Skip to content

Commit fea112b

Browse files
authored
Merge pull request #111 from slegroux/108-simple-diffusion
108 simple diffusion
2 parents 8be360b + 88b57cf commit fea112b

12 files changed

+1695
-336
lines changed

config/data/image/mnist.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ batch_size: 64
66
num_workers: 0
77

88
pin_memory: True
9-
persistent_workers: False
9+
persistent_workers: True
1010
transforms:
11-
_target_: torchvision.transforms.Compose
11+
_target_: torchvision.transforms.v2.Compose
1212
transforms:
13-
- _target_: torchvision.transforms.ToTensor
14-
- _target_: torchvision.transforms.Normalize
13+
- _target_: torchvision.transforms.v2.ToImage
14+
- _target_: torchvision.transforms.v2.Normalize
1515
mean: [0.1307,]
1616
std: [0.3081,]
1717

config/model/image/diffusorx.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_target_: nimrod.models.diffusion.DiffusorX
2+
_partial_: true # we init optim & sched later
3+
4+
nnet:
5+
_target_: diffusers.UNet2DModel
6+
block_out_channels: [32, 64, 128, 256]
7+
sample_size: 32
8+
in_channels: 1
9+
out_channels: 1
10+
11+
noise_scheduler:
12+
_target_: diffusers.DDPMScheduler
13+
num_train_timesteps: 1000
14+
beta_start: 0.00085
15+
beta_end: 0.012
16+
17+
18+
# optimizer
19+
# scheduler%

config/model/image/tinyunetx.yaml

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
_target_: nimrod.models.unet.TinyUnetX
2+
_partial_: true
3+
4+
nnet:
5+
_target_: nimrod.models.unet.TinyUnet
6+
n_features: [3, 32, 64, 128, 256, 512, 1024] # channel/feature expansion
7+
pre_activation: true
8+
# activation:
9+
# _target_: torch.nn.ReLU
10+
# _partial_: true
11+
# leaky: 0.0
12+
13+
activation:
14+
_target_: torch.nn.LeakyReLU
15+
_partial_: true
16+
negative_slope: 0.1
17+
18+
19+
weight_initialization: true
20+
21+
# optimizer
22+
# schedule

nbs/models.core.ipynb

