3
3
from typing import TYPE_CHECKING
4
4
5
5
import numpy as np
6
+ import pytest
6
7
import torch
7
8
from pymatgen .core import Lattice , Structure
8
9
36
37
)
37
38
38
39
39
- def test_trainer (tmp_path : Path ) -> None :
40
+ def test_trainer (tmp_path : Path , monkeypatch : pytest . MonkeyPatch ) -> None :
40
41
chgnet = CHGNet .load ()
41
42
train_loader , val_loader , _test_loader = get_train_val_test_loader (
42
43
data , batch_size = 16 , train_ratio = 0.9 , val_ratio = 0.05
@@ -47,11 +48,13 @@ def test_trainer(tmp_path: Path) -> None:
47
48
optimizer = "Adam" ,
48
49
criterion = "MSE" ,
49
50
learning_rate = 1e-2 ,
50
- epochs = 5 ,
51
+ epochs = 500 ,
52
+ wandb_path = "/" ,
53
+ wandb_kwargs = dict (anonymous = "allow" ),
51
54
)
52
55
dir_name = "test_tmp_dir"
53
56
test_dir = tmp_path / dir_name
54
- trainer .train (train_loader , val_loader , save_dir = test_dir )
57
+ trainer .train (train_loader , val_loader )
55
58
for param in chgnet .composition_model .parameters ():
56
59
assert param .requires_grad is False
57
60
assert test_dir .is_dir (), "Training dir was not created"
@@ -63,6 +66,12 @@ def test_trainer(tmp_path: Path) -> None:
63
66
n_matches == 1
64
67
), f"Expected 1 { prefix } file, found { n_matches } in { output_files } "
65
68
69
+ # expect ImportError when passing wandb_path without wandb installed
70
+ err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging"
71
+ with monkeypatch .context () as ctx , pytest .raises (ImportError , match = err_msg ): # noqa: PT012
72
+ ctx .setattr ("chgnet.trainer.trainer.wandb" , None )
73
+ _ = Trainer (model = chgnet , wandb_path = "radicalai/chgnet-test-finetune" )
74
+
66
75
67
76
def test_trainer_composition_model (tmp_path : Path ) -> None :
68
77
chgnet = CHGNet .load ()
0 commit comments