Skip to content

Commit 9717a32

Browse files
authored
Add wandb logging support to Trainer class (#166)
* add wandb logging to Trainer class * add optional deps set logging = ["wandb>=0.17"] and install in CI * test wandb_path and wandb_kwargs in test_trainer * only wandb.log if "wandb_path" in trainer_args * add and test extra_run_config: dict | None = None keyword to Trainer to specify run params like batch_size that aren't already recorded by the trainer_args dict * fix ruff isort.known-third-party = ["wandb"]
1 parent d3f1b30 commit 9717a32

File tree

5 files changed

+70
-4
lines changed

5 files changed

+70
-4
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
3939
python setup.py build_ext --inplace
4040
41-
uv pip install -e .[test] --system --resolution=${{ matrix.version.resolution }}
41+
uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }}
4242
4343
- name: Run Tests
4444
run: pytest --capture=no --cov --cov-report=xml

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ coverage.xml
2727
.ipynb_checkpoints
2828
bond_graph_error.cif
2929
test.py
30+
31+
# training logs
32+
wandb

chgnet/trainer/trainer.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
from chgnet.model.model import CHGNet
2222
from chgnet.utils import AverageMeter, determine_device, mae, write_json
2323

24+
try:
25+
import wandb
26+
except ImportError:
27+
wandb = None
28+
29+
2430
if TYPE_CHECKING:
2531
from torch.utils.data import DataLoader
2632

@@ -50,6 +56,9 @@ def __init__(
5056
data_seed: int | None = None,
5157
use_device: str | None = None,
5258
check_cuda_mem: bool = False,
59+
wandb_path: str | None = None,
60+
wandb_init_kwargs: dict | None = None,
61+
extra_run_config: dict | None = None,
5362
**kwargs,
5463
) -> None:
5564
"""Initialize all hyper-parameters for trainer.
@@ -88,15 +97,22 @@ def __init__(
8897
Default = None
8998
check_cuda_mem (bool): Whether to use cuda with most available memory
9099
Default = False
100+
wandb_path (str | None): The project and run name separated by a slash:
101+
"project/run_name". If None, wandb logging is not used.
102+
Default = None
103+
wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init.
104+
Default = None
105+
extra_run_config (dict): Additional hyper-params to be recorded by wandb
106+
that are not included in the trainer_args. Default = None
107+
91108
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
92109
"""
93110
# Store trainer args for reproducibility
94111
self.trainer_args = {
95112
k: v
96113
for k, v in locals().items()
97114
if k not in {"self", "__class__", "model", "kwargs"}
98-
}
99-
self.trainer_args.update(kwargs)
115+
} | kwargs
100116

101117
self.model = model
102118
self.targets = targets
@@ -195,6 +211,27 @@ def __init__(
195211
] = {key: {"train": [], "val": [], "test": []} for key in self.targets}
196212
self.best_model = None
197213

214+
# Initialize wandb if project/run specified
215+
if wandb_path:
216+
if wandb is None:
217+
raise ImportError(
218+
"Weights and Biases not installed. pip install wandb to use "
219+
"wandb logging."
220+
)
221+
if wandb_path.count("/") == 1:
222+
project, run_name = wandb_path.split("/")
223+
else:
224+
raise ValueError(
225+
f"{wandb_path=} should be in the format 'project/run_name' "
226+
"(no extra slashes)"
227+
)
228+
wandb.init(
229+
project=project,
230+
name=run_name,
231+
config=self.trainer_args | (extra_run_config or {}),
232+
**(wandb_init_kwargs or {}),
233+
)
234+
198235
def train(
199236
self,
200237
train_loader: DataLoader,
@@ -257,6 +294,13 @@ def train(
257294

258295
self.save_checkpoint(epoch, val_mae, save_dir=save_dir)
259296

297+
# Log train/val metrics to wandb
298+
if wandb is not None and self.trainer_args.get("wandb_path"):
299+
wandb.log(
300+
{f"train_{k}_mae": v for k, v in train_mae.items()}
301+
| {f"val_{k}_mae": v for k, v in val_mae.items()}
302+
)
303+
260304
if test_loader is not None:
261305
# test best model
262306
print("---------Evaluate Model on Test Set---------------")
@@ -279,6 +323,10 @@ def train(
279323
self.training_history[key]["test"] = test_mae[key]
280324
self.save(filename=os.path.join(save_dir, test_file))
281325

326+
# Log test metrics to wandb
327+
if wandb is not None and self.trainer_args.get("wandb_path"):
328+
wandb.log({f"test_{k}_mae": v for k, v in test_mae.items()})
329+
282330
def _train(self, train_loader: DataLoader, current_epoch: int) -> dict:
283331
"""Train all data for one epoch.
284332

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ test = ["pytest-cov>=4", "pytest>=8"]
3232
# needed to run interactive example notebooks
3333
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
3434
docs = ["lazydocs>=0.4"]
35+
logging = ["wandb>=0.17"]
3536

3637
[project.urls]
3738
Source = "https://github.com/CederGroupHub/chgnet"
@@ -89,6 +90,7 @@ ignore = [
8990
pydocstyle.convention = "google"
9091
isort.required-imports = ["from __future__ import annotations"]
9192
isort.split-on-trailing-comma = false
93+
isort.known-third-party = ["wandb"]
9294

9395
[tool.ruff.format]
9496
docstring-code-format = true

tests/test_trainer.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from typing import TYPE_CHECKING
44

55
import numpy as np
6+
import pytest
67
import torch
8+
import wandb
79
from pymatgen.core import Lattice, Structure
810

911
from chgnet.data.dataset import StructureData, get_train_val_test_loader
@@ -36,19 +38,24 @@
3638
)
3739

3840

39-
def test_trainer(tmp_path: Path) -> None:
41+
def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
4042
chgnet = CHGNet.load()
4143
train_loader, val_loader, _test_loader = get_train_val_test_loader(
4244
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
4345
)
46+
extra_run_config = dict(some_other_hyperparam=42)
4447
trainer = Trainer(
4548
model=chgnet,
4649
targets="efsm",
4750
optimizer="Adam",
4851
criterion="MSE",
4952
learning_rate=1e-2,
5053
epochs=5,
54+
wandb_path="test/run",
55+
wandb_init_kwargs=dict(anonymous="must"),
56+
extra_run_config=extra_run_config,
5157
)
58+
assert dict(wandb.config).items() >= extra_run_config.items()
5259
dir_name = "test_tmp_dir"
5360
test_dir = tmp_path / dir_name
5461
trainer.train(train_loader, val_loader, save_dir=test_dir)
@@ -63,6 +70,12 @@ def test_trainer(tmp_path: Path) -> None:
6370
n_matches == 1
6471
), f"Expected 1 {prefix} file, found {n_matches} in {output_files}"
6572

73+
# expect ImportError when passing wandb_path without wandb installed
74+
err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging"
75+
with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012
76+
ctx.setattr("chgnet.trainer.trainer.wandb", None)
77+
_ = Trainer(model=chgnet, wandb_path="some-org/some-project")
78+
6679

6780
def test_trainer_composition_model(tmp_path: Path) -> None:
6881
chgnet = CHGNet.load()

0 commit comments

Comments
 (0)