+112-53
Original file line numberDiff line numberDiff line change
@@ -560,65 +560,124 @@
560560
"\n"
561561
]
562562
},
563+
{
564+
"cell_type": "markdown",
565+
"metadata": {},
566+
"source": [
567+
"## Diffuser Abstract Class"
568+
]
569+
},
563570
{
564571
"cell_type": "code",
565572
"execution_count": null,
566573
"metadata": {},
567-
"outputs": [
568-
{
569-
"name": "stderr",
570-
"output_type": "stream",
571-
"text": [
572-
"[21:54:33] INFO - Init ImageDataModule for mnist\n"
573-
]
574-
},
575-
{
576-
"name": "stderr",
577-
"output_type": "stream",
578-
"text": [
579-
"[21:54:36] INFO - loading dataset mnist with args () from split train\n",
580-
"[21:54:36] INFO - loading dataset mnist from split train\n",
581-
"Overwrite dataset info from restored data version if exists.\n",
582-
"[21:54:38] INFO - Overwrite dataset info from restored data version if exists.\n",
583-
"Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
584-
"[21:54:38] INFO - Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
585-
"Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n",
586-
"[21:54:38] INFO - Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n",
587-
"Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
588-
"[21:54:38] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
589-
"[21:54:42] INFO - loading dataset mnist with args () from split test\n",
590-
"[21:54:42] INFO - loading dataset mnist from split test\n",
591-
"Overwrite dataset info from restored data version if exists.\n",
592-
"[21:54:43] INFO - Overwrite dataset info from restored data version if exists.\n",
593-
"Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
594-
"[21:54:43] INFO - Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
595-
"Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n",
596-
"[21:54:43] INFO - Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n",
597-
"Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
598-
"[21:54:43] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n",
599-
"[21:54:44] INFO - split train into train/val [0.8, 0.2]\n",
600-
"[21:54:44] INFO - train: 48000 val: 12000, test: 10000\n"
601-
]
602-
},
603-
{
604-
"data": {
605-
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADcCAYAAADa3YUtAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAE2ZJREFUeJzt3XtQVOX/B/D3rl9YCHEJhF1JFnE0L2laJIi3UBnRsSa8zFQzpY6Nli2N4pQTlmJmbnlJwyhrTNFuOJZY6YyTg4jVACpeGkTJHFRMd8WSBS+Ass/vD3/ut/3uWR4WFnbB92vm/MFnn939HOXN4TycfY5KCCFARC6pvd0Aka9jSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESIIhIZJgSHzQuXPnoFKpsGbNGo+95oEDB6BSqXDgwAGPveb9giHxkOzsbKhUKhw5csTbrbSp7du3IyEhAUFBQQgJCcGIESOwf/9+b7fVpv7j7Qao41i2bBmWL1+O6dOnY9asWbh9+zZKS0vx119/ebu1NsWQULMUFRVh+fLlWLt2LdLS0rzdTrvir1vtqKGhAUuXLkVsbCy0Wi2CgoIwevRo5Ofnu3zOunXrEB0djcDAQDz55JMoLS11GnP69GlMnz4doaGhCAgIwBNPPIEff/xR2s/Nmzdx+vRpXL16VTp2/fr10Ov1mD9/PoQQuH79uvQ5nQVD0o5qamqwadMmJCYm4oMPPsCyZctQVVWF5ORkHD9+3Gn8tm3bkJmZCaPRiPT0dJSWlmLcuHGwWCz2MSdPnsTw4cNx6tQpvPnmm1i7di2CgoKQkpKC3NzcJvs5dOgQBgwYgI8//ljae15eHoYNG4bMzEyEh4cjODgYPXr0aNZzOzxBHrFlyxYBQBw+fNjlmDt37oj6+nqH2rVr14ROpxOzZ8+21yoqKgQAERgYKC5evGivFxcXCwAiLS3NXhs/frwYPHiwqKurs9dsNpsYMWKE6Nu3r72Wn58vAIj8/HynWkZGRpP79s8//wgAIiwsTHTt2lWsXr1abN++XUycOFEAEBs3bmzy+R0dQ+IhzQnJvzU2Noq///5bVFVVicmTJ4uhQ4faH7sXkueff97pefHx8aJfv35CCCH+/vtvoVKpxLvvviuqqqoctnfeeUcAsIdMKSTNdeHCBQFAABA5OTkO+zBw4EDRs2dPt1+zI+GvW+1s69atePTRRxEQEICwsDCEh4djz549sFqtTmP79u3rVHv44Ydx7tw5AMCff/4JIQSWLFmC8PBwhy0jIwMAcOXKlVb3HBgYCADw8/PD9OnT7XW1Wo1nn30WFy9exIULF1r9Pr6Ks1vt6KuvvsKsWbOQkpKCN954AxEREejSpQtMJhPOnj3r9uvZbDYAwOuvv47k5GTFMX369GlVzwDsEwIhISHo0qWLw2MREREAgGvXrsFgMLT6vXwRQ9KOvvvuO/Tu3Rs7d+6ESqWy1+/91P9fZ86ccar98ccf6NWrFwCgd+/eAO7+hE9KSvJ8w/9PrVZj6NChOHz4MBoaGuDv729/7NKlSwCA8PDwNnt/b+OvW+3o3k9h8a+1N4qLi1FYWKg4fteuXQ5/qDt06BCKi4sxadIkAHd/iicmJuKzzz7D5cuXnZ5fVVXVZD/uTAE/++yzaGxsxNatW+21uro6fP311xg4cCAiIyOlr9FR8UjiYZs3b8bevXud6vPnz8dTTz2FnTt3YsqUKZg8eTIqKiqwceNGDBw4UPHvDn369MGoUaMwb9481NfXY/369QgLC8OiRYvsY7KysjBq1CgMHjwYc+bMQe/evWGxWFBYWIiLFy/ixIkTLns9dOgQxo4di4yMDCxbtqzJ/Xr55ZexadMmGI1G/PHHHzAYDPjyyy9x/vx5/PTTT83/B+qIvD1z0Fncm91ytVVWVgqbzSZWrlwpoqOjhUajEY899pjYvXu3mDlzpoiOjra/1r3ZrdWrV4u1a9eKqKgoodFoxOjRo8WJEyec3vvs2bNixowZQq/XCz8/P/HQQw+Jp556Snz33Xf2Ma2ZAr7HYrGImTNnitDQUKHRaER8fLzYu3dvS//JOgyVEFx3i6gpPCchkmBIiCQYEiIJhoRIgiEhkmBIiCTa7I+JWVlZWL16NcxmM4YMGYINGzYgLi5O+jybzYZLly4hODjY4dINIk8SQqC2thaRkZFQqyXHirb440tOTo7w9/cXmzdvFidPnhRz5swRISEhwmKxSJ9bWVnZ5B/luHHz5FZZWSn9nmyTkMTFxQmj0Wj/urGxUURGRgqTySR9bnV1tdf/4bjdP1t1dbX0e9Lj5yQNDQ0oKSlxuCpVrVYjKSlJ8UK++vp61NTU2Lfa2lpPt0TkUnN+pfd4SK5evYrGxkbodDqHuk6ng9lsdhpvMpmg1WrtW1RUlKdbImoVr89upaenw2q12rfKykpvt0TkwOOzW927d0eXLl0cVvQAAIvFAr1e7zReo9FAo9F4ug0ij/H4kcTf3x+xsbHIy8uz12w2G/Ly8pCQkODptyNqe62axnIhJydHaDQakZ2dLcrKysTcuXNFSEiIMJvN0udarVavz3hwu382q9Uq/Z5ssw9dbdiwQRgMBuHv7y/i4uJEUVFRs57HkHBrz605IfG5D13V1NRAq9V6uw26T1itVnTr1q3JMV6f3SLydVwIohObOHGiYv3fK57cc/ToUcWxL774olOtOaurdCY8khBJMCREEgwJkQRDQiTBE/dObO3atYr1sLAwp5qrJVHvt5N0JTySEEkwJEQSDAmRBENCJMGQEElwdquTGDNmjFNtwIABimNPnTrlVJsxY4bHe+oseCQhkmBIiCQYEiIJhoRIgifunUR6erpTzdWHTr///vu2bqdT4ZGESIIhIZJgSIgkGBIiCYaESIKzWx1MdHS0Yt1gMDjVXK2AkpmZ6dGeOjseSYgkGBIiCYaESIIhIZLgiXsHo3T5CQD069fPqebqBJ0roLiHRxIiCYaESIIhIZJgSIgkGBIiCd4Ozof179/fqVZWVqY4VmkFlEceecTjPXU2vB0ckQcwJEQSDAmRBENCJMHLUnxAeHi4Yn3dunVONVfzLCtWrPBoT/RfPJIQSTAkRBIMCZEEQ0Ik4XZIDh48iKeffhqRkZFQqVTYtWuXw+NCCCxduhQ9evRAYGAgkpKScObMGU/1S9Tu3J7dunHjBoYMGYLZs2dj6tSpTo+vWrUKmZmZ2Lp1K2JiYrBkyRIkJyejrKwMAQEBHmm6s1Fa6QQAHn/8caeaqxVQ9u3b59Ge6L/cDsmkSZMwadIkxceEEFi/fj3efvttPPPMMwCAbdu2QafTYdeuXXjuueda1y2RF3j0nKSiogJmsxlJSUn2mlarRXx8PAoLCxWfU19fj5qaGoeNyJd4NCRmsxkAoNPpHOo6nc7+2P8ymUzQarX2LSoqypMtEbWa12e30tPTYbVa7VtlZaW3WyJy4NHLUvR6PQDAYrGgR48e9rrFYsHQoUMVn6PRaKDRaDzZRoczevRoxXpYWJhTbcmSJYpjuQJK2/HokSQmJgZ6vR55eXn2Wk1NDYqLi5GQkODJtyJqN24fSa5fv44///zT/nVFRQWOHz+O0NBQGAwGLFiwACtWrEDfvn3tU8CRkZFISUnxZN9E7cbtkBw5cgRjx461f71w4UIAwMyZM5GdnY1Fixbhxo0bmDt3LqqrqzFq1Cjs3buXfyOhDsvtkCQmJrq8XBsAVCoVli9fjuXLl7eqMSJf4fXZLSJfx9VSfIDNZlOsK/3XuFoB5fTp0x7t6X7B1VKIPIAhIZJgSIgkGBIiCa6W0s6mTJniVHM1d7Jz506nGk/Q2x+PJEQSDAmRBENCJMGQEEkwJEQSnN1qI0FBQYr1F154wanm6tOYH330kUd7akpaWppiXemmQaNGjVIcazKZnGo3b95sXWM+gEcSIgmGhEiCISGSYEiIJHji3ka6d++uWB85cqRT7datW4pj2+oSFKW7+q5Zs0ZxrNIlMyqVqtnv5Wp1l46ERxIiCYaESIIhIZJgSIgkGBIiCc5utRGlGxwByrej/vnnnxXHtnZ9X6VLYABg69atTjVXl8YozW5FR0crjnW1pnFHxyMJkQRDQiTBkBBJMCREEjxxbyP9+vVTrCudCL/33nutfr+33nrLqeZq0XKl98vMzFQcu3jxYqfa/PnzFceeOnWqqRY7LB5JiCQYEiIJhoRIgiEhkmBIiCR4E582cuXKFcV6VVWVU83VjXncceTIEadabm6u4lilFVBczbApzdK5uoxm0qRJTbXok3gTHyIPYEiIJBgSIgmGhEiCl6W0EVeXaCitluKK0qomY8aMURwbFhbmVHvzzTeb/V6uliP9/PPPnWqdYQUUd/BIQiTBkBBJMCREEgwJkYRbITGZTBg2bBiCg4MRERGBlJQUlJeXO4ypq6uD0WhEWFgYunbtimnTpsFisXi0aaL25NbsVkFBAYxGI4YNG4Y7d+5g8eLFmDBhAsrKyuw3rUlLS8OePXuwY8cOaLVapKamYurUqfjtt9/aZAd8lat1fJVugOPqyiCbzeZUU6uVf64pjT169KjiWKWZt5UrVyqO5S2x3QzJ3r17Hb7Ozs5GREQESkpKMGbMGFitVnzxxRf45ptvMG7cOADAli1bMGDAABQVFWH48OGe65yonbTqnMRqtQIAQkNDAQAlJSW4ffs2kpKS7GP69+8Pg8GAwsJCxdeor69HTU2Nw0bkS1ocEpvNhgULFmDkyJEYNGgQAMBsNsPf3x8hISEOY3U6Hcxms+LrmEwmaLVa+xYVFdXSlojaRItDYjQaUVpaipycnFY1kJ6eDqvVat9crSRI5C0tuiwlNTUVu3fvxsGDB9GzZ097Xa/Xo6GhAdXV1Q5HE4vFAr1er/haGo0GGo2mJW34tF9++UWxrvT5DFfLgyqd0BcUFCiOVTrxdvW5D3KPW0cSIQRSU1ORm5uL/fv3IyYmxuHx2NhY+Pn5IS8vz14rLy/HhQsXkJCQ4JmOidqZW0cSo9GIb775Bj/88AOCg4Pt5xlarRaBgYHQarV46aWXsHDhQoSGhqJbt2547bXXkJCQwJkt6rDcCsmnn34KAEhMTHSob9myBbNmzQIArFu3Dmq1GtOmTUN9fT2Sk5PxySefeKRZIm9wKyTN+Th8QEAAsrKykJWV1eKmiHwJr90ikuBqKXRf42opRB7AkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEj4XEh+7XQp1cs35fvO5kNTW1nq7BbqPNOf7zefudGWz2XDp0iUEBwejtrYWUVFRqKyslN6NqKOpqanhvnmREAK1tbWIjIyEWt30scKtG4u2B7VajZ49ewIAVCoVAKBbt24++4/dWtw372nubQd97tctIl/DkBBJ+HRINBoNMjIyoNFovN2Kx3HfOg6fO3En8jU+fSQh8gUMCZEEQ0IkwZAQSfh0SLKystCrVy8EBAQgPj4ehw4d8nZLbjt48CCefvppREZGQqVSYdeuXQ6PCyGwdOlS9OjRA4GBgUhKSsKZM2e806wbTCYThg0bhuDgYERERCAlJQXl5eUOY+rq6mA0GhEWFoauXbti2rRpsFgsXuq45Xw2JNu3b8fChQuRkZGBo0ePYsiQIUhOTsaVK1e83Zpbbty4gSFDhiArK0vx8VWrViEzMxMbN25EcXExgoKCkJycjLq6unbu1D0FBQUwGo0oKirCvn37cPv2bUyYMAE3btywj0lLS8NPP/2EHTt2oKCgAJcuXcLUqVO92HULCR8VFxcnjEaj/evGxkYRGRkpTCaTF7tqHQAiNzfX/rXNZhN6vV6sXr3aXquurhYajUZ8++23Xuiw5a5cuSIAiIKCAiHE3f3w8/MTO3bssI85deqUACAKCwu91WaL+OSRpKGhASUlJUhKSrLX1Go1kpKSUFhY6MXOPKuiogJms9lhP7VaLeLj4zvcflqtVgBAaGgoAKCkpAS3b9922Lf+/fvDYDB0uH3zyZBcvXoVjY2N0Ol0DnWdTgez2eylrjzv3r509P202WxYsGABRo4ciUGDBgG4u2/+/v4ICQlxGNvR9g3wwauAqeMxGo0oLS3Fr7/+6u1W2oRPHkm6d++OLl26OM2EWCwW6PV6L3Xleff2pSPvZ2pqKnbv3o38/Hz7RxyAu/vW0NCA6upqh/Edad/u8cmQ+Pv7IzY2Fnl5efaazWZDXl4eEhISvNiZZ8XExECv1zvsZ01NDYqLi31+P4UQSE1NRW5uLvbv34+YmBiHx2NjY+Hn5+ewb+Xl5bhw4YLP75sTb88cuJKTkyM0Go3Izs4WZWVlYu7cuSIkJESYzWZvt+aW2tpacezYMXHs2DEBQHz44Yfi2LFj4vz580IIId5//30REhIifvjhB/H777+LZ555RsTExIhbt255ufOmzZs3T2i1WnHgwAFx+fJl+3bz5k37mFdeeUUYDAaxf/9+ceTIEZGQkCASEhK82HXL+GxIhBBiw4YNwmAwCH9/fxEXFyeKioq83ZLb8vPzBQCnbebMmUKIu9PAS5YsETqdTmg0GjF+/HhRXl7u3aabQWmfAIgtW7bYx9y6dUu8+uqr4sEHHxQPPPCAmDJlirh8+bL3mm4hXipPJOGT5yREvoQhIZJgSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESIIhIZJgSIgkGBIiif8DUF99nwmMfM8AAAAASUVORK5CYII=",
606-
"text/plain": [
607-
"<Figure size 200x200 with 1 Axes>"
608-
]
609-
},
610-
"metadata": {},
611-
"output_type": "display_data"
612-
}
613-
],
574+
"outputs": [],
614575
"source": [
615-
"#| notest\n",
576+
"#| export\n",
616577
"\n",
617-
"cfg = OmegaConf.load(\"../config/data/image/mnist.yaml\")\n",
618-
"dm = instantiate(cfg)\n",
619-
"dm.prepare_data()\n",
620-
"dm.setup()\n",
621-
"dm.show(0)\n"
578+
"class Diffuser(ABC, L.LightningModule):\n",
579+
" def __init__(\n",
580+
" self,\n",
581+
" nnet: L.LightningModule,\n",
582+
" optimizer: Callable[...,torch.optim.Optimizer], # partial of optimizer\n",
583+
" scheduler: Optional[Callable[...,Any]]=None, # partial of scheduler\n",
584+
"\n",
585+
" ):\n",
586+
" logger.info(\"Diffuser: init\")\n",
587+
" super().__init__()\n",
588+
" self.save_hyperparameters()\n",
589+
" self.lr = optimizer.keywords.get('lr') if optimizer else None # for lr finder\n",
590+
" self.nnet = nnet\n",
591+
" # explicitely register nnet as a module to track its parameters\n",
592+
" self.register_module('nnet', self.nnet)\n",
593+
"\n",
594+
" # loss\n",
595+
" self.criterion = nn.MSELoss()\n",
596+
"\n",
597+
" # average accross batches\n",
598+
" self.train_loss = MeanMetric()\n",
599+
" self.val_loss = MeanMetric()\n",
600+
" self.test_loss = MeanMetric()\n",
601+
" self.val_loss_best = MinMetric()\n",
602+
"\n",
603+
" def forward(self, x:torch.Tensor, t:torch.Tensor)->torch.Tensor:\n",
604+
" return self.nnet(x, t)\n",
605+
"\n",
606+
" def configure_optimizers(self):\n",
607+
" logger.info(\"Regressor: configure_optimizers\")\n",
608+
" self.optimizer = self.hparams.optimizer(params=self.parameters())\n",
609+
" logger.info(f\"Optimizer: {self.optimizer.__class__}\")\n",
610+
" if self.hparams.scheduler is None:\n",
611+
" logger.warning(\"no scheduler has been setup\")\n",
612+
" return {\"optimizer\": self.optimizer}\n",
613+
" self.scheduler = self.hparams.scheduler(optimizer=self.optimizer)\n",
614+
" if isinstance(self.scheduler, torch.optim.lr_scheduler.OneCycleLR):\n",
615+
" lr_scheduler = {\n",
616+
" \"scheduler\": self.scheduler,\n",
617+
" \"interval\": \"step\",\n",
618+
" \"frequency\": 1,\n",
619+
" }\n",
620+
" else:\n",
621+
" lr_scheduler = {\n",
622+
" \"scheduler\": self.scheduler,\n",
623+
" \"interval\": \"epoch\",\n",
624+
" \"frequency\": 1,\n",
625+
" }\n",
626+
" logger.info(f\"Scheduler: {self.scheduler.__class__}\")\n",
627+
" return {\"optimizer\": self.optimizer, \"lr_scheduler\": lr_scheduler}\n",
628+
"\n",
629+
" def on_train_start(self) -> None:\n",
630+
" self.val_loss.reset()\n",
631+
" self.val_loss_best.reset()\n",
632+
" \n",
633+
" def predict_step(self, batch, batch_idx, dataloader_idx=0):\n",
634+
" x, y = batch\n",
635+
" y_hat = self.forward(x)\n",
636+
" return y_hat\n",
637+
"\n",
638+
" def _step(self, batch:Tuple[torch.Tensor, torch.Tensor], batch_idx:int):\n",
639+
" x, y = batch\n",
640+
" y_hat = self.forward(x)\n",
641+
" loss = self.criterion(y_hat, y)\n",
642+
" return loss, y_hat, y\n",
643+
"\n",
644+
" def training_step(self, batch, batch_idx):\n",
645+
" loss, y_hat, y = self._step(batch, batch_idx)\n",
646+
" self.train_loss(loss)\n",
647+
" self.train_mse(y_hat, y)\n",
648+
" # self.log(\"train/mse\", self.train_mse, on_epoch=True, on_step=True, prog_bar=True)\n",
649+
" self.log(\"train/loss\", self.train_loss, on_epoch=True, on_step=True, prog_bar=True)\n",
650+
" return loss\n",
651+
"\n",
652+
" def on_train_epoch_end(self) -> None:\n",
653+
" pass\n",
654+
"\n",
655+
" def validation_step(self, batch, batch_idx):\n",
656+
" loss, y_hat, y = self._step(batch, batch_idx)\n",
657+
" self.val_loss(loss)\n",
658+
" self.val_mse(y_hat, y)\n",
659+
" # self.log(\"val/mse\", self.val_mse, on_epoch=True, on_step=True, prog_bar=True)\n",
660+
" self.log(\"val/loss\", self.val_loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n",
661+
" \n",
662+
" def on_validation_epoch_end(self) -> None:\n",
663+
" mse = self.val_mse.compute() # get current val acc\n",
664+
" self.val_mse_best(mse) # update best so far val acc\n",
665+
" # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object\n",
666+
" # otherwise metric would be reset by lightning after each epoch\n",
667+
" self.log(\"val/mse_best\", self.val_mse_best.compute(), sync_dist=True, prog_bar=True)\n",
668+
" \n",
669+
"\n",
670+
" def test_step(self, batch, batch_idx):\n",
671+
" loss, y_hat, y = self._step(batch, batch_idx)\n",
672+
" self.test_loss(loss)\n",
673+
" self.test_mse(y_hat, y)\n",
674+
" \n",
675+
" # self.log(\"test/mse\", self.test_mse, on_epoch=True, on_step=True, prog_bar=True)\n",
676+
" self.log(\"test/loss\", self.test_loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n",
677+
"\n",
678+
" def on_test_epoch_end(self) -> None:\n",
679+
" pass\n",
680+
"\n"
622681
]
623682
},
624683
{

nbs/models.diffusion.ipynb

+315-95
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)