|
560 | 560 | "\n"
|
561 | 561 | ]
|
562 | 562 | },
|
| 563 | + { |
| 564 | + "cell_type": "markdown", |
| 565 | + "metadata": {}, |
| 566 | + "source": [ |
| 567 | + "## Diffuser Abstract Class" |
| 568 | + ] |
| 569 | + }, |
563 | 570 | {
|
564 | 571 | "cell_type": "code",
|
565 | 572 | "execution_count": null,
|
566 | 573 | "metadata": {},
|
567 |
| - "outputs": [ |
568 |
| - { |
569 |
| - "name": "stderr", |
570 |
| - "output_type": "stream", |
571 |
| - "text": [ |
572 |
| - "[21:54:33] INFO - Init ImageDataModule for mnist\n" |
573 |
| - ] |
574 |
| - }, |
575 |
| - { |
576 |
| - "name": "stderr", |
577 |
| - "output_type": "stream", |
578 |
| - "text": [ |
579 |
| - "[21:54:36] INFO - loading dataset mnist with args () from split train\n", |
580 |
| - "[21:54:36] INFO - loading dataset mnist from split train\n", |
581 |
| - "Overwrite dataset info from restored data version if exists.\n", |
582 |
| - "[21:54:38] INFO - Overwrite dataset info from restored data version if exists.\n", |
583 |
| - "Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
584 |
| - "[21:54:38] INFO - Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
585 |
| - "Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n", |
586 |
| - "[21:54:38] INFO - Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n", |
587 |
| - "Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
588 |
| - "[21:54:38] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
589 |
| - "[21:54:42] INFO - loading dataset mnist with args () from split test\n", |
590 |
| - "[21:54:42] INFO - loading dataset mnist from split test\n", |
591 |
| - "Overwrite dataset info from restored data version if exists.\n", |
592 |
| - "[21:54:43] INFO - Overwrite dataset info from restored data version if exists.\n", |
593 |
| - "Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
594 |
| - "[21:54:43] INFO - Loading Dataset info from ../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
595 |
| - "Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n", |
596 |
| - "[21:54:43] INFO - Found cached dataset mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c)\n", |
597 |
| - "Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
598 |
| - "[21:54:43] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/mnist/mnist/0.0.0/77f3279092a1c1579b2250db8eafed0ad422088c\n", |
599 |
| - "[21:54:44] INFO - split train into train/val [0.8, 0.2]\n", |
600 |
| - "[21:54:44] INFO - train: 48000 val: 12000, test: 10000\n" |
601 |
| - ] |
602 |
| - }, |
603 |
| - { |
604 |
| - "data": { |
605 |
| - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADcCAYAAADa3YUtAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAE2ZJREFUeJzt3XtQVOX/B/D3rl9YCHEJhF1JFnE0L2laJIi3UBnRsSa8zFQzpY6Nli2N4pQTlmJmbnlJwyhrTNFuOJZY6YyTg4jVACpeGkTJHFRMd8WSBS+Ass/vD3/ut/3uWR4WFnbB92vm/MFnn939HOXN4TycfY5KCCFARC6pvd0Aka9jSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESIIhIZJgSHzQuXPnoFKpsGbNGo+95oEDB6BSqXDgwAGPveb9giHxkOzsbKhUKhw5csTbrbSp7du3IyEhAUFBQQgJCcGIESOwf/9+b7fVpv7j7Qao41i2bBmWL1+O6dOnY9asWbh9+zZKS0vx119/ebu1NsWQULMUFRVh+fLlWLt2LdLS0rzdTrvir1vtqKGhAUuXLkVsbCy0Wi2CgoIwevRo5Ofnu3zOunXrEB0djcDAQDz55JMoLS11GnP69GlMnz4doaGhCAgIwBNPPIEff/xR2s/Nmzdx+vRpXL16VTp2/fr10Ov1mD9/PoQQuH79uvQ5nQVD0o5qamqwadMmJCYm4oMPPsCyZctQVVWF5ORkHD9+3Gn8tm3bkJmZCaPRiPT0dJSWlmLcuHGwWCz2MSdPnsTw4cNx6tQpvPnmm1i7di2CgoKQkpKC3NzcJvs5dOgQBgwYgI8//ljae15eHoYNG4bMzEyEh4cjODgYPXr0aNZzOzxBHrFlyxYBQBw+fNjlmDt37oj6+nqH2rVr14ROpxOzZ8+21yoqKgQAERgYKC5evGivFxcXCwAiLS3NXhs/frwYPHiwqKurs9dsNpsYMWKE6Nu3r72Wn58vAIj8/HynWkZGRpP79s8//wgAIiwsTHTt2lWsXr1abN++XUycOFEAEBs3bmzy+R0dQ+IhzQnJvzU2Noq///5bVFVVicmTJ4uhQ4faH7sXkueff97pefHx8aJfv35CCCH+/vtvoVKpxLvvviuqqqoctnfeeUcAsIdMKSTNdeHCBQFAABA5OTkO+zBw4EDRs2dPt1+zI+GvW+1s69atePTRRxEQEICwsDCEh4djz549sFqtTmP79u3rVHv44Ydx7tw5AMCff/4JIQSWLFmC8PBwhy0jIwMAcOXKlVb3HBgYCADw8/PD9OnT7XW1Wo1nn30WFy9exIULF1r9Pr6Ks1vt6KuvvsKsWbOQkpKCN954AxEREejSpQtMJhPOnj3r9uvZbDYAwOuvv47k5GTFMX369GlVzwDsEwIhISHo0qWLw2MREREAgGvXrsFgMLT6vXwRQ9KOvvvuO/Tu3Rs7d+6ESqWy1+/91P9fZ86ccar98ccf6NWrFwCgd+/eAO7+hE9KSvJ8w/9PrVZj6NChOHz4MBoaGuDv729/7NKlSwCA8PDwNnt/b+OvW+3o3k9h8a+1N4qLi1FYWKg4fteuXQ5/qDt06BCKi4sxadIkAHd/iicmJuKzzz7D5cuXnZ5fVVXVZD/uTAE/++yzaGxsxNatW+21uro6fP311xg4cCAiIyOlr9FR8UjiYZs3b8bevXud6vPnz8dTTz2FnTt3YsqUKZg8eTIqKiqwceNGDBw4UPHvDn369MGoUaMwb9481NfXY/369QgLC8OiRYvsY7KysjBq1CgMHjwYc+bMQe/evWGxWFBYWIiLFy/ixIkTLns9dOgQxo4di4yMDCxbtqzJ/Xr55ZexadMmGI1G/PHHHzAYDPjyyy9x/vx5/PTTT83/B+qIvD1z0Fncm91ytVVWVgqbzSZWrlwpoqOjhUajEY899pjYvXu3mDlzpoiOjra/1r3ZrdWrV4u1a9eKqKgoodFoxOjRo8WJEyec3vvs2bNixowZQq/XCz8/P/HQQw+Jp556Snz33Xf2Ma2ZAr7HYrGImTNnitDQUKHRaER8fLzYu3dvS//JOgyVEFx3i6gpPCchkmBIiCQYEiIJhoRIgiEhkmBIiCTa7I+JWVlZWL16NcxmM4YMGYINGzYgLi5O+jybzYZLly4hODjY4dINIk8SQqC2thaRkZFQqyXHirb440tOTo7w9/cXmzdvFidPnhRz5swRISEhwmKxSJ9bWVnZ5B/luHHz5FZZWSn9nmyTkMTFxQmj0Wj/urGxUURGRgqTySR9bnV1tdf/4bjdP1t1dbX0e9Lj5yQNDQ0oKSlxuCpVrVYjKSlJ8UK++vp61NTU2Lfa2lpPt0TkUnN+pfd4SK5evYrGxkbodDqHuk6ng9lsdhpvMpmg1WrtW1RUlKdbImoVr89upaenw2q12rfKykpvt0TkwOOzW927d0eXLl0cVvQAAIvFAr1e7zReo9FAo9F4ug0ij/H4kcTf3x+xsbHIy8uz12w2G/Ly8pCQkODptyNqe62axnIhJydHaDQakZ2dLcrKysTcuXNFSEiIMJvN0udarVavz3hwu382q9Uq/Z5ssw9dbdiwQRgMBuHv7y/i4uJEUVFRs57HkHBrz605IfG5D13V1NRAq9V6uw26T1itVnTr1q3JMV6f3SLydVwIohObOHGiYv3fK57cc/ToUcWxL774olOtOaurdCY8khBJMCREEgwJkQRDQiTBE/dObO3atYr1sLAwp5qrJVHvt5N0JTySEEkwJEQSDAmRBENCJMGQEElwdquTGDNmjFNtwIABimNPnTrlVJsxY4bHe+oseCQhkmBIiCQYEiIJhoRIgifunUR6erpTzdWHTr///vu2bqdT4ZGESIIhIZJgSIgkGBIiCYaESIKzWx1MdHS0Yt1gMDjVXK2AkpmZ6dGeOjseSYgkGBIiCYaESIIhIZLgiXsHo3T5CQD069fPqebqBJ0roLiHRxIiCYaESIIhIZJgSIgkGBIiCd4Ozof179/fqVZWVqY4VmkFlEceecTjPXU2vB0ckQcwJEQSDAmRBENCJMHLUnxAeHi4Yn3dunVONVfzLCtWrPBoT/RfPJIQSTAkRBIMCZEEQ0Ik4XZIDh48iKeffhqRkZFQqVTYtWuXw+NCCCxduhQ9evRAYGAgkpKScObMGU/1S9Tu3J7dunHjBoYMGYLZs2dj6tSpTo+vWrUKmZmZ2Lp1K2JiYrBkyRIkJyejrKwMAQEBHmm6s1Fa6QQAHn/8caeaqxVQ9u3b59Ge6L/cDsmkSZMwadIkxceEEFi/fj3efvttPPPMMwCAbdu2QafTYdeuXXjuueda1y2RF3j0nKSiogJmsxlJSUn2mlarRXx8PAoLCxWfU19fj5qaGoeNyJd4NCRmsxkAoNPpHOo6nc7+2P8ymUzQarX2LSoqypMtEbWa12e30tPTYbVa7VtlZaW3WyJy4NHLUvR6PQDAYrGgR48e9rrFYsHQoUMVn6PRaKDRaDzZRoczevRoxXpYWJhTbcmSJYpjuQJK2/HokSQmJgZ6vR55eXn2Wk1NDYqLi5GQkODJtyJqN24fSa5fv44///zT/nVFRQWOHz+O0NBQGAwGLFiwACtWrEDfvn3tU8CRkZFISUnxZN9E7cbtkBw5cgRjx461f71w4UIAwMyZM5GdnY1Fixbhxo0bmDt3LqqrqzFq1Cjs3buXfyOhDsvtkCQmJrq8XBsAVCoVli9fjuXLl7eqMSJf4fXZLSJfx9VSfIDNZlOsK/3XuFoB5fTp0x7t6X7B1VKIPIAhIZJgSIgkGBIiCa6W0s6mTJniVHM1d7Jz506nGk/Q2x+PJEQSDAmRBENCJMGQEEkwJEQSnN1qI0FBQYr1F154wanm6tOYH330kUd7akpaWppiXemmQaNGjVIcazKZnGo3b95sXWM+gEcSIgmGhEiCISGSYEiIJHji3ka6d++uWB85cqRT7datW4pj2+oSFKW7+q5Zs0ZxrNIlMyqVqtnv5Wp1l46ERxIiCYaESIIhIZJgSIgkGBIiCc5utRGlGxwByrej/vnnnxXHtnZ9X6VLYABg69atTjVXl8YozW5FR0crjnW1pnFHxyMJkQRDQiTBkBBJMCREEjxxbyP9+vVTrCudCL/33nutfr+33nrLqeZq0XKl98vMzFQcu3jxYqfa/PnzFceeOnWqqRY7LB5JiCQYEiIJhoRIgiEhkmBIiCR4E582cuXKFcV6VVWVU83VjXncceTIEadabm6u4lilFVBczbApzdK5uoxm0qRJTbXok3gTHyIPYEiIJBgSIgmGhEiCl6W0EVeXaCitluKK0qomY8aMURwbFhbmVHvzzTeb/V6uliP9/PPPnWqdYQUUd/BIQiTBkBBJMCREEgwJkYRbITGZTBg2bBiCg4MRERGBlJQUlJeXO4ypq6uD0WhEWFgYunbtimnTpsFisXi0aaL25NbsVkFBAYxGI4YNG4Y7d+5g8eLFmDBhAsrKyuw3rUlLS8OePXuwY8cOaLVapKamYurUqfjtt9/aZAd8lat1fJVugOPqyiCbzeZUU6uVf64pjT169KjiWKWZt5UrVyqO5S2x3QzJ3r17Hb7Ozs5GREQESkpKMGbMGFitVnzxxRf45ptvMG7cOADAli1bMGDAABQVFWH48OGe65yonbTqnMRqtQIAQkNDAQAlJSW4ffs2kpKS7GP69+8Pg8GAwsJCxdeor69HTU2Nw0bkS1ocEpvNhgULFmDkyJEYNGgQAMBsNsPf3x8hISEOY3U6Hcxms+LrmEwmaLVa+xYVFdXSlojaRItDYjQaUVpaipycnFY1kJ6eDqvVat9crSRI5C0tuiwlNTUVu3fvxsGDB9GzZ097Xa/Xo6GhAdXV1Q5HE4vFAr1er/haGo0GGo2mJW34tF9++UWxrvT5DFfLgyqd0BcUFCiOVTrxdvW5D3KPW0cSIQRSU1ORm5uL/fv3IyYmxuHx2NhY+Pn5IS8vz14rLy/HhQsXkJCQ4JmOidqZW0cSo9GIb775Bj/88AOCg4Pt5xlarRaBgYHQarV46aWXsHDhQoSGhqJbt2547bXXkJCQwJkt6rDcCsmnn34KAEhMTHSob9myBbNmzQIArFu3Dmq1GtOmTUN9fT2Sk5PxySefeKRZIm9wKyTN+Th8QEAAsrKykJWV1eKmiHwJr90ikuBqKXRf42opRB7AkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEgwJkQRDQiTBkBBJMCREEj4XEh+7XQp1cs35fvO5kNTW1nq7BbqPNOf7zefudGWz2XDp0iUEBwejtrYWUVFRqKyslN6NqKOpqanhvnmREAK1tbWIjIyEWt30scKtG4u2B7VajZ49ewIAVCoVAKBbt24++4/dWtw372nubQd97tctIl/DkBBJ+HRINBoNMjIyoNFovN2Kx3HfOg6fO3En8jU+fSQh8gUMCZEEQ0IkwZAQSfh0SLKystCrVy8EBAQgPj4ehw4d8nZLbjt48CCefvppREZGQqVSYdeuXQ6PCyGwdOlS9OjRA4GBgUhKSsKZM2e806wbTCYThg0bhuDgYERERCAlJQXl5eUOY+rq6mA0GhEWFoauXbti2rRpsFgsXuq45Xw2JNu3b8fChQuRkZGBo0ePYsiQIUhOTsaVK1e83Zpbbty4gSFDhiArK0vx8VWrViEzMxMbN25EcXExgoKCkJycjLq6unbu1D0FBQUwGo0oKirCvn37cPv2bUyYMAE3btywj0lLS8NPP/2EHTt2oKCgAJcuXcLUqVO92HULCR8VFxcnjEaj/evGxkYRGRkpTCaTF7tqHQAiNzfX/rXNZhN6vV6sXr3aXquurhYajUZ8++23Xuiw5a5cuSIAiIKCAiHE3f3w8/MTO3bssI85deqUACAKCwu91WaL+OSRpKGhASUlJUhKSrLX1Go1kpKSUFhY6MXOPKuiogJms9lhP7VaLeLj4zvcflqtVgBAaGgoAKCkpAS3b9922Lf+/fvDYDB0uH3zyZBcvXoVjY2N0Ol0DnWdTgez2eylrjzv3r509P202WxYsGABRo4ciUGDBgG4u2/+/v4ICQlxGNvR9g3wwauAqeMxGo0oLS3Fr7/+6u1W2oRPHkm6d++OLl26OM2EWCwW6PV6L3Xleff2pSPvZ2pqKnbv3o38/Hz7RxyAu/vW0NCA6upqh/Edad/u8cmQ+Pv7IzY2Fnl5efaazWZDXl4eEhISvNiZZ8XExECv1zvsZ01NDYqLi31+P4UQSE1NRW5uLvbv34+YmBiHx2NjY+Hn5+ewb+Xl5bhw4YLP75sTb88cuJKTkyM0Go3Izs4WZWVlYu7cuSIkJESYzWZvt+aW2tpacezYMXHs2DEBQHz44Yfi2LFj4vz580IIId5//30REhIifvjhB/H777+LZ555RsTExIhbt255ufOmzZs3T2i1WnHgwAFx+fJl+3bz5k37mFdeeUUYDAaxf/9+ceTIEZGQkCASEhK82HXL+GxIhBBiw4YNwmAwCH9/fxEXFyeKioq83ZLb8vPzBQCnbebMmUKIu9PAS5YsETqdTmg0GjF+/HhRXl7u3aabQWmfAIgtW7bYx9y6dUu8+uqr4sEHHxQPPPCAmDJlirh8+bL3mm4hXipPJOGT5yREvoQhIZJgSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESIIhIZJgSIgkGBIiif8DUF99nwmMfM8AAAAASUVORK5CYII=", |
606 |
| - "text/plain": [ |
607 |
| - "<Figure size 200x200 with 1 Axes>" |
608 |
| - ] |
609 |
| - }, |
610 |
| - "metadata": {}, |
611 |
| - "output_type": "display_data" |
612 |
| - } |
613 |
| - ], |
| 574 | + "outputs": [], |
614 | 575 | "source": [
|
615 |
| - "#| notest\n", |
| 576 | + "#| export\n", |
616 | 577 | "\n",
|
617 |
| - "cfg = OmegaConf.load(\"../config/data/image/mnist.yaml\")\n", |
618 |
| - "dm = instantiate(cfg)\n", |
619 |
| - "dm.prepare_data()\n", |
620 |
| - "dm.setup()\n", |
621 |
| - "dm.show(0)\n" |
| 578 | + "class Diffuser(ABC, L.LightningModule):\n", |
| 579 | + " def __init__(\n", |
| 580 | + " self,\n", |
| 581 | + " nnet: L.LightningModule,\n", |
| 582 | + " optimizer: Callable[...,torch.optim.Optimizer], # partial of optimizer\n", |
| 583 | + " scheduler: Optional[Callable[...,Any]]=None, # partial of scheduler\n", |
| 584 | + "\n", |
| 585 | + " ):\n", |
| 586 | + " logger.info(\"Diffuser: init\")\n", |
| 587 | + " super().__init__()\n", |
| 588 | + " self.save_hyperparameters()\n", |
| 589 | + " self.lr = optimizer.keywords.get('lr') if optimizer else None # for lr finder\n", |
| 590 | + " self.nnet = nnet\n", |
| 591 | + " # explicitely register nnet as a module to track its parameters\n", |
| 592 | + " self.register_module('nnet', self.nnet)\n", |
| 593 | + "\n", |
| 594 | + " # loss\n", |
| 595 | + " self.criterion = nn.MSELoss()\n", |
| 596 | + "\n", |
| 597 | + " # average accross batches\n", |
| 598 | + " self.train_loss = MeanMetric()\n", |
| 599 | + " self.val_loss = MeanMetric()\n", |
| 600 | + " self.test_loss = MeanMetric()\n", |
| 601 | + " self.val_loss_best = MinMetric()\n", |
| 602 | + "\n", |
| 603 | + " def forward(self, x:torch.Tensor, t:torch.Tensor)->torch.Tensor:\n", |
| 604 | + " return self.nnet(x, t)\n", |
| 605 | + "\n", |
| 606 | + " def configure_optimizers(self):\n", |
| 607 | + " logger.info(\"Regressor: configure_optimizers\")\n", |
| 608 | + " self.optimizer = self.hparams.optimizer(params=self.parameters())\n", |
| 609 | + " logger.info(f\"Optimizer: {self.optimizer.__class__}\")\n", |
| 610 | + " if self.hparams.scheduler is None:\n", |
| 611 | + " logger.warning(\"no scheduler has been setup\")\n", |
| 612 | + " return {\"optimizer\": self.optimizer}\n", |
| 613 | + " self.scheduler = self.hparams.scheduler(optimizer=self.optimizer)\n", |
| 614 | + " if isinstance(self.scheduler, torch.optim.lr_scheduler.OneCycleLR):\n", |
| 615 | + " lr_scheduler = {\n", |
| 616 | + " \"scheduler\": self.scheduler,\n", |
| 617 | + " \"interval\": \"step\",\n", |
| 618 | + " \"frequency\": 1,\n", |
| 619 | + " }\n", |
| 620 | + " else:\n", |
| 621 | + " lr_scheduler = {\n", |
| 622 | + " \"scheduler\": self.scheduler,\n", |
| 623 | + " \"interval\": \"epoch\",\n", |
| 624 | + " \"frequency\": 1,\n", |
| 625 | + " }\n", |
| 626 | + " logger.info(f\"Scheduler: {self.scheduler.__class__}\")\n", |
| 627 | + " return {\"optimizer\": self.optimizer, \"lr_scheduler\": lr_scheduler}\n", |
| 628 | + "\n", |
| 629 | + " def on_train_start(self) -> None:\n", |
| 630 | + " self.val_loss.reset()\n", |
| 631 | + " self.val_loss_best.reset()\n", |
| 632 | + " \n", |
| 633 | + " def predict_step(self, batch, batch_idx, dataloader_idx=0):\n", |
| 634 | + " x, y = batch\n", |
| 635 | + " y_hat = self.forward(x)\n", |
| 636 | + " return y_hat\n", |
| 637 | + "\n", |
| 638 | + " def _step(self, batch:Tuple[torch.Tensor, torch.Tensor], batch_idx:int):\n", |
| 639 | + " x, y = batch\n", |
| 640 | + " y_hat = self.forward(x)\n", |
| 641 | + " loss = self.criterion(y_hat, y)\n", |
| 642 | + " return loss, y_hat, y\n", |
| 643 | + "\n", |
| 644 | + " def training_step(self, batch, batch_idx):\n", |
| 645 | + " loss, y_hat, y = self._step(batch, batch_idx)\n", |
| 646 | + " self.train_loss(loss)\n", |
| 647 | + " self.train_mse(y_hat, y)\n", |
| 648 | + " # self.log(\"train/mse\", self.train_mse, on_epoch=True, on_step=True, prog_bar=True)\n", |
| 649 | + " self.log(\"train/loss\", self.train_loss, on_epoch=True, on_step=True, prog_bar=True)\n", |
| 650 | + " return loss\n", |
| 651 | + "\n", |
| 652 | + " def on_train_epoch_end(self) -> None:\n", |
| 653 | + " pass\n", |
| 654 | + "\n", |
| 655 | + " def validation_step(self, batch, batch_idx):\n", |
| 656 | + " loss, y_hat, y = self._step(batch, batch_idx)\n", |
| 657 | + " self.val_loss(loss)\n", |
| 658 | + " self.val_mse(y_hat, y)\n", |
| 659 | + " # self.log(\"val/mse\", self.val_mse, on_epoch=True, on_step=True, prog_bar=True)\n", |
| 660 | + " self.log(\"val/loss\", self.val_loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n", |
| 661 | + " \n", |
| 662 | + " def on_validation_epoch_end(self) -> None:\n", |
| 663 | + " mse = self.val_mse.compute() # get current val acc\n", |
| 664 | + " self.val_mse_best(mse) # update best so far val acc\n", |
| 665 | + " # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object\n", |
| 666 | + " # otherwise metric would be reset by lightning after each epoch\n", |
| 667 | + " self.log(\"val/mse_best\", self.val_mse_best.compute(), sync_dist=True, prog_bar=True)\n", |
| 668 | + " \n", |
| 669 | + "\n", |
| 670 | + " def test_step(self, batch, batch_idx):\n", |
| 671 | + " loss, y_hat, y = self._step(batch, batch_idx)\n", |
| 672 | + " self.test_loss(loss)\n", |
| 673 | + " self.test_mse(y_hat, y)\n", |
| 674 | + " \n", |
| 675 | + " # self.log(\"test/mse\", self.test_mse, on_epoch=True, on_step=True, prog_bar=True)\n", |
| 676 | + " self.log(\"test/loss\", self.test_loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)\n", |
| 677 | + "\n", |
| 678 | + " def on_test_epoch_end(self) -> None:\n", |
| 679 | + " pass\n", |
| 680 | + "\n" |
622 | 681 | ]
|
623 | 682 | },
|
624 | 683 | {
|
|
0 commit comments