@@ -208,7 +208,6 @@ def __init__(
208
208
self .criterion = CombinedLoss (
209
209
target_str = self .targets ,
210
210
criterion = criterion ,
211
- is_intensive = self .model .is_intensive ,
212
211
energy_loss_ratio = energy_loss_ratio ,
213
212
force_loss_ratio = force_loss_ratio ,
214
213
stress_loss_ratio = stress_loss_ratio ,
@@ -725,7 +724,6 @@ def __init__(
725
724
* ,
726
725
target_str : str = "ef" ,
727
726
criterion : str = "MSE" ,
728
- is_intensive : bool = True ,
729
727
energy_loss_ratio : float = 1 ,
730
728
force_loss_ratio : float = 1 ,
731
729
stress_loss_ratio : float = 0.1 ,
@@ -740,8 +738,6 @@ def __init__(
740
738
Default = "ef"
741
739
criterion: loss criterion to use
742
740
Default = "MSE"
743
- is_intensive (bool): whether the energy label is intensive
744
- Default = True
745
741
energy_loss_ratio (float): energy loss ratio in loss function
746
742
Default = 1
747
743
force_loss_ratio (float): force loss ratio in loss function
@@ -765,7 +761,6 @@ def __init__(
765
761
else :
766
762
raise NotImplementedError
767
763
self .target_str = target_str
768
- self .is_intensive = is_intensive
769
764
self .energy_loss_ratio = energy_loss_ratio
770
765
if "f" not in self .target_str :
771
766
self .force_loss_ratio = 0
@@ -803,19 +798,12 @@ def forward(
803
798
if self .allow_missing_labels :
804
799
valid_value_indices = ~ torch .isnan (targets ["e" ])
805
800
valid_e_target = targets ["e" ][valid_value_indices ]
806
- valid_atoms_per_graph = prediction ["atoms_per_graph" ][
807
- valid_value_indices
808
- ]
809
801
valid_e_pred = prediction ["e" ][valid_value_indices ]
810
802
if valid_e_pred .shape == torch .Size ([]):
811
803
valid_e_pred = valid_e_pred .view (1 )
812
804
else :
813
805
valid_e_target = targets ["e" ]
814
- valid_atoms_per_graph = prediction ["atoms_per_graph" ]
815
806
valid_e_pred = prediction ["e" ]
816
- if self .is_intensive :
817
- valid_e_target = valid_e_target / valid_atoms_per_graph
818
- valid_e_pred = valid_e_pred / valid_atoms_per_graph
819
807
820
808
out ["loss" ] += self .energy_loss_ratio * self .criterion (
821
809
valid_e_target , valid_e_pred
0 commit comments