|
1 | 1 | {
|
2 | 2 | "cells": [
|
| 3 | + { |
| 4 | + "cell_type": "raw", |
| 5 | + "metadata": { |
| 6 | + "vscode": { |
| 7 | + "languageId": "raw" |
| 8 | + } |
| 9 | + }, |
| 10 | + "source": [ |
| 11 | + "---\n", |
| 12 | + "skip_exec: true\n", |
| 13 | + "skip_showdoc: true\n", |
| 14 | + "---" |
| 15 | + ] |
| 16 | + }, |
3 | 17 | {
|
4 | 18 | "attachments": {},
|
5 | 19 | "cell_type": "markdown",
|
|
476 | 490 | },
|
477 | 491 | {
|
478 | 492 | "cell_type": "code",
|
479 |
| - "execution_count": null, |
| 493 | + "execution_count": 15, |
480 | 494 | "metadata": {},
|
481 | 495 | "outputs": [
|
482 | 496 | {
|
|
492 | 506 | },
|
493 | 507 | "metadata": {},
|
494 | 508 | "output_type": "display_data"
|
| 509 | + }, |
| 510 | + { |
| 511 | + "name": "stderr", |
| 512 | + "output_type": "stream", |
| 513 | + "text": [ |
| 514 | + "[10:43:16] INFO - Train loss: nan\n" |
| 515 | + ] |
| 516 | + }, |
| 517 | + { |
| 518 | + "data": { |
| 519 | + "application/vnd.jupyter.widget-view+json": { |
| 520 | + "model_id": "5bf19447f51d42c5b9931765e88dda40", |
| 521 | + "version_major": 2, |
| 522 | + "version_minor": 0 |
| 523 | + }, |
| 524 | + "text/plain": [ |
| 525 | + " 0%| | 0/47 [00:00<?, ?it/s]" |
| 526 | + ] |
| 527 | + }, |
| 528 | + "metadata": {}, |
| 529 | + "output_type": "display_data" |
| 530 | + }, |
| 531 | + { |
| 532 | + "name": "stderr", |
| 533 | + "output_type": "stream", |
| 534 | + "text": [ |
| 535 | + "[10:43:18] INFO - Epoch: 0, step: 0, n_steps: 256, loss: nan\n", |
| 536 | + "[10:43:20] INFO - Epoch: 0, step: 1, n_steps: 512, loss: nan\n", |
| 537 | + "[10:43:21] INFO - Epoch: 0, step: 2, n_steps: 768, loss: nan\n", |
| 538 | + "[10:43:22] INFO - Epoch: 0, step: 3, n_steps: 1024, loss: nan\n", |
| 539 | + "[10:43:23] INFO - Epoch: 0, step: 4, n_steps: 1280, loss: nan\n", |
| 540 | + "[10:43:24] INFO - Epoch: 0, step: 5, n_steps: 1536, loss: nan\n", |
| 541 | + "[10:43:25] INFO - Epoch: 0, step: 6, n_steps: 1792, loss: nan\n", |
| 542 | + "[10:43:27] INFO - Epoch: 0, step: 7, n_steps: 2048, loss: nan\n", |
| 543 | + "[10:43:28] INFO - Epoch: 0, step: 8, n_steps: 2304, loss: nan\n", |
| 544 | + "[10:43:29] INFO - Epoch: 0, step: 9, n_steps: 2560, loss: nan\n", |
| 545 | + "[10:43:30] INFO - Epoch: 0, step: 10, n_steps: 2816, loss: nan\n", |
| 546 | + "[10:43:31] INFO - Epoch: 0, step: 11, n_steps: 3072, loss: nan\n", |
| 547 | + "[10:43:32] INFO - Epoch: 0, step: 12, n_steps: 3328, loss: nan\n", |
| 548 | + "[10:43:34] INFO - Epoch: 0, step: 13, n_steps: 3584, loss: nan\n", |
| 549 | + "[10:43:35] INFO - Epoch: 0, step: 14, n_steps: 3840, loss: nan\n", |
| 550 | + "[10:43:36] INFO - Epoch: 0, step: 15, n_steps: 4096, loss: nan\n", |
| 551 | + "[10:43:37] INFO - Epoch: 0, step: 16, n_steps: 4352, loss: nan\n", |
| 552 | + "[10:43:38] INFO - Epoch: 0, step: 17, n_steps: 4608, loss: nan\n", |
| 553 | + "[10:43:39] INFO - Epoch: 0, step: 18, n_steps: 4864, loss: nan\n", |
| 554 | + "[10:43:41] INFO - Epoch: 0, step: 19, n_steps: 5120, loss: nan\n", |
| 555 | + "[10:43:41] INFO - Epoch: 0, step: 20, n_steps: 5376, loss: nan\n", |
| 556 | + "[10:43:43] INFO - Epoch: 0, step: 21, n_steps: 5632, loss: nan\n", |
| 557 | + "[10:43:44] INFO - Epoch: 0, step: 22, n_steps: 5888, loss: nan\n", |
| 558 | + "[10:43:45] INFO - Epoch: 0, step: 23, n_steps: 6144, loss: nan\n", |
| 559 | + "[10:43:46] INFO - Epoch: 0, step: 24, n_steps: 6400, loss: nan\n", |
| 560 | + "[10:43:47] INFO - Epoch: 0, step: 25, n_steps: 6656, loss: nan\n", |
| 561 | + "[10:43:48] INFO - Epoch: 0, step: 26, n_steps: 6912, loss: nan\n", |
| 562 | + "[10:43:50] INFO - Epoch: 0, step: 27, n_steps: 7168, loss: nan\n", |
| 563 | + "[10:43:50] INFO - Epoch: 0, step: 28, n_steps: 7424, loss: nan\n", |
| 564 | + "[10:43:52] INFO - Epoch: 0, step: 29, n_steps: 7680, loss: nan\n", |
| 565 | + "[10:43:53] INFO - Epoch: 0, step: 30, n_steps: 7936, loss: nan\n", |
| 566 | + "[10:43:54] INFO - Epoch: 0, step: 31, n_steps: 8192, loss: nan\n", |
| 567 | + "[10:43:55] INFO - Epoch: 0, step: 32, n_steps: 8448, loss: nan\n", |
| 568 | + "[10:43:56] INFO - Epoch: 0, step: 33, n_steps: 8704, loss: nan\n", |
| 569 | + "[10:43:57] INFO - Epoch: 0, step: 34, n_steps: 8960, loss: nan\n", |
| 570 | + "[10:43:58] INFO - Epoch: 0, step: 35, n_steps: 9216, loss: nan\n", |
| 571 | + "[10:43:59] INFO - Epoch: 0, step: 36, n_steps: 9472, loss: nan\n", |
| 572 | + "[10:44:01] INFO - Epoch: 0, step: 37, n_steps: 9728, loss: nan\n", |
| 573 | + "[10:44:01] INFO - Epoch: 0, step: 38, n_steps: 9984, loss: nan\n", |
| 574 | + "[10:44:03] INFO - Epoch: 0, step: 39, n_steps: 10240, loss: nan\n", |
| 575 | + "[10:44:04] INFO - Epoch: 0, step: 40, n_steps: 10496, loss: nan\n", |
| 576 | + "[10:44:05] INFO - Epoch: 0, step: 41, n_steps: 10752, loss: nan\n", |
| 577 | + "[10:44:06] INFO - Epoch: 0, step: 42, n_steps: 11008, loss: nan\n", |
| 578 | + "[10:44:08] INFO - Epoch: 0, step: 43, n_steps: 11264, loss: nan\n", |
| 579 | + "[10:44:09] INFO - Epoch: 0, step: 44, n_steps: 11520, loss: nan\n", |
| 580 | + "[10:44:10] INFO - Epoch: 0, step: 45, n_steps: 11776, loss: nan\n" |
| 581 | + ] |
| 582 | + }, |
| 583 | + { |
| 584 | + "ename": "KeyboardInterrupt", |
| 585 | + "evalue": "", |
| 586 | + "output_type": "error", |
| 587 | + "traceback": [ |
| 588 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 589 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", |
| 590 | + "Cell \u001b[0;32mIn[15], line 40\u001b[0m\n\u001b[1;32m 38\u001b[0m noisy_images \u001b[38;5;241m=\u001b[39m noise_scheduler\u001b[38;5;241m.\u001b[39madd_noise(images, noise, timesteps)\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# train model to predict noise\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoisy_images\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# output should be as close to input as possible\u001b[39;00m\n\u001b[1;32m 42\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(outputs, noise)\n", |
| 591 | + "File \u001b[0;32m~/miniforge3/envs/nimrod/lib/python3.11/site-packages/torch/nn/modules/module.py:1732\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1729\u001b[0m tracing_state\u001b[38;5;241m.\u001b[39mpop_scope()\n\u001b[1;32m 1730\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n\u001b[0;32m-> 1732\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapped_call_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1733\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n", |
| 592 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " |
| 593 | + ] |
495 | 594 | }
|
496 | 595 | ],
|
497 | 596 | "source": [
|
498 | 597 | "\n",
|
499 | 598 | "for epoch in range(NUM_EPOCHS):\n",
|
500 | 599 | " i = 0\n",
|
501 | 600 | " model.train()\n",
|
502 |
| - " n_steps, loss = 0, 0\n", |
| 601 | + " n_steps, total_loss = 0, 0\n", |
503 | 602 | " for step, (images, labels) in tqdm(enumerate(dm.train_dataloader()), total=len(dm.train_dataloader())):\n",
|
504 | 603 | " optimizer.zero_grad()\n",
|
505 | 604 | " images, labels = images.to(device), labels.to(device)\n",
|
|
514 | 613 | " # output should be as close to input as possible\n",
|
515 | 614 | " loss = criterion(outputs, noise)\n",
|
516 | 615 | " n_steps += len(images)\n",
|
517 |
| - " loss += (loss.item() * len(images))\n", |
| 616 | + " total_loss += (loss.item() * len(images))\n", |
518 | 617 | " # logger.info(f\"loss.item(): {loss.item()}, len(images): {len(images)}\")\n",
|
519 | 618 | " # logger.info(f\"Epoch: {epoch}, step: {step}, n_steps: {n_steps}, loss: {loss}\")\n",
|
520 | 619 | " loss.backward()\n",
|
521 | 620 | " optimizer.step()\n",
|
522 | 621 | " lr_scheduler.step()\n",
|
523 | 622 | "\n",
|
524 |
| - " logger.info(f\"Train loss: {loss / n_steps}\")\n", |
| 623 | + " logger.info(f\"Train loss: {total_loss / n_steps}\")\n", |
525 | 624 | "\n",
|
526 | 625 | " model.eval()\n",
|
527 | 626 | " total_loss, n_steps = 0, 0\n",
|
|
0 commit comments