Skip to content

Commit ece1758

Browse files
committed
tiny img + data aug
1 parent 81cc11e commit ece1758

File tree

3 files changed

+550
-82
lines changed

3 files changed

+550
-82
lines changed

config/data/image/tiny_imagenet_aug.yaml

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ data_dir: '../data/image'
44
train_val_split: [0.8, 0.2]
55
batch_size: 512
66
num_workers: 0
7-
size: 64
87
pin_memory: True
98
persistent_workers: False
109
transforms:
@@ -15,11 +14,10 @@ transforms:
1514
mean: [0.4822, 0.4495, 0.3985]
1615
std: [0.2771, 0.2690, 0.2826]
1716
- _target_: torchvision.transforms.Resize
18-
size: [${..size},${..size}]
17+
size: 64
1918
- _target_: torchvision.transforms.RandomCrop
20-
size: [${..size},${..size}]
19+
size: 64
2120
- _target_: torchvision.transforms.RandomHorizontalFlip
2221
- _target_: torchvision.transforms.RandomVerticalFlip
23-
- _target_: torchvision.transforms.RandomRotation
2422

2523

tutorials/super_resolution.ipynb

+89-20
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,28 @@
2828
},
2929
{
3030
"cell_type": "code",
31-
"execution_count": 3,
31+
"execution_count": 16,
3232
"metadata": {},
33-
"outputs": [
34-
{
35-
"name": "stderr",
36-
"output_type": "stream",
37-
"text": [
38-
"Seed set to 42\n",
39-
"/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py:42: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.Output is equivalent up to float precision.\n",
40-
" warnings.warn(\n"
41-
]
42-
}
43-
],
33+
"outputs": [],
4434
"source": [
35+
"import torch.nn as nn\n",
36+
"import torch\n",
37+
"\n",
4538
"from nimrod.image.datasets import ImageDataset, ImageDataModule\n",
4639
"from nimrod.models.core import lr_finder, train_one_cycle\n",
40+
"from nimrod.models.resnet import ResBlock\n",
41+
"\n",
4742
"from hydra.utils import instantiate\n",
4843
"from omegaconf import OmegaConf\n",
49-
"from rich import print"
44+
"from rich import print\n",
45+
"from typing import Optional, Type"
46+
]
47+
},
48+
{
49+
"cell_type": "markdown",
50+
"metadata": {},
51+
"source": [
52+
"## tiny imagenet"
5053
]
5154
},
5255
{
@@ -71,13 +74,6 @@
7174
")"
7275
]
7376
},
74-
{
75-
"cell_type": "code",
76-
"execution_count": null,
77-
"metadata": {},
78-
"outputs": [],
79-
"source": []
80-
},
8177
{
8278
"cell_type": "code",
8379
"execution_count": 5,
@@ -148,6 +144,79 @@
148144
"print(dm.dim)"
149145
]
150146
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": 32,
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"#| export\n",
154+
"class UpBlock(nn.Module):\n",
155+
" def __init__(\n",
156+
" self,\n",
157+
" in_channels:int,\n",
158+
" out_channels:int,\n",
159+
" kernel_size:int=3,\n",
160+
" activation:Optional[Type[nn.Module]]=nn.ReLU\n",
161+
" ):\n",
162+
" super().__init__()\n",
163+
" layers = []\n",
164+
" # upsample receptive field\n",
165+
" layers.append(nn.UpsamplingNearest2d(scale_factor=2))\n",
166+
" # resnet block increase channels\n",
167+
" layers.append(ResBlock(in_channels, out_channels, kernel_size=kernel_size, activation=activation))\n",
168+
" self.nnet = nn.Sequential(*layers)\n",
169+
"\n",
170+
" def forward(self, x):\n",
171+
" return self.nnet(x)"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": 31,
177+
"metadata": {},
178+
"outputs": [
179+
{
180+
"data": {
181+
"text/html": [
182+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">torch.Size</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">128</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">128</span><span style=\"font-weight: bold\">])</span>\n",
183+
"</pre>\n"
184+
],
185+
"text/plain": [
186+
"\u001b[1;35mtorch.Size\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m, \u001b[1;36m8\u001b[0m, \u001b[1;36m128\u001b[0m, \u001b[1;36m128\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n"
187+
]
188+
},
189+
"metadata": {},
190+
"output_type": "display_data"
191+
}
192+
],
193+
"source": [
194+
"m = UpBlock(3, 8)\n",
195+
"x = torch.randn(1, 3, 64, 64)\n",
196+
"y = m(x)\n",
197+
"print(y.shape)"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": 29,
203+
"metadata": {},
204+
"outputs": [
205+
{
206+
"data": {
207+
"text/plain": [
208+
"torch.Size([1, 3, 128, 128])"
209+
]
210+
},
211+
"execution_count": 29,
212+
"metadata": {},
213+
"output_type": "execute_result"
214+
}
215+
],
216+
"source": [
217+
"m.nnet[0](x).shape"
218+
]
219+
},
151220
{
152221
"cell_type": "code",
153222
"execution_count": null,

0 commit comments

Comments
 (0)