4
4
from omegaconf import DictConfig , OmegaConf
5
5
import hydra
6
6
from hydra .utils import instantiate
7
- import wandb
7
+ from nimrod .utils import set_seed
8
+ import json
9
+ import logging
10
+ log = logging .getLogger (__name__ )
8
11
9
12
@hydra .main (version_base = "1.3" ,config_path = "conf" , config_name = "train_mlp.yaml" )
10
13
def main (cfg : DictConfig ) -> dict :
@@ -15,12 +18,10 @@ def main(cfg: DictConfig) -> dict:
15
18
hp = OmegaConf .to_container (cfg , resolve = True , throw_on_missing = True )
16
19
17
20
# SEED
18
-
19
- L .seed_everything (cfg .seed , workers = True )
21
+ set_seed (cfg .seed )
20
22
21
23
# MODEL
22
24
model = instantiate (cfg .model )
23
- # from IPython import embed; embed()
24
25
25
26
# DATA
26
27
datamodule = instantiate (cfg .datamodule )
@@ -31,42 +32,53 @@ def main(cfg: DictConfig) -> dict:
31
32
callbacks .append (instantiate (cb_conf ))
32
33
33
34
loggers = []
35
+ # logger.info("Instantiating logger <{}>".format(cfg.logger._target_))
34
36
for log_conf in cfg .loggers :
35
37
logger = instantiate (cfg [log_conf ])
36
38
# wandb logger special setup
37
39
if isinstance (logger , L .pytorch .loggers .WandbLogger ):
38
40
# deal with hangs when hp optim multirun training
39
41
# wandb.init(settings=wandb.Settings(start_method="thread"))
40
42
# wandb requires dict not DictConfig
41
- logger .experiment .config .update (hp )
43
+ logger .experiment .config .update (hp ["datamodule" ], allow_val_change = True )
44
+ logger .experiment .config .update (hp ["model" ], allow_val_change = True )
42
45
loggers .append (logger )
46
+
47
+ # print(json.dumps(hp, indent=4))
43
48
44
49
# trainer
45
50
profiler = instantiate (cfg .profiler )
46
- trainer = instantiate (cfg .trainer , callbacks = callbacks , profiler = profiler , logger = [logger ])
47
- trainer .logger .log_hyperparams (hp )
48
-
49
- # lr finder
50
- # tuner = Tuner(trainer)
51
+ trainer = instantiate (cfg .trainer , callbacks = callbacks , profiler = profiler , logger = loggers )
52
+ # trainer.logger.log_hyperparams(hp)
51
53
52
- # tuner.scale_batch_size(model, datamodule=datamodule, mode="power")
53
- # lr_finder = tuner.lr_find(model,datamodule=datamodule)
54
- # print(lr_finder.results)
55
- # # Plot with
56
- # fig = lr_finder.plot(suggest=True)
57
- # fig.show()
58
- # new_lr = lr_finder.suggestion()
59
- # model.hparams.lr = new_lr
54
+ # batch size & lr optimization
55
+ tuner = Tuner (trainer )
56
+ if cfg .get ("bs_finder" ):
57
+ tuner .scale_batch_size (model , datamodule = datamodule , mode = "power" , init_val = 65536 )
58
+ if isinstance (logger , L .pytorch .loggers .WandbLogger ):
59
+ # bs is automatically updated in L datamodule but we need to manuallyupdate it in wandb
60
+ logger .experiment .config .update ({"batch_size" : datamodule .hparams .batch_size }, allow_val_change = True )
61
+
62
+ if cfg .get ("lr_finder" ):
63
+ lr_finder = tuner .lr_find (model , datamodule = datamodule )
64
+ print (lr_finder .results )
65
+ # Plot with
66
+ fig = lr_finder .plot (suggest = True )
67
+ fig .show ()
68
+ new_lr = lr_finder .suggestion ()
69
+ model .hparams .lr = new_lr
70
+ if isinstance (logger , L .pytorch .loggers .WandbLogger ):
71
+ logger .experiment .config .update ({"lr" : new_lr }, allow_val_change = True )
60
72
61
73
62
74
if cfg .get ("train" ):
63
75
# trainer.fit(model=autoencoder_pl, train_dataloaders=train_dl, val_dataloaders=dev_dl, ckpt_path=cfg.get("ckpt_path"))
64
76
trainer .fit (model , datamodule = datamodule , ckpt_path = cfg .get ("ckpt_path" ))
65
77
66
78
# # TEST
67
- # if cfg.get("test"):
79
+ if cfg .get ("test" ):
68
80
# # trainer.test(autoencoder_pl, dataloaders=test_dl)
69
- # trainer.test(datamodule=datamodule, ckpt_path="best")
81
+ trainer .test (datamodule = datamodule , ckpt_path = "best" )
70
82
71
83
# wandb.finish()
72
84
0 commit comments