|
28 | 28 | },
|
29 | 29 | {
|
30 | 30 | "cell_type": "code",
|
31 |
| - "execution_count": 3, |
| 31 | + "execution_count": 16, |
32 | 32 | "metadata": {},
|
33 |
| - "outputs": [ |
34 |
| - { |
35 |
| - "name": "stderr", |
36 |
| - "output_type": "stream", |
37 |
| - "text": [ |
38 |
| - "Seed set to 42\n", |
39 |
| - "/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py:42: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.Output is equivalent up to float precision.\n", |
40 |
| - " warnings.warn(\n" |
41 |
| - ] |
42 |
| - } |
43 |
| - ], |
| 33 | + "outputs": [], |
44 | 34 | "source": [
|
| 35 | + "import torch.nn as nn\n", |
| 36 | + "import torch\n", |
| 37 | + "\n", |
45 | 38 | "from nimrod.image.datasets import ImageDataset, ImageDataModule\n",
|
46 | 39 | "from nimrod.models.core import lr_finder, train_one_cycle\n",
|
| 40 | + "from nimrod.models.resnet import ResBlock\n", |
| 41 | + "\n", |
47 | 42 | "from hydra.utils import instantiate\n",
|
48 | 43 | "from omegaconf import OmegaConf\n",
|
49 |
| - "from rich import print" |
| 44 | + "from rich import print\n", |
| 45 | + "from typing import Optional, Type" |
| 46 | + ] |
| 47 | + }, |
| 48 | + { |
| 49 | + "cell_type": "markdown", |
| 50 | + "metadata": {}, |
| 51 | + "source": [ |
| 52 | + "## tiny imagenet" |
50 | 53 | ]
|
51 | 54 | },
|
52 | 55 | {
|
|
71 | 74 | ")"
|
72 | 75 | ]
|
73 | 76 | },
|
74 |
| - { |
75 |
| - "cell_type": "code", |
76 |
| - "execution_count": null, |
77 |
| - "metadata": {}, |
78 |
| - "outputs": [], |
79 |
| - "source": [] |
80 |
| - }, |
81 | 77 | {
|
82 | 78 | "cell_type": "code",
|
83 | 79 | "execution_count": 5,
|
|
148 | 144 | "print(dm.dim)"
|
149 | 145 | ]
|
150 | 146 | },
|
| 147 | + { |
| 148 | + "cell_type": "code", |
| 149 | + "execution_count": 32, |
| 150 | + "metadata": {}, |
| 151 | + "outputs": [], |
| 152 | + "source": [ |
| 153 | + "#| export\n", |
| 154 | + "class UpBlock(nn.Module):\n", |
| 155 | + " def __init__(\n", |
| 156 | + " self,\n", |
| 157 | + " in_channels:int,\n", |
| 158 | + " out_channels:int,\n", |
| 159 | + " kernel_size:int=3,\n", |
| 160 | + " activation:Optional[Type[nn.Module]]=nn.ReLU\n", |
| 161 | + " ):\n", |
| 162 | + " super().__init__()\n", |
| 163 | + " layers = []\n", |
| 164 | + " # upsample receptive field\n", |
| 165 | + " layers.append(nn.UpsamplingNearest2d(scale_factor=2))\n", |
| 166 | + " # resnet block increase channels\n", |
| 167 | + " layers.append(ResBlock(in_channels, out_channels, kernel_size=kernel_size, activation=activation))\n", |
| 168 | + " self.nnet = nn.Sequential(*layers)\n", |
| 169 | + "\n", |
| 170 | + " def forward(self, x):\n", |
| 171 | + " return self.nnet(x)" |
| 172 | + ] |
| 173 | + }, |
| 174 | + { |
| 175 | + "cell_type": "code", |
| 176 | + "execution_count": 31, |
| 177 | + "metadata": {}, |
| 178 | + "outputs": [ |
| 179 | + { |
| 180 | + "data": { |
| 181 | + "text/html": [ |
| 182 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">torch.Size</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">128</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">128</span><span style=\"font-weight: bold\">])</span>\n", |
| 183 | + "</pre>\n" |
| 184 | + ], |
| 185 | + "text/plain": [ |
| 186 | + "\u001b[1;35mtorch.Size\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m, \u001b[1;36m8\u001b[0m, \u001b[1;36m128\u001b[0m, \u001b[1;36m128\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" |
| 187 | + ] |
| 188 | + }, |
| 189 | + "metadata": {}, |
| 190 | + "output_type": "display_data" |
| 191 | + } |
| 192 | + ], |
| 193 | + "source": [ |
| 194 | + "m = UpBlock(3, 8)\n", |
| 195 | + "x = torch.randn(1, 3, 64, 64)\n", |
| 196 | + "y = m(x)\n", |
| 197 | + "print(y.shape)" |
| 198 | + ] |
| 199 | + }, |
| 200 | + { |
| 201 | + "cell_type": "code", |
| 202 | + "execution_count": 29, |
| 203 | + "metadata": {}, |
| 204 | + "outputs": [ |
| 205 | + { |
| 206 | + "data": { |
| 207 | + "text/plain": [ |
| 208 | + "torch.Size([1, 3, 128, 128])" |
| 209 | + ] |
| 210 | + }, |
| 211 | + "execution_count": 29, |
| 212 | + "metadata": {}, |
| 213 | + "output_type": "execute_result" |
| 214 | + } |
| 215 | + ], |
| 216 | + "source": [ |
| 217 | + "m.nnet[0](x).shape" |
| 218 | + ] |
| 219 | + }, |
151 | 220 | {
|
152 | 221 | "cell_type": "code",
|
153 | 222 | "execution_count": null,
|
|
0 commit comments