@@ -49,6 +49,7 @@ def __init__(
49
49
force_loss_ratio : float = 1 ,
50
50
stress_loss_ratio : float = 0.1 ,
51
51
mag_loss_ratio : float = 0.1 ,
52
+ allow_missing_labels : bool = True ,
52
53
optimizer : str = "Adam" ,
53
54
scheduler : str = "CosLR" ,
54
55
criterion : str = "MSE" ,
@@ -78,6 +79,9 @@ def __init__(
78
79
Default = 0.1
79
80
mag_loss_ratio (float): magmom loss ratio in loss function
80
81
Default = 0.1
82
+ allow_missing_labels (bool): whether to allow missing labels in the dataset,
83
+ missed target will not contribute to loss and MAEs
84
+ Default = True
81
85
optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW",
82
86
"RAdam". Default = 'Adam'
83
87
scheduler (str): learning rate scheduler. Can be "CosLR", "ExponentialLR",
@@ -209,6 +213,7 @@ def __init__(
209
213
force_loss_ratio = force_loss_ratio ,
210
214
stress_loss_ratio = stress_loss_ratio ,
211
215
mag_loss_ratio = mag_loss_ratio ,
216
+ allow_missing_labels = allow_missing_labels ,
212
217
** kwargs ,
213
218
)
214
219
self .epochs = epochs
@@ -726,6 +731,7 @@ def __init__(
726
731
stress_loss_ratio : float = 0.1 ,
727
732
mag_loss_ratio : float = 0.1 ,
728
733
delta : float = 0.1 ,
734
+ allow_missing_labels : bool = True ,
729
735
) -> None :
730
736
"""Initialize the combined loss.
731
737
@@ -745,6 +751,8 @@ def __init__(
745
751
mag_loss_ratio (float): magmom loss ratio in loss function
746
752
Default = 0.1
747
753
delta (float): delta for torch.nn.HuberLoss. Default = 0.1
754
+ allow_missing_labels (bool): whether to allow missing labels in the dataset,
755
+ missed target will not contribute to loss and MAEs
748
756
"""
749
757
super ().__init__ ()
750
758
# Define loss criterion
@@ -771,6 +779,7 @@ def __init__(
771
779
self .mag_loss_ratio = 0
772
780
else :
773
781
self .mag_loss_ratio = mag_loss_ratio
782
+ self .allow_missing_labels = allow_missing_labels
774
783
775
784
def forward (
776
785
self ,
@@ -791,25 +800,37 @@ def forward(
791
800
out = {"loss" : 0.0 }
792
801
# Energy
793
802
if "e" in self .target_str :
794
- if self .is_intensive :
795
- out ["loss" ] += self .energy_loss_ratio * self .criterion (
796
- targets ["e" ], prediction ["e" ]
797
- )
798
- out ["e_MAE" ] = mae (targets ["e" ], prediction ["e" ])
799
- out ["e_MAE_size" ] = prediction ["e" ].shape [0 ]
803
+ if self .allow_missing_labels :
804
+ valid_value_indices = ~ torch .isnan (targets ["e" ])
805
+ valid_e_target = targets ["e" ][valid_value_indices ]
806
+ valid_atoms_per_graph = prediction ["atoms_per_graph" ][
807
+ valid_value_indices
808
+ ]
809
+ valid_e_pred = prediction ["e" ][valid_value_indices ]
810
+ if valid_e_pred .shape == torch .Size ([]):
811
+ valid_e_pred = valid_e_pred .view (1 )
800
812
else :
801
- e_per_atom_target = targets ["e" ] / prediction ["atoms_per_graph" ]
802
- e_per_atom_pred = prediction ["e" ] / prediction ["atoms_per_graph" ]
803
- out ["loss" ] += self .energy_loss_ratio * self .criterion (
804
- e_per_atom_target , e_per_atom_pred
805
- )
806
- out ["e_MAE" ] = mae (e_per_atom_target , e_per_atom_pred )
807
- out ["e_MAE_size" ] = prediction ["e" ].shape [0 ]
813
+ valid_e_target = targets ["e" ]
814
+ valid_atoms_per_graph = prediction ["atoms_per_graph" ]
815
+ valid_e_pred = prediction ["e" ]
816
+ if self .is_intensive :
817
+ valid_e_target = valid_e_target / valid_atoms_per_graph
818
+ valid_e_pred = valid_e_pred / valid_atoms_per_graph
819
+
820
+ out ["loss" ] += self .energy_loss_ratio * self .criterion (
821
+ valid_e_target , valid_e_pred
822
+ )
823
+ out ["e_MAE" ] = mae (valid_e_target , valid_e_pred )
824
+ out ["e_MAE_size" ] = prediction ["e" ].shape [0 ]
808
825
809
826
# Force
810
827
if "f" in self .target_str :
811
828
forces_pred = torch .cat (prediction ["f" ], dim = 0 )
812
829
forces_target = torch .cat (targets ["f" ], dim = 0 )
830
+ if self .allow_missing_labels :
831
+ valid_value_indices = ~ torch .isnan (forces_target )
832
+ forces_target = forces_target [valid_value_indices ]
833
+ forces_pred = forces_pred [valid_value_indices ]
813
834
out ["loss" ] += self .force_loss_ratio * self .criterion (
814
835
forces_target , forces_pred
815
836
)
@@ -820,6 +841,10 @@ def forward(
820
841
if "s" in self .target_str :
821
842
stress_pred = torch .cat (prediction ["s" ], dim = 0 )
822
843
stress_target = torch .cat (targets ["s" ], dim = 0 )
844
+ if self .allow_missing_labels :
845
+ valid_value_indices = ~ torch .isnan (stress_target )
846
+ stress_target = stress_target [valid_value_indices ]
847
+ stress_pred = stress_pred [valid_value_indices ]
823
848
out ["loss" ] += self .stress_loss_ratio * self .criterion (
824
849
stress_target , stress_pred
825
850
)
@@ -832,7 +857,12 @@ def forward(
832
857
m_mae_size = 0
833
858
for mag_pred , mag_target in zip (prediction ["m" ], targets ["m" ], strict = True ):
834
859
# exclude structures without magmom labels
835
- if mag_target is not None :
860
+ if self .allow_missing_labels :
861
+ if mag_target is not None and not np .isnan (mag_target ).any ():
862
+ mag_preds .append (mag_pred )
863
+ mag_targets .append (mag_target )
864
+ m_mae_size += mag_target .shape [0 ]
865
+ else :
836
866
mag_preds .append (mag_pred )
837
867
mag_targets .append (mag_target )
838
868
m_mae_size += mag_target .shape [0 ]
0 commit comments