Skip to content

Commit b9aab51

Browse files
author
Sylvain Le Groux
committed
re-compute image normalization m&std
1 parent 4b505ce commit b9aab51

File tree

6 files changed

+797
-371
lines changed

6 files changed

+797
-371
lines changed

config/data/image/tiny_imagenet.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ transforms:
1111
transforms:
1212
- _target_: torchvision.transforms.ToTensor
1313
- _target_: torchvision.transforms.Normalize
14-
mean: [0.4822, 0.4495, 0.3985]
15-
std: [0.2771, 0.2690, 0.2826]
14+
mean: [0.4822, 0.4494, 0.3978]
15+
std: [0.2754, 0.2679, 0.2811]
16+
# mean: [0.4822, 0.4495, 0.3985]
17+
# std: [0.2771, 0.2690, 0.2826]
1618
# - _target_: torchvision.transforms.Resize
1719
# size: [32,32]
1820

config/data/image/tiny_imagenet_aug.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ transforms:
1111
transforms:
1212
- _target_: torchvision.transforms.ToTensor
1313
- _target_: torchvision.transforms.Normalize
14-
mean: [0.4822, 0.4495, 0.3985]
15-
std: [0.2771, 0.2690, 0.2826]
14+
mean: [0.4822, 0.4494, 0.3978]
15+
std: [0.2754, 0.2679, 0.2811]
1616
- _target_: torchvision.transforms.Resize
1717
size: 64
1818
- _target_: torchvision.transforms.RandomCrop

nbs/image.datasets.ipynb

