Skip to content

Commit 5ef8876

Browse files
committed
fixed bug in energy loss
1 parent 84e8d55 commit 5ef8876

File tree

2 files changed

+55
-13
lines changed

2 files changed

+55
-13
lines changed

chgnet/trainer/trainer.py

-12
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def __init__(
208208
self.criterion = CombinedLoss(
209209
target_str=self.targets,
210210
criterion=criterion,
211-
is_intensive=self.model.is_intensive,
212211
energy_loss_ratio=energy_loss_ratio,
213212
force_loss_ratio=force_loss_ratio,
214213
stress_loss_ratio=stress_loss_ratio,
@@ -725,7 +724,6 @@ def __init__(
725724
*,
726725
target_str: str = "ef",
727726
criterion: str = "MSE",
728-
is_intensive: bool = True,
729727
energy_loss_ratio: float = 1,
730728
force_loss_ratio: float = 1,
731729
stress_loss_ratio: float = 0.1,
@@ -740,8 +738,6 @@ def __init__(
740738
Default = "ef"
741739
criterion: loss criterion to use
742740
Default = "MSE"
743-
is_intensive (bool): whether the energy label is intensive
744-
Default = True
745741
energy_loss_ratio (float): energy loss ratio in loss function
746742
Default = 1
747743
force_loss_ratio (float): force loss ratio in loss function
@@ -765,7 +761,6 @@ def __init__(
765761
else:
766762
raise NotImplementedError
767763
self.target_str = target_str
768-
self.is_intensive = is_intensive
769764
self.energy_loss_ratio = energy_loss_ratio
770765
if "f" not in self.target_str:
771766
self.force_loss_ratio = 0
@@ -803,19 +798,12 @@ def forward(
803798
if self.allow_missing_labels:
804799
valid_value_indices = ~torch.isnan(targets["e"])
805800
valid_e_target = targets["e"][valid_value_indices]
806-
valid_atoms_per_graph = prediction["atoms_per_graph"][
807-
valid_value_indices
808-
]
809801
valid_e_pred = prediction["e"][valid_value_indices]
810802
if valid_e_pred.shape == torch.Size([]):
811803
valid_e_pred = valid_e_pred.view(1)
812804
else:
813805
valid_e_target = targets["e"]
814-
valid_atoms_per_graph = prediction["atoms_per_graph"]
815806
valid_e_pred = prediction["e"]
816-
if self.is_intensive:
817-
valid_e_target = valid_e_target / valid_atoms_per_graph
818-
valid_e_pred = valid_e_pred / valid_atoms_per_graph
819807

820808
out["loss"] += self.energy_loss_ratio * self.criterion(
821809
valid_e_target, valid_e_pred

tests/test_trainer.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from chgnet.data.dataset import StructureData, get_train_val_test_loader
1313
from chgnet.model import CHGNet
14-
from chgnet.trainer import Trainer
14+
from chgnet.trainer.trainer import CombinedLoss, Trainer
1515

1616
if TYPE_CHECKING:
1717
from pathlib import Path
@@ -50,6 +50,60 @@
5050
chgnet = CHGNet.load()
5151

5252

53+
def test_combined_loss() -> None:
54+
criterion = CombinedLoss(
55+
target_str="ef",
56+
criterion="MSE",
57+
energy_loss_ratio=1,
58+
force_loss_ratio=1,
59+
stress_loss_ratio=0.1,
60+
mag_loss_ratio=0.1,
61+
allow_missing_labels=False,
62+
)
63+
target1 = {"e": torch.Tensor([1]), "f": [torch.Tensor([[[1, 1, 1], [2, 2, 2]]])]}
64+
prediction1 = chgnet.predict_structure(NaCl)
65+
prediction1 = {
66+
"e": torch.from_numpy(prediction1["e"]).unsqueeze(0),
67+
"f": [torch.from_numpy(prediction1["f"])],
68+
"atoms_per_graph": torch.tensor([2]),
69+
}
70+
out1 = criterion(
71+
targets=target1,
72+
prediction=prediction1,
73+
)
74+
target2 = {
75+
"e": torch.Tensor([1]),
76+
"f": [
77+
torch.Tensor(
78+
[
79+
[
80+
[1, 1, 1],
81+
[1, 1, 1],
82+
[1, 1, 1],
83+
[1, 1, 1],
84+
[2, 2, 2],
85+
[2, 2, 2],
86+
[2, 2, 2],
87+
[2, 2, 2],
88+
]
89+
]
90+
)
91+
],
92+
}
93+
supercell = NaCl.make_supercell([2, 2, 1], in_place=False)
94+
prediction2 = chgnet.predict_structure(supercell)
95+
prediction2 = {
96+
"e": torch.from_numpy(prediction2["e"]).unsqueeze(0),
97+
"f": [torch.from_numpy(prediction2["f"])],
98+
"atoms_per_graph": torch.tensor([8]),
99+
}
100+
out2 = criterion(
101+
targets=target2,
102+
prediction=prediction2,
103+
)
104+
assert np.isclose(out1["loss"], out2["loss"], rtol=1e-04, atol=1e-05)
105+
106+
53107
def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
54108
extra_run_config = dict(some_other_hyperparam=42)
55109
trainer = Trainer(

0 commit comments

Comments
 (0)