Skip to content

Commit 266c426

Browse files
committed
refactor lr find & 1-cycle
1 parent d0bb928 commit 266c426

8 files changed

+354
-132
lines changed

nbs/image.datasets.ipynb

+3
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,9 @@
925925
" return list(self.train_ds[0][0].shape[-2:])\n",
926926
" raise RuntimeError(\"train_ds is not initialized. Call prepare_data() first.\")\n",
927927
"\n",
928+
" @property\n",
929+
" def name(self)->str:\n",
930+
" return self.hparams.name\n",
928931
"\n",
929932
" @property\n",
930933
" def num_classes(self) -> int: # num of classes in dataset\n",

nbs/models.conv.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"\n",
6161
"from matplotlib import pyplot as plt\n",
6262
"import pandas as pd\n",
63-
"from typing import List, Optional, Type, Callable\n",
63+
"from typing import List, Optional, Type, Callable, Any\n",
6464
"\n",
6565
"from nimrod.utils import get_device, set_seed\n",
6666
"from nimrod.models.core import Classifier\n",
@@ -1659,7 +1659,7 @@
16591659
" nnet:ConvNet, # model\n",
16601660
" num_classes:int, # number of classes\n",
16611661
" optimizer:Callable[...,torch.optim.Optimizer], # optimizer\n",
1662-
" scheduler:Callable[...,torch.optim.lr_scheduler], # scheduler\n",
1662+
" scheduler: Optional[Callable[...,Any]]=None, # scheduler\n",
16631663
" ):\n",
16641664
"\n",
16651665
" logger.info(\"ConvNetX: init\")\n",

nbs/models.diffusion.ipynb

+5-1
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,13 @@
745745
],
746746
"metadata": {
747747
"kernelspec": {
748-
"display_name": "python3",
748+
"display_name": "nimrod",
749749
"language": "python",
750750
"name": "python3"
751+
},
752+
"language_info": {
753+
"name": "python",
754+
"version": "3.11.8"
751755
}
752756
},
753757
"nbformat": 4,

nimrod/_modidx.py

+2
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@
174174
'nimrod/image/datasets.py'),
175175
'nimrod.image.datasets.ImageDataModule.label_names': ( 'image.datasets.html#imagedatamodule.label_names',
176176
'nimrod/image/datasets.py'),
177+
'nimrod.image.datasets.ImageDataModule.name': ( 'image.datasets.html#imagedatamodule.name',
178+
'nimrod/image/datasets.py'),
177179
'nimrod.image.datasets.ImageDataModule.num_classes': ( 'image.datasets.html#imagedatamodule.num_classes',
178180
'nimrod/image/datasets.py'),
179181
'nimrod.image.datasets.ImageDataModule.prepare_data': ( 'image.datasets.html#imagedatamodule.prepare_data',

nimrod/image/datasets.py

+3
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ def dim(self)->List[int]:
323323
return list(self.train_ds[0][0].shape[-2:])
324324
raise RuntimeError("train_ds is not initialized. Call prepare_data() first.")
325325

326+
@property
327+
def name(self)->str:
328+
return self.hparams.name
326329

327330
@property
328331
def num_classes(self) -> int: # num of classes in dataset

nimrod/models/conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from matplotlib import pyplot as plt
2222
import pandas as pd
23-
from typing import List, Optional, Type, Callable
23+
from typing import List, Optional, Type, Callable, Any
2424

2525
from ..utils import get_device, set_seed
2626
from .core import Classifier
@@ -260,7 +260,7 @@ def __init__(
260260
nnet:ConvNet, # model
261261
num_classes:int, # number of classes
262262
optimizer:Callable[...,torch.optim.Optimizer], # optimizer
263-
scheduler:Callable[...,torch.optim.lr_scheduler], # scheduler
263+
scheduler: Optional[Callable[...,Any]]=None, # scheduler
264264
):
265265

266266
logger.info("ConvNetX: init")

tutorials/convnet.ipynb

+328-123
Large diffs are not rendered by default.

tutorials/resnet.ipynb

+9-4
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,14 @@
408408
"source": [
409409
"# LR_FINDER\n",
410410
"\n",
411-
"cfg_model.nnet.n_features = [1, 8, 16, 32, 16]\n",
411+
"\n",
412412
"N_EPOCHS = 5\n",
413+
"# set sched total steps\n",
413414
"cfg_sched.total_steps = len(dm.train_dataloader()) * N_EPOCHS\n",
414-
"\n",
415415
"scheduler = instantiate(cfg_sched)\n",
416+
"\n",
417+
"# instantiate model with scheduler\n",
418+
"cfg_model.nnet.n_features = [1, 8, 16, 32, 16]\n",
416419
"model = instantiate(cfg_model)(optimizer=optimizer, scheduler=scheduler)\n",
417420
"\n",
418421
"trainer = Trainer(\n",
@@ -748,14 +751,16 @@
748751
"source": [
749752
"# ONE-CYCLE TRAIN\n",
750753
"\n",
751-
"cfg_model.nnet.n_features = [1, 8, 16, 32, 16]\n",
752754
"N_EPOCHS = 5\n",
755+
"\n",
753756
"cfg_sched.total_steps = len(dm.train_dataloader()) * N_EPOCHS\n",
754757
"cfg_sched.max_lr = lr_finder.suggestion()\n",
755758
"\n",
759+
"cfg_model.nnet.n_features = [1, 8, 16, 32, 16]\n",
760+
"\n",
756761
"wandb_logger = WandbLogger(\n",
757762
" project=\"MNIST Classification\",\n",
758-
" name=f\"resnetx-bs:{dm.batch_size}-epochs:{N_EPOCHS}-features:{cfg_model.nnet.n_features}\",\n",
763+
" name=f\"ResnetX-bs:{dm.batch_size}-epochs:{N_EPOCHS}-features:{cfg_model.nnet.n_features}\",\n",
759764
" save_dir='wandb',\n",
760765
" entity='slegroux',\n",
761766
" tags=['arch', 'dev'],\n",

0 commit comments

Comments
 (0)