forked from greydanus/mnist1d
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
57 lines (47 loc) · 1.72 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# The MNIST-1D dataset | 2020
# Sam Greydanus
import numpy as np
import torch
import random
import pickle
import matplotlib.pyplot as plt
from .transform import transform
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def to_pickle(thing, path): # save something
with open(path, 'wb') as handle:
pickle.dump(thing, handle, protocol=3)
def from_pickle(path): # load something
thing = None
with open(path, 'rb') as handle:
thing = pickle.load(handle)
return thing
class ObjectView(object):
def __init__(self, d): self.__dict__ = d
def plot_signals(xs, t, labels=None, args=None, ratio=2.6, do_transform=False, dark_mode=False, zoom=1):
rows, cols = 1, 10
fig = plt.figure(figsize=[cols*1.5,rows*1.5*ratio], dpi=60)
for r in range(rows):
for c in range(cols):
ix = r*cols + c
x, t = xs[ix], t
ax = plt.subplot(rows,cols,ix+1)
# plot the data
if do_transform:
assert args is not None, "Need an args object in order to do transforms"
x, t = transform(x, t, args) # optionally, transform the signal in some manner
if dark_mode:
plt.plot(x, t, 'wo', linewidth=6)
ax.set_facecolor('k')
else:
plt.plot(x, t, 'k-', linewidth=2)
if labels is not None:
plt.title("label=" + str(labels[ix]), fontsize=22)
plt.xlim(-zoom,zoom) ; plt.ylim(-zoom,zoom)
plt.gca().invert_yaxis() ; plt.xticks([], []), plt.yticks([], [])
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout() ; plt.show()
return fig