Skip to content

Commit 0d5340f

Browse files
committed
test wandb_path and wandb_kwargs in test_trainer
1 parent 265fb60 commit 0d5340f

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tests/test_trainer.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING
44

55
import numpy as np
6+
import pytest
67
import torch
78
from pymatgen.core import Lattice, Structure
89

@@ -36,7 +37,7 @@
3637
)
3738

3839

39-
def test_trainer(tmp_path: Path) -> None:
40+
def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
4041
chgnet = CHGNet.load()
4142
train_loader, val_loader, _test_loader = get_train_val_test_loader(
4243
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
@@ -47,11 +48,13 @@ def test_trainer(tmp_path: Path) -> None:
4748
optimizer="Adam",
4849
criterion="MSE",
4950
learning_rate=1e-2,
50-
epochs=5,
51+
epochs=500,
52+
wandb_path="/",
53+
wandb_kwargs=dict(anonymous="allow"),
5154
)
5255
dir_name = "test_tmp_dir"
5356
test_dir = tmp_path / dir_name
54-
trainer.train(train_loader, val_loader, save_dir=test_dir)
57+
trainer.train(train_loader, val_loader)
5558
for param in chgnet.composition_model.parameters():
5659
assert param.requires_grad is False
5760
assert test_dir.is_dir(), "Training dir was not created"
@@ -63,6 +66,12 @@ def test_trainer(tmp_path: Path) -> None:
6366
n_matches == 1
6467
), f"Expected 1 {prefix} file, found {n_matches} in {output_files}"
6568

69+
# expect ImportError when passing wandb_path without wandb installed
70+
err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging"
71+
with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012
72+
ctx.setattr("chgnet.trainer.trainer.wandb", None)
73+
_ = Trainer(model=chgnet, wandb_path="radicalai/chgnet-test-finetune")
74+
6675

6776
def test_trainer_composition_model(tmp_path: Path) -> None:
6877
chgnet = CHGNet.load()

0 commit comments

Comments
 (0)