|
23 | 23 | "cell_type": "code",
|
24 | 24 | "execution_count": null,
|
25 | 25 | "metadata": {},
|
26 |
| - "outputs": [ |
27 |
| - { |
28 |
| - "name": "stdout", |
29 |
| - "output_type": "stream", |
30 |
| - "text": [ |
31 |
| - "The autoreload extension is already loaded. To reload it, use:\n", |
32 |
| - " %reload_ext autoreload\n" |
33 |
| - ] |
34 |
| - } |
35 |
| - ], |
| 26 | + "outputs": [], |
36 | 27 | "source": [
|
37 | 28 | "#| hide\n",
|
38 | 29 | "%load_ext autoreload\n",
|
|
462 | 453 | " n_rows:int=3, # Number of rows in the grid\n",
|
463 | 454 | " n_cols:int=3 # Number of columns in the grid\n",
|
464 | 455 | " ):\n",
|
465 |
| - " self.plot_grid(self, n_rows, n_cols, self.hf_ds.features['label'].int2str)" |
| 456 | + " self.plot_grid(self, n_rows, n_cols, self.hf_ds.features['label'].int2str)\n", |
| 457 | + " \n", |
| 458 | + " def compute_image_normalization(self):\n", |
| 459 | + " full_dl = DataLoader(self, batch_size=len(self))\n", |
| 460 | + " full_batch = next(iter(full_dl))\n", |
| 461 | + " mean = full_batch[0].float().mean([0,2,3])\n", |
| 462 | + " std = full_batch[0].float().std([0,2,3])\n", |
| 463 | + " return mean, std" |
| 464 | + ] |
| 465 | + }, |
| 466 | + { |
| 467 | + "cell_type": "markdown", |
| 468 | + "metadata": {}, |
| 469 | + "source": [ |
| 470 | + "### Image normalization\n", |
| 471 | + "if both train and validation splits are available normalize both. else just train\n" |
| 472 | + ] |
| 473 | + }, |
| 474 | + { |
| 475 | + "cell_type": "code", |
| 476 | + "execution_count": null, |
| 477 | + "metadata": {}, |
| 478 | + "outputs": [ |
| 479 | + { |
| 480 | + "name": "stderr", |
| 481 | + "output_type": "stream", |
| 482 | + "text": [ |
| 483 | + "[16:22:46] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split train\n", |
| 484 | + "[16:22:46] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split train\n", |
| 485 | + "Overwrite dataset info from restored data version if exists.\n", |
| 486 | + "[16:22:48] INFO - Overwrite dataset info from restored data version if exists.\n", |
| 487 | + "Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 488 | + "[16:22:48] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 489 | + "Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n", |
| 490 | + "[16:22:48] INFO - Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n", |
| 491 | + "Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 492 | + "[16:22:48] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 493 | + "[16:23:04] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split validation\n", |
| 494 | + "[16:23:04] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split validation\n", |
| 495 | + "Overwrite dataset info from restored data version if exists.\n", |
| 496 | + "[16:23:05] INFO - Overwrite dataset info from restored data version if exists.\n", |
| 497 | + "Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 498 | + "[16:23:05] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 499 | + "Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n", |
| 500 | + "[16:23:05] INFO - Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n", |
| 501 | + "Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n", |
| 502 | + "[16:23:05] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n" |
| 503 | + ] |
| 504 | + }, |
| 505 | + { |
| 506 | + "data": { |
| 507 | + "text/html": [ |
| 508 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">mean:<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">tensor</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4822</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4494</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.3978</span><span style=\"font-weight: bold\">])</span>, std: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">tensor</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.2754</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.2679</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.2811</span><span style=\"font-weight: bold\">])</span>\n", |
| 509 | + "</pre>\n" |
| 510 | + ], |
| 511 | + "text/plain": [ |
| 512 | + "mean:\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.4822\u001b[0m, \u001b[1;36m0.4494\u001b[0m, \u001b[1;36m0.3978\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m, std: \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.2754\u001b[0m, \u001b[1;36m0.2679\u001b[0m, \u001b[1;36m0.2811\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" |
| 513 | + ] |
| 514 | + }, |
| 515 | + "metadata": {}, |
| 516 | + "output_type": "display_data" |
| 517 | + } |
| 518 | + ], |
| 519 | + "source": [ |
| 520 | + "#| export\n", |
| 521 | + "\n", |
| 522 | + "def normalize_image_datasets(name, data_dir='../data/image', splits=['train', 'validation']):\n", |
| 523 | + " mean, std = [], []\n", |
| 524 | + " for split in splits:\n", |
| 525 | + "\n", |
| 526 | + " ds = ImageDataset(\n", |
| 527 | + " name=name,\n", |
| 528 | + " data_dir=data_dir,\n", |
| 529 | + " split=split,\n", |
| 530 | + " )\n", |
| 531 | + " m, s = ds.compute_image_normalization()\n", |
| 532 | + " mean.append(m)\n", |
| 533 | + " std.append(s)\n", |
| 534 | + "\n", |
| 535 | + " mean = torch.stack(mean).mean(dim=0)\n", |
| 536 | + " std = torch.stack(std).mean(dim=0)\n", |
| 537 | + " return mean, std" |
| 538 | + ] |
| 539 | + }, |
| 540 | + { |
| 541 | + "cell_type": "code", |
| 542 | + "execution_count": null, |
| 543 | + "metadata": {}, |
| 544 | + "outputs": [], |
| 545 | + "source": [ |
| 546 | + "#| notest\n", |
| 547 | + "\n", |
| 548 | + "mean, std = normalize_image_datasets('slegroux/tiny-imagenet-200-clean')\n", |
| 549 | + "print(f\"mean:{mean}, std: {std}\")" |
466 | 550 | ]
|
467 | 551 | },
|
468 | 552 | {
|
|
0 commit comments