Skip to content

Commit 2fb442d

Browse files
committed
fixed test
1 parent 3bc34b5 commit 2fb442d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
9393
new_chgnet = CHGNet.from_file(weights_path)
9494
for param in new_chgnet.composition_model.parameters():
9595
assert param.requires_grad is False
96-
comparison = (
97-
new_chgnet.composition_model.state_dict()["fc.weight"] == initial_weights
98-
)
96+
comparison = new_chgnet.composition_model.state_dict()["fc.weight"].to(
97+
"cpu"
98+
) == initial_weights.to("cpu")
9999
expect = torch.ones_like(comparison)
100100
# Only Na and Cl should have updated
101101
expect[0][10] = 0

0 commit comments

Comments
 (0)