Skip to content

Commit 0da2d15

Browse files
committed
allow missing labels in training
1 parent 817e21b commit 0da2d15

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

chgnet/trainer/trainer.py

+44-14
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
force_loss_ratio: float = 1,
5050
stress_loss_ratio: float = 0.1,
5151
mag_loss_ratio: float = 0.1,
52+
allow_missing_labels: bool = True,
5253
optimizer: str = "Adam",
5354
scheduler: str = "CosLR",
5455
criterion: str = "MSE",
@@ -78,6 +79,9 @@ def __init__(
7879
Default = 0.1
7980
mag_loss_ratio (float): magmom loss ratio in loss function
8081
Default = 0.1
82+
allow_missing_labels (bool): whether to allow missing labels in the dataset,
83+
missed target will not contribute to loss and MAEs
84+
Default = True
8185
optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW",
8286
"RAdam". Default = 'Adam'
8387
scheduler (str): learning rate scheduler. Can be "CosLR", "ExponentialLR",
@@ -209,6 +213,7 @@ def __init__(
209213
force_loss_ratio=force_loss_ratio,
210214
stress_loss_ratio=stress_loss_ratio,
211215
mag_loss_ratio=mag_loss_ratio,
216+
allow_missing_labels=allow_missing_labels,
212217
**kwargs,
213218
)
214219
self.epochs = epochs
@@ -726,6 +731,7 @@ def __init__(
726731
stress_loss_ratio: float = 0.1,
727732
mag_loss_ratio: float = 0.1,
728733
delta: float = 0.1,
734+
allow_missing_labels: bool = True,
729735
) -> None:
730736
"""Initialize the combined loss.
731737
@@ -745,6 +751,8 @@ def __init__(
745751
mag_loss_ratio (float): magmom loss ratio in loss function
746752
Default = 0.1
747753
delta (float): delta for torch.nn.HuberLoss. Default = 0.1
754+
allow_missing_labels (bool): whether to allow missing labels in the dataset,
755+
missed target will not contribute to loss and MAEs
748756
"""
749757
super().__init__()
750758
# Define loss criterion
@@ -771,6 +779,7 @@ def __init__(
771779
self.mag_loss_ratio = 0
772780
else:
773781
self.mag_loss_ratio = mag_loss_ratio
782+
self.allow_missing_labels = allow_missing_labels
774783

775784
def forward(
776785
self,
@@ -791,25 +800,37 @@ def forward(
791800
out = {"loss": 0.0}
792801
# Energy
793802
if "e" in self.target_str:
794-
if self.is_intensive:
795-
out["loss"] += self.energy_loss_ratio * self.criterion(
796-
targets["e"], prediction["e"]
797-
)
798-
out["e_MAE"] = mae(targets["e"], prediction["e"])
799-
out["e_MAE_size"] = prediction["e"].shape[0]
803+
if self.allow_missing_labels:
804+
valid_value_indices = ~torch.isnan(targets["e"])
805+
valid_e_target = targets["e"][valid_value_indices]
806+
valid_atoms_per_graph = prediction["atoms_per_graph"][
807+
valid_value_indices
808+
]
809+
valid_e_pred = prediction["e"][valid_value_indices]
810+
if valid_e_pred.shape == torch.Size([]):
811+
valid_e_pred = valid_e_pred.view(1)
800812
else:
801-
e_per_atom_target = targets["e"] / prediction["atoms_per_graph"]
802-
e_per_atom_pred = prediction["e"] / prediction["atoms_per_graph"]
803-
out["loss"] += self.energy_loss_ratio * self.criterion(
804-
e_per_atom_target, e_per_atom_pred
805-
)
806-
out["e_MAE"] = mae(e_per_atom_target, e_per_atom_pred)
807-
out["e_MAE_size"] = prediction["e"].shape[0]
813+
valid_e_target = targets["e"]
814+
valid_atoms_per_graph = prediction["atoms_per_graph"]
815+
valid_e_pred = prediction["e"]
816+
if self.is_intensive:
817+
valid_e_target = valid_e_target / valid_atoms_per_graph
818+
valid_e_pred = valid_e_pred / valid_atoms_per_graph
819+
820+
out["loss"] += self.energy_loss_ratio * self.criterion(
821+
valid_e_target, valid_e_pred
822+
)
823+
out["e_MAE"] = mae(valid_e_target, valid_e_pred)
824+
out["e_MAE_size"] = prediction["e"].shape[0]
808825

809826
# Force
810827
if "f" in self.target_str:
811828
forces_pred = torch.cat(prediction["f"], dim=0)
812829
forces_target = torch.cat(targets["f"], dim=0)
830+
if self.allow_missing_labels:
831+
valid_value_indices = ~torch.isnan(forces_target)
832+
forces_target = forces_target[valid_value_indices]
833+
forces_pred = forces_pred[valid_value_indices]
813834
out["loss"] += self.force_loss_ratio * self.criterion(
814835
forces_target, forces_pred
815836
)
@@ -820,6 +841,10 @@ def forward(
820841
if "s" in self.target_str:
821842
stress_pred = torch.cat(prediction["s"], dim=0)
822843
stress_target = torch.cat(targets["s"], dim=0)
844+
if self.allow_missing_labels:
845+
valid_value_indices = ~torch.isnan(stress_target)
846+
stress_target = stress_target[valid_value_indices]
847+
stress_pred = stress_pred[valid_value_indices]
823848
out["loss"] += self.stress_loss_ratio * self.criterion(
824849
stress_target, stress_pred
825850
)
@@ -832,7 +857,12 @@ def forward(
832857
m_mae_size = 0
833858
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
834859
# exclude structures without magmom labels
835-
if mag_target is not None:
860+
if self.allow_missing_labels:
861+
if mag_target is not None and not np.isnan(mag_target).any():
862+
mag_preds.append(mag_pred)
863+
mag_targets.append(mag_target)
864+
m_mae_size += mag_target.shape[0]
865+
else:
836866
mag_preds.append(mag_pred)
837867
mag_targets.append(mag_target)
838868
m_mae_size += mag_target.shape[0]

tests/test_trainer.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
coords = [[0, 0, 0], [0.5, 0.5, 0.5]]
2222
NaCl = Structure(lattice, species, coords)
2323
structures, energies, forces, stresses, magmoms = [], [], [], [], []
24-
for _ in range(100):
24+
for _ in range(20):
2525
struct = NaCl.copy()
2626
struct.perturb(0.1)
2727
structures.append(struct)
@@ -30,15 +30,22 @@
3030
stresses.append(np.random.random([3, 3]))
3131
magmoms.append(np.random.random(2))
3232

33+
# Create some missing labels
34+
energies[10] = np.nan
35+
forces[4] = (np.nan * np.ones((len(structures[4]), 3))).tolist()
36+
stresses[6] = (np.nan * np.ones((3, 3))).tolist()
37+
magmoms[8] = (np.nan * np.ones((len(structures[8]), 1))).tolist()
38+
3339
data = StructureData(
3440
structures=structures,
3541
energies=energies,
3642
forces=forces,
3743
stresses=stresses,
3844
magmoms=magmoms,
45+
shuffle=False,
3946
)
4047
train_loader, val_loader, _test_loader = get_train_val_test_loader(
41-
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
48+
data, batch_size=4, train_ratio=0.9, val_ratio=0.05
4249
)
4350
chgnet = CHGNet.load()
4451

@@ -55,6 +62,7 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
5562
wandb_path="test/run",
5663
wandb_init_kwargs=dict(anonymous="must"),
5764
extra_run_config=extra_run_config,
65+
allow_missing_labels=True,
5866
)
5967
trainer.train(
6068
train_loader,
@@ -66,7 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
6674
for param in chgnet.composition_model.parameters():
6775
assert param.requires_grad is False
6876
assert tmp_path.is_dir(), "Training dir was not created"
69-
77+
for target_str in ["e", "f", "s", "m"]:
78+
assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
79+
assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
7080
output_files = [file.name for file in tmp_path.iterdir()]
7181
for prefix in ("epoch", "bestE_", "bestF_"):
7282
n_matches = sum(file.startswith(prefix) for file in output_files)
@@ -147,6 +157,7 @@ def test_wandb_init(mock_wandb):
147157
"wandb_path": "test-project/test-run",
148158
"wandb_init_kwargs": {"tags": ["test"]},
149159
"extra_run_config": None,
160+
"allow_missing_labels": True,
150161
}
151162
mock_wandb.init.assert_called_once_with(
152163
project="test-project", name="test-run", config=expected_config, tags=["test"]

0 commit comments

Comments
 (0)