+95-11
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,7 @@
2323
"cell_type": "code",
2424
"execution_count": null,
2525
"metadata": {},
26-
"outputs": [
27-
{
28-
"name": "stdout",
29-
"output_type": "stream",
30-
"text": [
31-
"The autoreload extension is already loaded. To reload it, use:\n",
32-
" %reload_ext autoreload\n"
33-
]
34-
}
35-
],
26+
"outputs": [],
3627
"source": [
3728
"#| hide\n",
3829
"%load_ext autoreload\n",
@@ -462,7 +453,100 @@
462453
" n_rows:int=3, # Number of rows in the grid\n",
463454
" n_cols:int=3 # Number of columns in the grid\n",
464455
" ):\n",
465-
" self.plot_grid(self, n_rows, n_cols, self.hf_ds.features['label'].int2str)"
456+
" self.plot_grid(self, n_rows, n_cols, self.hf_ds.features['label'].int2str)\n",
457+
" \n",
458+
" def compute_image_normalization(self):\n",
459+
" full_dl = DataLoader(self, batch_size=len(self))\n",
460+
" full_batch = next(iter(full_dl))\n",
461+
" mean = full_batch[0].float().mean([0,2,3])\n",
462+
" std = full_batch[0].float().std([0,2,3])\n",
463+
" return mean, std"
464+
]
465+
},
466+
{
467+
"cell_type": "markdown",
468+
"metadata": {},
469+
"source": [
470+
"### Image normalization\n",
471+
"if both train and validation splits are available normalize both. else just train\n"
472+
]
473+
},
474+
{
475+
"cell_type": "code",
476+
"execution_count": null,
477+
"metadata": {},
478+
"outputs": [
479+
{
480+
"name": "stderr",
481+
"output_type": "stream",
482+
"text": [
483+
"[16:22:46] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split train\n",
484+
"[16:22:46] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split train\n",
485+
"Overwrite dataset info from restored data version if exists.\n",
486+
"[16:22:48] INFO - Overwrite dataset info from restored data version if exists.\n",
487+
"Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
488+
"[16:22:48] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
489+
"Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n",
490+
"[16:22:48] INFO - Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n",
491+
"Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
492+
"[16:22:48] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
493+
"[16:23:04] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split validation\n",
494+
"[16:23:04] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split validation\n",
495+
"Overwrite dataset info from restored data version if exists.\n",
496+
"[16:23:05] INFO - Overwrite dataset info from restored data version if exists.\n",
497+
"Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
498+
"[16:23:05] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
499+
"Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n",
500+
"[16:23:05] INFO - Found cached dataset tiny-imagenet-200-clean (/user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)\n",
501+
"Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n",
502+
"[16:23:05] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2\n"
503+
]
504+
},
505+
{
506+
"data": {
507+
"text/html": [
508+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">mean:<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">tensor</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4822</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4494</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.3978</span><span style=\"font-weight: bold\">])</span>, std: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">tensor</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.2754</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.2679</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.2811</span><span style=\"font-weight: bold\">])</span>\n",
509+
"</pre>\n"
510+
],
511+
"text/plain": [
512+
"mean:\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.4822\u001b[0m, \u001b[1;36m0.4494\u001b[0m, \u001b[1;36m0.3978\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m, std: \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.2754\u001b[0m, \u001b[1;36m0.2679\u001b[0m, \u001b[1;36m0.2811\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n"
513+
]
514+
},
515+
"metadata": {},
516+
"output_type": "display_data"
517+
}
518+
],
519+
"source": [
520+
"#| export\n",
521+
"\n",
522+
"def normalize_image_datasets(name, data_dir='../data/image', splits=['train', 'validation']):\n",
523+
" mean, std = [], []\n",
524+
" for split in splits:\n",
525+
"\n",
526+
" ds = ImageDataset(\n",
527+
" name=name,\n",
528+
" data_dir=data_dir,\n",
529+
" split=split,\n",
530+
" )\n",
531+
" m, s = ds.compute_image_normalization()\n",
532+
" mean.append(m)\n",
533+
" std.append(s)\n",
534+
"\n",
535+
" mean = torch.stack(mean).mean(dim=0)\n",
536+
" std = torch.stack(std).mean(dim=0)\n",
537+
" return mean, std"
538+
]
539+
},
540+
{
541+
"cell_type": "code",
542+
"execution_count": null,
543+
"metadata": {},
544+
"outputs": [],
545+
"source": [
546+
"#| notest\n",
547+
"\n",
548+
"mean, std = normalize_image_datasets('slegroux/tiny-imagenet-200-clean')\n",
549+
"print(f\"mean:{mean}, std: {std}\")"
466550
]
467551
},
468552
{

nimrod/_modidx.py

+4
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@
195195
'nimrod/image/datasets.py'),
196196
'nimrod.image.datasets.ImageDataset.__len__': ( 'image.datasets.html#imagedataset.__len__',
197197
'nimrod/image/datasets.py'),
198+
'nimrod.image.datasets.ImageDataset.compute_image_normalization': ( 'image.datasets.html#imagedataset.compute_image_normalization',
199+
'nimrod/image/datasets.py'),
198200
'nimrod.image.datasets.ImageDataset.dim': ( 'image.datasets.html#imagedataset.dim',
199201
'nimrod/image/datasets.py'),
200202
'nimrod.image.datasets.ImageDataset.int2str': ( 'image.datasets.html#imagedataset.int2str',
@@ -232,6 +234,8 @@
232234
'nimrod.image.datasets.ImageSuperResDataset.__init__': ( 'image.datasets.html#imagesuperresdataset.__init__',
233235
'nimrod/image/datasets.py'),
234236
'nimrod.image.datasets.make_grid': ('image.datasets.html#make_grid', 'nimrod/image/datasets.py'),
237+
'nimrod.image.datasets.normalize_image_datasets': ( 'image.datasets.html#normalize_image_datasets',
238+
'nimrod/image/datasets.py'),
235239
'nimrod.image.datasets.show_images': ( 'image.datasets.html#show_images',
236240
'nimrod/image/datasets.py')},
237241
'nimrod.image.med': { 'nimrod.image.med.BertAttention': ('image.med.html#bertattention', 'nimrod/image/med.py'),

nimrod/image/datasets.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/image.datasets.ipynb.
44

55
# %% auto 0
6-
__all__ = ['logger', 'TFM_LOW_RES', 'show_images', 'make_grid', 'ImagePlotMixin', 'ImageDataset', 'ImageDataModule',
7-
'ImageSuperResDataset', 'ImageSuperResDataModule']
6+
__all__ = ['logger', 'TFM_LOW_RES', 'show_images', 'make_grid', 'ImagePlotMixin', 'ImageDataset', 'normalize_image_datasets',
7+
'ImageDataModule', 'ImageSuperResDataset', 'ImageSuperResDataModule']
88

99
# %% ../../nbs/image.datasets.ipynb 3
1010
# torch
@@ -346,8 +346,33 @@ def show_grid(
346346
n_cols:int=3 # Number of columns in the grid
347347
):
348348
self.plot_grid(self, n_rows, n_cols, self.hf_ds.features['label'].int2str)
349+
350+
def compute_image_normalization(self):
351+
full_dl = DataLoader(self, batch_size=len(self))
352+
full_batch = next(iter(full_dl))
353+
mean = full_batch[0].float().mean([0,2,3])
354+
std = full_batch[0].float().std([0,2,3])
355+
return mean, std
349356

350357
# %% ../../nbs/image.datasets.ipynb 12
358+
def normalize_image_datasets(name, data_dir='../data/image', splits=['train', 'validation']):
359+
mean, std = [], []
360+
for split in splits:
361+
362+
ds = ImageDataset(
363+
name=name,
364+
data_dir=data_dir,
365+
split=split,
366+
)
367+
m, s = ds.compute_image_normalization()
368+
mean.append(m)
369+
std.append(s)
370+
371+
mean = torch.stack(mean).mean(dim=0)
372+
std = torch.stack(std).mean(dim=0)
373+
return mean, std
374+
375+
# %% ../../nbs/image.datasets.ipynb 15
351376
class ImageDataModule(ImagePlotMixin, DataModule):
352377

353378
def __init__(
@@ -510,7 +535,7 @@ def show_batch(
510535
return grid_im
511536

512537

513-
# %% ../../nbs/image.datasets.ipynb 23
538+
# %% ../../nbs/image.datasets.ipynb 26
514539
TFM_LOW_RES = transforms.Compose(
515540
[
516541
transforms.Resize((32, 32), antialias=True),
@@ -549,7 +574,7 @@ def __getitem__(self, idx:int) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
549574
return image_x, image_y
550575

551576

552-
# %% ../../nbs/image.datasets.ipynb 27
577+
# %% ../../nbs/image.datasets.ipynb 30
553578
TFM_LOW_RES = nn.Sequential(
554579
transforms.Resize((32, 32), antialias=True),
555580
transforms.Resize((64, 64), antialias=True)

0 commit comments

Comments
 (0)