@@ -29,14 +29,19 @@ class ImageDataset(Dataset):
29
29
def show_idx (self ,
30
30
index :int # Index of the (image,label) sample to visualize
31
31
):
32
+ "display image from data point index of a image dataset"
32
33
X , y = self .__getitem__ (index )
33
34
plt .figure (figsize = (1 , 1 ))
34
35
plt .imshow (X .numpy ().reshape (28 ,28 ),cmap = 'gray' )
35
36
plt .title (f"Label: { int (y )} " )
36
37
plt .show ()
37
38
38
39
@staticmethod
39
- def show_grid (imgs , save_path = None ):
40
+ def show_grid (
41
+ imgs : List [torch .Tensor ], # python list of images dim (C,H,W)
42
+ save_path = None # path where image can be saved
43
+ ):
44
+ "display list of mnist-like images (C,H,W)"
40
45
if not isinstance (imgs , list ):
41
46
imgs = [imgs ]
42
47
fig , axs = plt .subplots (ncols = len (imgs ), squeeze = False )
@@ -50,6 +55,7 @@ def show_grid(imgs, save_path=None):
50
55
def show_random (self ,
51
56
n = 3 # number of images to display
52
57
):
58
+ "display grid of random images"
53
59
indices = torch .randint (0 ,len (self ), (n ,))
54
60
images = []
55
61
for index in indices :
@@ -59,7 +65,7 @@ def show_random(self,
59
65
self .show_grid (images )
60
66
61
67
62
- # %% ../../nbs/image.datasets.ipynb 8
68
+ # %% ../../nbs/image.datasets.ipynb 11
63
69
class MNISTDataset (ImageDataset ):
64
70
"MNIST digit dataset"
65
71
@@ -68,6 +74,9 @@ def __init__(
68
74
data_dir :str = '~/Data' , # path where data is saved
69
75
train = True , # train or test dataset
70
76
transform :torchvision .transforms .transforms = torchvision .transforms .ToTensor () # data formatting
77
+ # TODO: add noramlization?
78
+ # torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.1307,), (0.3081,))])
79
+
71
80
):
72
81
73
82
super ().__init__ ()
@@ -79,18 +88,19 @@ def __init__(
79
88
download = True
80
89
)
81
90
82
- def __len__ (self ):
91
+ def __len__ (self ) -> int : # length of dataset
83
92
return len (self .ds )
84
93
85
- def __getitem__ (self , idx ):
94
+ def __getitem__ (self , idx # index into the dataset
95
+ ) -> tuple [torch .FloatTensor , int ]: # Y image data, x digit number
86
96
x = self .ds [idx ][0 ]
87
97
y = self .ds [idx ][1 ]
88
98
return x , y
89
99
90
100
def train_dev_split (self ,
91
101
ratio :float , # percentage of train/dev split,
92
102
seed :int = 42 # rand generator seed
93
- ):
103
+ ) -> tuple [ torchvision . datasets . MNIST , torchvision . datasets . MNIST ]: # train and set mnnist datasets
94
104
train_set_size = int (len (self .ds ) * ratio )
95
105
valid_set_size = len (self .ds ) - train_set_size
96
106
@@ -101,15 +111,15 @@ def train_dev_split(self,
101
111
102
112
103
113
104
- # %% ../../nbs/image.datasets.ipynb 14
114
+ # %% ../../nbs/image.datasets.ipynb 18
105
115
class MNISTDataModule (LightningDataModule ):
106
116
def __init__ (
107
117
self ,
108
- data_dir : str = "~/Data/" ,
109
- train_val_test_split :List [float ] = [0.8 , 0.1 , 0.1 ],
110
- batch_size : int = 64 ,
111
- num_workers : int = 0 ,
112
- pin_memory : bool = False ,
118
+ data_dir : str = "~/Data/" , # path to source data dir
119
+ train_val_test_split :List [float ] = [0.8 , 0.1 , 0.1 ], # train val test %
120
+ batch_size : int = 64 , # size of compute batch
121
+ num_workers : int = 0 , # num_workers equal 0 means that it’s the main process that will do the data loading when needed, num_workers equal 1 is the same as any n, but you’ll only have a single worker, so it might be slow
122
+ pin_memory : bool = False , # If you load your samples in the Dataset on CPU and would like to push it during training to the GPU, you can speed up the host to device transfer by enabling pin_memory. This lets your DataLoader allocate the samples in page-locked memory, which speeds-up the transfer
113
123
):
114
124
super ().__init__ ()
115
125
self .save_hyperparameters (logger = False ) # can access inputs with self.hparams
@@ -122,17 +132,19 @@ def __init__(
122
132
raise Exception ('split percentages should sum up to 1.0' )
123
133
124
134
@property
125
- def num_classes (self ):
135
+ def num_classes (self ) -> int : # num of classes in dataset
126
136
return 10
127
137
128
- def prepare_data (self ):
138
+ def prepare_data (self ) -> None :
129
139
"""Download data if needed + format with MNISTDataset
130
140
"""
131
141
MNISTDataset (self .hparams .data_dir , train = True )
132
142
MNISTDataset (self .hparams .data_dir , train = False )
133
143
134
- def setup (self , stage : Optional [str ] = None ):
144
+ def setup (self , stage : Optional [str ] = None ) -> None :
145
+ # concat train & test mnist dataset and randomly generate train, eval, test sets
135
146
if not self .data_train and not self .data_val and not self .data_test :
147
+ # ((B, H, W), int)
136
148
trainset = MNISTDataset (self .hparams .data_dir , train = True , transform = self .transforms )
137
149
testset = MNISTDataset (self .hparams .data_dir , train = False , transform = self .transforms )
138
150
dataset = ConcatDataset (datasets = [trainset , testset ])
@@ -143,7 +155,7 @@ def setup(self, stage: Optional[str] = None):
143
155
generator = torch .Generator ().manual_seed (42 ),
144
156
)
145
157
146
- def train_dataloader (self ):
158
+ def train_dataloader (self ) -> torch . utils . data . DataLoader :
147
159
return DataLoader (
148
160
dataset = self .data_train ,
149
161
batch_size = self .hparams .batch_size ,
@@ -152,7 +164,7 @@ def train_dataloader(self):
152
164
shuffle = True ,
153
165
)
154
166
155
- def val_dataloader (self ):
167
+ def val_dataloader (self ) -> torch . utils . data . DataLoader :
156
168
return DataLoader (
157
169
dataset = self .data_val ,
158
170
batch_size = self .hparams .batch_size ,
@@ -161,7 +173,7 @@ def val_dataloader(self):
161
173
shuffle = False ,
162
174
)
163
175
164
- def test_dataloader (self ):
176
+ def test_dataloader (self ) -> torch . utils . data . DataLoader :
165
177
return DataLoader (
166
178
dataset = self .data_test ,
167
179
batch_size = self .hparams .batch_size ,
@@ -170,7 +182,7 @@ def test_dataloader(self):
170
182
shuffle = False ,
171
183
)
172
184
173
- def teardown (self , stage : Optional [str ] = None ):
185
+ def teardown (self , stage : Optional [str ] = None ) -> None :
174
186
"""Clean up after fit or test."""
175
187
pass
176
188
0 commit comments