Skip to content

Commit 7fbda8a

Browse files
committed
nb prep
1 parent 2964689 commit 7fbda8a

File tree

1 file changed

+103
-4
lines changed

1 file changed

+103
-4
lines changed

nbs/models.diffusion.ipynb

+103-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "raw",
5+
"metadata": {
6+
"vscode": {
7+
"languageId": "raw"
8+
}
9+
},
10+
"source": [
11+
"---\n",
12+
"skip_exec: true\n",
13+
"skip_showdoc: true\n",
14+
"---"
15+
]
16+
},
317
{
418
"attachments": {},
519
"cell_type": "markdown",
@@ -476,7 +490,7 @@
476490
},
477491
{
478492
"cell_type": "code",
479-
"execution_count": null,
493+
"execution_count": 15,
480494
"metadata": {},
481495
"outputs": [
482496
{
@@ -492,14 +506,99 @@
492506
},
493507
"metadata": {},
494508
"output_type": "display_data"
509+
},
510+
{
511+
"name": "stderr",
512+
"output_type": "stream",
513+
"text": [
514+
"[10:43:16] INFO - Train loss: nan\n"
515+
]
516+
},
517+
{
518+
"data": {
519+
"application/vnd.jupyter.widget-view+json": {
520+
"model_id": "5bf19447f51d42c5b9931765e88dda40",
521+
"version_major": 2,
522+
"version_minor": 0
523+
},
524+
"text/plain": [
525+
" 0%| | 0/47 [00:00<?, ?it/s]"
526+
]
527+
},
528+
"metadata": {},
529+
"output_type": "display_data"
530+
},
531+
{
532+
"name": "stderr",
533+
"output_type": "stream",
534+
"text": [
535+
"[10:43:18] INFO - Epoch: 0, step: 0, n_steps: 256, loss: nan\n",
536+
"[10:43:20] INFO - Epoch: 0, step: 1, n_steps: 512, loss: nan\n",
537+
"[10:43:21] INFO - Epoch: 0, step: 2, n_steps: 768, loss: nan\n",
538+
"[10:43:22] INFO - Epoch: 0, step: 3, n_steps: 1024, loss: nan\n",
539+
"[10:43:23] INFO - Epoch: 0, step: 4, n_steps: 1280, loss: nan\n",
540+
"[10:43:24] INFO - Epoch: 0, step: 5, n_steps: 1536, loss: nan\n",
541+
"[10:43:25] INFO - Epoch: 0, step: 6, n_steps: 1792, loss: nan\n",
542+
"[10:43:27] INFO - Epoch: 0, step: 7, n_steps: 2048, loss: nan\n",
543+
"[10:43:28] INFO - Epoch: 0, step: 8, n_steps: 2304, loss: nan\n",
544+
"[10:43:29] INFO - Epoch: 0, step: 9, n_steps: 2560, loss: nan\n",
545+
"[10:43:30] INFO - Epoch: 0, step: 10, n_steps: 2816, loss: nan\n",
546+
"[10:43:31] INFO - Epoch: 0, step: 11, n_steps: 3072, loss: nan\n",
547+
"[10:43:32] INFO - Epoch: 0, step: 12, n_steps: 3328, loss: nan\n",
548+
"[10:43:34] INFO - Epoch: 0, step: 13, n_steps: 3584, loss: nan\n",
549+
"[10:43:35] INFO - Epoch: 0, step: 14, n_steps: 3840, loss: nan\n",
550+
"[10:43:36] INFO - Epoch: 0, step: 15, n_steps: 4096, loss: nan\n",
551+
"[10:43:37] INFO - Epoch: 0, step: 16, n_steps: 4352, loss: nan\n",
552+
"[10:43:38] INFO - Epoch: 0, step: 17, n_steps: 4608, loss: nan\n",
553+
"[10:43:39] INFO - Epoch: 0, step: 18, n_steps: 4864, loss: nan\n",
554+
"[10:43:41] INFO - Epoch: 0, step: 19, n_steps: 5120, loss: nan\n",
555+
"[10:43:41] INFO - Epoch: 0, step: 20, n_steps: 5376, loss: nan\n",
556+
"[10:43:43] INFO - Epoch: 0, step: 21, n_steps: 5632, loss: nan\n",
557+
"[10:43:44] INFO - Epoch: 0, step: 22, n_steps: 5888, loss: nan\n",
558+
"[10:43:45] INFO - Epoch: 0, step: 23, n_steps: 6144, loss: nan\n",
559+
"[10:43:46] INFO - Epoch: 0, step: 24, n_steps: 6400, loss: nan\n",
560+
"[10:43:47] INFO - Epoch: 0, step: 25, n_steps: 6656, loss: nan\n",
561+
"[10:43:48] INFO - Epoch: 0, step: 26, n_steps: 6912, loss: nan\n",
562+
"[10:43:50] INFO - Epoch: 0, step: 27, n_steps: 7168, loss: nan\n",
563+
"[10:43:50] INFO - Epoch: 0, step: 28, n_steps: 7424, loss: nan\n",
564+
"[10:43:52] INFO - Epoch: 0, step: 29, n_steps: 7680, loss: nan\n",
565+
"[10:43:53] INFO - Epoch: 0, step: 30, n_steps: 7936, loss: nan\n",
566+
"[10:43:54] INFO - Epoch: 0, step: 31, n_steps: 8192, loss: nan\n",
567+
"[10:43:55] INFO - Epoch: 0, step: 32, n_steps: 8448, loss: nan\n",
568+
"[10:43:56] INFO - Epoch: 0, step: 33, n_steps: 8704, loss: nan\n",
569+
"[10:43:57] INFO - Epoch: 0, step: 34, n_steps: 8960, loss: nan\n",
570+
"[10:43:58] INFO - Epoch: 0, step: 35, n_steps: 9216, loss: nan\n",
571+
"[10:43:59] INFO - Epoch: 0, step: 36, n_steps: 9472, loss: nan\n",
572+
"[10:44:01] INFO - Epoch: 0, step: 37, n_steps: 9728, loss: nan\n",
573+
"[10:44:01] INFO - Epoch: 0, step: 38, n_steps: 9984, loss: nan\n",
574+
"[10:44:03] INFO - Epoch: 0, step: 39, n_steps: 10240, loss: nan\n",
575+
"[10:44:04] INFO - Epoch: 0, step: 40, n_steps: 10496, loss: nan\n",
576+
"[10:44:05] INFO - Epoch: 0, step: 41, n_steps: 10752, loss: nan\n",
577+
"[10:44:06] INFO - Epoch: 0, step: 42, n_steps: 11008, loss: nan\n",
578+
"[10:44:08] INFO - Epoch: 0, step: 43, n_steps: 11264, loss: nan\n",
579+
"[10:44:09] INFO - Epoch: 0, step: 44, n_steps: 11520, loss: nan\n",
580+
"[10:44:10] INFO - Epoch: 0, step: 45, n_steps: 11776, loss: nan\n"
581+
]
582+
},
583+
{
584+
"ename": "KeyboardInterrupt",
585+
"evalue": "",
586+
"output_type": "error",
587+
"traceback": [
588+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
589+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
590+
"Cell \u001b[0;32mIn[15], line 40\u001b[0m\n\u001b[1;32m 38\u001b[0m noisy_images \u001b[38;5;241m=\u001b[39m noise_scheduler\u001b[38;5;241m.\u001b[39madd_noise(images, noise, timesteps)\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# train model to predict noise\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoisy_images\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# output should be as close to input as possible\u001b[39;00m\n\u001b[1;32m 42\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(outputs, noise)\n",
591+
"File \u001b[0;32m~/miniforge3/envs/nimrod/lib/python3.11/site-packages/torch/nn/modules/module.py:1732\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1729\u001b[0m tracing_state\u001b[38;5;241m.\u001b[39mpop_scope()\n\u001b[1;32m 1730\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n\u001b[0;32m-> 1732\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapped_call_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1733\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n",
592+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
593+
]
495594
}
496595
],
497596
"source": [
498597
"\n",
499598
"for epoch in range(NUM_EPOCHS):\n",
500599
" i = 0\n",
501600
" model.train()\n",
502-
" n_steps, loss = 0, 0\n",
601+
" n_steps, total_loss = 0, 0\n",
503602
" for step, (images, labels) in tqdm(enumerate(dm.train_dataloader()), total=len(dm.train_dataloader())):\n",
504603
" optimizer.zero_grad()\n",
505604
" images, labels = images.to(device), labels.to(device)\n",
@@ -514,14 +613,14 @@
514613
" # output should be as close to input as possible\n",
515614
" loss = criterion(outputs, noise)\n",
516615
" n_steps += len(images)\n",
517-
" loss += (loss.item() * len(images))\n",
616+
" total_loss += (loss.item() * len(images))\n",
518617
" # logger.info(f\"loss.item(): {loss.item()}, len(images): {len(images)}\")\n",
519618
" # logger.info(f\"Epoch: {epoch}, step: {step}, n_steps: {n_steps}, loss: {loss}\")\n",
520619
" loss.backward()\n",
521620
" optimizer.step()\n",
522621
" lr_scheduler.step()\n",
523622
"\n",
524-
" logger.info(f\"Train loss: {loss / n_steps}\")\n",
623+
" logger.info(f\"Train loss: {total_loss / n_steps}\")\n",
525624
"\n",
526625
" model.eval()\n",
527626
" total_loss, n_steps = 0, 0\n",

0 commit comments

Comments
 (0)