We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3bc34b5 commit 2fb442dCopy full SHA for 2fb442d
tests/test_trainer.py
@@ -93,9 +93,9 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
93
new_chgnet = CHGNet.from_file(weights_path)
94
for param in new_chgnet.composition_model.parameters():
95
assert param.requires_grad is False
96
- comparison = (
97
- new_chgnet.composition_model.state_dict()["fc.weight"] == initial_weights
98
- )
+ comparison = new_chgnet.composition_model.state_dict()["fc.weight"].to(
+ "cpu"
+ ) == initial_weights.to("cpu")
99
expect = torch.ones_like(comparison)
100
# Only Na and Cl should have updated
101
expect[0][10] = 0
0 commit comments