@@ -110,6 +110,12 @@ def __init__(
110
110
that are not included in the trainer_args. Default = None
111
111
112
112
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
113
+
114
+ Raises:
115
+ NotImplementedError: If the optimizer or scheduler is not implemented
116
+ ImportError: If wandb_path is specified but wandb is not installed
117
+ ValueError: If wandb_path is specified but not in the format
118
+ 'project/run_name'
113
119
"""
114
120
# Store trainer args for reproducibility
115
121
self .trainer_args = {
@@ -271,6 +277,9 @@ def train(
271
277
wandb_log_freq ("epoch" | "batch"): Frequency of logging to wandb.
272
278
'epoch' logs once per epoch, 'batch' logs after every batch.
273
279
Default = "batch"
280
+
281
+ Raises:
282
+ ValueError: If model is not initialized
274
283
"""
275
284
if self .model is None :
276
285
raise ValueError ("Model needs to be initialized" )
@@ -579,7 +588,11 @@ def _validate(
579
588
return {k : round (mae_error .avg , 6 ) for k , mae_error in mae_errors .items ()}
580
589
581
590
def get_best_model (self ) -> CHGNet :
582
- """Get best model recorded in the trainer."""
591
+ """Get best model recorded in the trainer.
592
+
593
+ Returns:
594
+ CHGNet: the model with lowest validation set energy error
595
+ """
583
596
if self .best_model is None :
584
597
raise RuntimeError ("the model needs to be trained first" )
585
598
MAE = min (self .training_history ["e" ]["val" ]) # noqa: N806
@@ -649,7 +662,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None:
649
662
650
663
@classmethod
651
664
def load (cls , path : str ) -> Self :
652
- """Load trainer state_dict."""
665
+ """Load trainer state_dict.
666
+
667
+ Args:
668
+ path (str): path to the saved model
669
+
670
+ Returns:
671
+ Trainer: the loaded trainer
672
+ """
653
673
state = torch .load (path , map_location = torch .device ("cpu" ))
654
674
model = CHGNet .from_dict (state ["model" ])
655
675
print (f"Loaded model params = { sum (p .numel () for p in model .parameters ()):,} " )
@@ -664,8 +684,21 @@ def load(cls, path: str) -> Self:
664
684
return trainer
665
685
666
686
@staticmethod
667
- def move_to (obj , device ) -> Tensor | list [Tensor ]:
668
- """Move object to device."""
687
+ def move_to (
688
+ obj : Tensor | list [Tensor ], device : torch .device
689
+ ) -> Tensor | list [Tensor ]:
690
+ """Move object to device.
691
+
692
+ Args:
693
+ obj (Tensor | list[Tensor]): object(s) to move to device
694
+ device (torch.device): device to move object to
695
+
696
+ Raises:
697
+ TypeError: if obj is not a tensor or list of tensors
698
+
699
+ Returns:
700
+ Tensor | list[Tensor]: moved object(s)
701
+ """
669
702
if torch .is_tensor (obj ):
670
703
return obj .to (device )
671
704
if isinstance (obj , list ):
0 commit comments