21
21
from chgnet .model .model import CHGNet
22
22
from chgnet .utils import AverageMeter , determine_device , mae , write_json
23
23
24
+ try :
25
+ import wandb
26
+ except ImportError :
27
+ wandb = None
28
+
29
+
24
30
if TYPE_CHECKING :
25
31
from torch .utils .data import DataLoader
26
32
@@ -50,6 +56,9 @@ def __init__(
50
56
data_seed : int | None = None ,
51
57
use_device : str | None = None ,
52
58
check_cuda_mem : bool = False ,
59
+ wandb_path : str | None = None ,
60
+ wandb_init_kwargs : dict | None = None ,
61
+ extra_run_config : dict | None = None ,
53
62
** kwargs ,
54
63
) -> None :
55
64
"""Initialize all hyper-parameters for trainer.
@@ -88,15 +97,22 @@ def __init__(
88
97
Default = None
89
98
check_cuda_mem (bool): Whether to use cuda with most available memory
90
99
Default = False
100
+ wandb_path (str | None): The project and run name separated by a slash:
101
+ "project/run_name". If None, wandb logging is not used.
102
+ Default = None
103
+ wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init.
104
+ Default = None
105
+ extra_run_config (dict): Additional hyper-params to be recorded by wandb
106
+ that are not included in the trainer_args. Default = None
107
+
91
108
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
92
109
"""
93
110
# Store trainer args for reproducibility
94
111
self .trainer_args = {
95
112
k : v
96
113
for k , v in locals ().items ()
97
114
if k not in {"self" , "__class__" , "model" , "kwargs" }
98
- }
99
- self .trainer_args .update (kwargs )
115
+ } | kwargs
100
116
101
117
self .model = model
102
118
self .targets = targets
@@ -195,6 +211,27 @@ def __init__(
195
211
] = {key : {"train" : [], "val" : [], "test" : []} for key in self .targets }
196
212
self .best_model = None
197
213
214
+ # Initialize wandb if project/run specified
215
+ if wandb_path :
216
+ if wandb is None :
217
+ raise ImportError (
218
+ "Weights and Biases not installed. pip install wandb to use "
219
+ "wandb logging."
220
+ )
221
+ if wandb_path .count ("/" ) == 1 :
222
+ project , run_name = wandb_path .split ("/" )
223
+ else :
224
+ raise ValueError (
225
+ f"{ wandb_path = } should be in the format 'project/run_name' "
226
+ "(no extra slashes)"
227
+ )
228
+ wandb .init (
229
+ project = project ,
230
+ name = run_name ,
231
+ config = self .trainer_args | (extra_run_config or {}),
232
+ ** (wandb_init_kwargs or {}),
233
+ )
234
+
198
235
def train (
199
236
self ,
200
237
train_loader : DataLoader ,
@@ -257,6 +294,13 @@ def train(
257
294
258
295
self .save_checkpoint (epoch , val_mae , save_dir = save_dir )
259
296
297
+ # Log train/val metrics to wandb
298
+ if wandb is not None and self .trainer_args .get ("wandb_path" ):
299
+ wandb .log (
300
+ {f"train_{ k } _mae" : v for k , v in train_mae .items ()}
301
+ | {f"val_{ k } _mae" : v for k , v in val_mae .items ()}
302
+ )
303
+
260
304
if test_loader is not None :
261
305
# test best model
262
306
print ("---------Evaluate Model on Test Set---------------" )
@@ -279,6 +323,10 @@ def train(
279
323
self .training_history [key ]["test" ] = test_mae [key ]
280
324
self .save (filename = os .path .join (save_dir , test_file ))
281
325
326
+ # Log test metrics to wandb
327
+ if wandb is not None and self .trainer_args .get ("wandb_path" ):
328
+ wandb .log ({f"test_{ k } _mae" : v for k , v in test_mae .items ()})
329
+
282
330
def _train (self , train_loader : DataLoader , current_epoch : int ) -> dict :
283
331
"""Train all data for one epoch.
284
332
0 commit comments