Skip to content

Commit 111216c

Browse files
committed
add specific param logging to wandb; update python logging
1 parent 34637f6 commit 111216c

8 files changed

+177
-128
lines changed

nbs/image.datasets.ipynb

+4-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@
6666
"\n",
6767
"from typing import Any, Dict, Optional, Tuple, List\n",
6868
"from nimrod.data.core import DataModule\n",
69-
"from nimrod.utils import logger, set_seed"
69+
"from nimrod.utils import set_seed\n",
70+
"\n",
71+
"import logging\n",
72+
"logger = logging.getLogger(__name__)"
7073
]
7174
},
7275
{

nbs/models.conv.ipynb

+5-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@
5656
"from omegaconf import OmegaConf\n",
5757
"\n",
5858
"from nimrod.image.datasets import MNISTDataModule\n",
59-
"from nimrod.utils import get_device, logger\n",
60-
"from nimrod.models.core import Classifier"
59+
"from nimrod.utils import get_device\n",
60+
"from nimrod.models.core import Classifier\n",
61+
"\n",
62+
"import logging\n",
63+
"logger = logging.getLogger(__name__)"
6164
]
6265
},
6366
{

nbs/models.core.ipynb

+4-2
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
"import torch\n",
4343
"\n",
4444
"from abc import ABC, abstractmethod\n",
45-
"from nimrod.utils import logger\n",
45+
"# from nimrod.utils import logger\n",
4646
"\n",
47-
"from torchmetrics import Accuracy\n"
47+
"from torchmetrics import Accuracy\n",
48+
"import logging\n",
49+
"logger = logging.getLogger(__name__)\n"
4850
]
4951
},
5052
{

nbs/models.lm.ipynb

+112-94
Large diffs are not rendered by default.

nbs/models.mlp.ipynb

+4-2
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@
5252
"\n",
5353
"from nimrod.utils import get_device\n",
5454
"from nimrod.image.datasets import MNISTDataModule\n",
55-
"from nimrod.utils import logger\n",
5655
"from nimrod.models.core import Classifier\n",
5756
"# torch.set_num_interop_threads(1)\n",
58-
"# from IPython.core.debugger import set_trace"
57+
"# from IPython.core.debugger import set_trace\n",
58+
"\n",
59+
"import logging\n",
60+
"logger = logging.getLogger(__name__)"
5961
]
6062
},
6163
{

nbs/text.datasets.ipynb

+5-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@
7676
"\n",
7777
"# nimrod\n",
7878
"# from nimrod.models.lm import Vocab\n",
79-
"from nimrod.utils import set_seed, logger\n",
80-
"from nimrod.data.core import DataModule\n"
79+
"from nimrod.utils import set_seed\n",
80+
"from nimrod.data.core import DataModule\n",
81+
"\n",
82+
"import logging\n",
83+
"logger = logging.getLogger(__name__)\n"
8184
]
8285
},
8386
{

nbs/utils.ipynb

+11-5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"import random\n",
5252
"import os\n",
5353
"import logging\n",
54+
"from rich.logging import RichHandler\n",
5455
"import lightning as L"
5556
]
5657
},
@@ -87,8 +88,8 @@
8788
"name": "stdout",
8889
"output_type": "stream",
8990
"text": [
90-
"cpu\n",
91-
"Is MPS (Metal Performance Shader) built? False\n"
91+
"mps\n",
92+
"Is MPS (Metal Performance Shader) built? True\n"
9293
]
9394
}
9495
],
@@ -123,7 +124,7 @@
123124
" # # Set a fixed value for the hash seed\n",
124125
" # os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
125126
" # print(f\"Random seed set as {seed}\")\n",
126-
" L.seed_everything(seed)"
127+
" L.seed_everything(seed, workers=True)"
127128
]
128129
},
129130
{
@@ -166,10 +167,15 @@
166167
"#| export\n",
167168
"\n",
168169
"# Configure the logger\n",
169-
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
170+
"logging.basicConfig(\n",
171+
" level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s',\n",
172+
" handlers=[RichHandler(rich_tracebacks=True)]\n",
173+
" )\n",
170174
"\n",
171175
"# Create a logger\n",
172-
"logger = logging.getLogger(__name__)"
176+
"# logger = logging.getLogger(__name__)\n",
177+
"def get_logger(name=__name__):\n",
178+
" return logging.getLogger(name)"
173179
]
174180
},
175181
{

recipes/image/mnist/train.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from omegaconf import DictConfig, OmegaConf
55
import hydra
66
from hydra.utils import instantiate
7-
import wandb
7+
from nimrod.utils import set_seed
8+
import json
9+
import logging
10+
log = logging.getLogger(__name__)
811

912
@hydra.main(version_base="1.3",config_path="conf", config_name="train_mlp.yaml")
1013
def main(cfg: DictConfig) -> dict:
@@ -15,12 +18,10 @@ def main(cfg: DictConfig) -> dict:
1518
hp = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
1619

1720
# SEED
18-
19-
L.seed_everything(cfg.seed, workers=True)
21+
set_seed(cfg.seed)
2022

2123
# MODEL
2224
model = instantiate(cfg.model)
23-
# from IPython import embed; embed()
2425

2526
# DATA
2627
datamodule = instantiate(cfg.datamodule)
@@ -31,42 +32,53 @@ def main(cfg: DictConfig) -> dict:
3132
callbacks.append(instantiate(cb_conf))
3233

3334
loggers = []
35+
# logger.info("Instantiating logger <{}>".format(cfg.logger._target_))
3436
for log_conf in cfg.loggers:
3537
logger = instantiate(cfg[log_conf])
3638
# wandb logger special setup
3739
if isinstance(logger, L.pytorch.loggers.WandbLogger):
3840
# deal with hangs when hp optim multirun training
3941
# wandb.init(settings=wandb.Settings(start_method="thread"))
4042
# wandb requires dict not DictConfig
41-
logger.experiment.config.update(hp)
43+
logger.experiment.config.update(hp["datamodule"], allow_val_change=True)
44+
logger.experiment.config.update(hp["model"], allow_val_change=True)
4245
loggers.append(logger)
46+
47+
# print(json.dumps(hp, indent=4))
4348

4449
# trainer
4550
profiler = instantiate(cfg.profiler)
46-
trainer = instantiate(cfg.trainer, callbacks=callbacks, profiler=profiler, logger=[logger])
47-
trainer.logger.log_hyperparams(hp)
48-
49-
# lr finder
50-
# tuner = Tuner(trainer)
51+
trainer = instantiate(cfg.trainer, callbacks=callbacks, profiler=profiler, logger=loggers)
52+
# trainer.logger.log_hyperparams(hp)
5153

52-
# tuner.scale_batch_size(model, datamodule=datamodule, mode="power")
53-
# lr_finder = tuner.lr_find(model,datamodule=datamodule)
54-
# print(lr_finder.results)
55-
# # Plot with
56-
# fig = lr_finder.plot(suggest=True)
57-
# fig.show()
58-
# new_lr = lr_finder.suggestion()
59-
# model.hparams.lr = new_lr
54+
# batch size & lr optimization
55+
tuner = Tuner(trainer)
56+
if cfg.get("bs_finder"):
57+
tuner.scale_batch_size(model, datamodule=datamodule, mode="power", init_val= 65536)
58+
if isinstance(logger, L.pytorch.loggers.WandbLogger):
59+
# bs is automatically updated in L datamodule but we need to manuallyupdate it in wandb
60+
logger.experiment.config.update({"batch_size": datamodule.hparams.batch_size}, allow_val_change=True)
61+
62+
if cfg.get("lr_finder"):
63+
lr_finder = tuner.lr_find(model, datamodule=datamodule)
64+
print(lr_finder.results)
65+
# Plot with
66+
fig = lr_finder.plot(suggest=True)
67+
fig.show()
68+
new_lr = lr_finder.suggestion()
69+
model.hparams.lr = new_lr
70+
if isinstance(logger, L.pytorch.loggers.WandbLogger):
71+
logger.experiment.config.update({"lr": new_lr}, allow_val_change=True)
6072

6173

6274
if cfg.get("train"):
6375
# trainer.fit(model=autoencoder_pl, train_dataloaders=train_dl, val_dataloaders=dev_dl, ckpt_path=cfg.get("ckpt_path"))
6476
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
6577

6678
# # TEST
67-
# if cfg.get("test"):
79+
if cfg.get("test"):
6880
# # trainer.test(autoencoder_pl, dataloaders=test_dl)
69-
# trainer.test(datamodule=datamodule, ckpt_path="best")
81+
trainer.test(datamodule=datamodule, ckpt_path="best")
7082

7183
# wandb.finish()
7284

0 commit comments

Comments
 (0)