33
33
from ase .optimize .optimize import Optimizer
34
34
from typing_extensions import Self
35
35
36
+ from chgnet import PredTask
37
+
36
38
# We would like to thank M3GNet develop team for this module
37
39
# source: https://github.com/materialsvirtuallab/m3gnet
38
40
@@ -59,7 +61,7 @@ def __init__(
59
61
* ,
60
62
use_device : str | None = None ,
61
63
check_cuda_mem : bool = False ,
62
- stress_weight : float | None = 1 / 160.21766208 ,
64
+ stress_weight : float = units . GPa , # GPa to eV/A^3
63
65
on_isolated_atoms : Literal ["ignore" , "warn" , "error" ] = "warn" ,
64
66
return_site_energies : bool = False ,
65
67
** kwargs ,
@@ -124,6 +126,7 @@ def calculate(
124
126
atoms : Atoms | None = None ,
125
127
properties : list | None = None ,
126
128
system_changes : list | None = None ,
129
+ task : PredTask = "efsm" ,
127
130
) -> None :
128
131
"""Calculate various properties of the atoms using CHGNet.
129
132
@@ -133,6 +136,8 @@ def calculate(
133
136
Default is all properties.
134
137
system_changes (list | None): The changes made to the system.
135
138
Default is all changes.
139
+ task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
140
+ Default = "efsm"
136
141
"""
137
142
properties = properties or all_properties
138
143
system_changes = system_changes or all_changes
@@ -147,23 +152,28 @@ def calculate(
147
152
graph = self .model .graph_converter (structure )
148
153
model_prediction = self .model .predict_graph (
149
154
graph .to (self .device ),
150
- task = "efsm" ,
155
+ task = task ,
151
156
return_crystal_feas = True ,
152
157
return_site_energies = self .return_site_energies ,
153
158
)
154
159
155
160
# Convert Result
156
- factor = 1 if not self .model .is_intensive else structure .composition .num_atoms
157
- self .results .update (
158
- energy = model_prediction ["e" ] * factor ,
159
- forces = model_prediction ["f" ],
160
- free_energy = model_prediction ["e" ] * factor ,
161
- magmoms = model_prediction ["m" ],
162
- stress = model_prediction ["s" ] * self .stress_weight ,
163
- crystal_fea = model_prediction ["crystal_fea" ],
161
+ extensive_factor = len (structure ) if self .model .is_intensive else 1
162
+ key_map = dict (
163
+ e = ("energy" , extensive_factor ),
164
+ f = ("forces" , 1 ),
165
+ m = ("magmoms" , 1 ),
166
+ s = ("stress" , self .stress_weight ),
164
167
)
168
+ self .results |= {
169
+ long_key : model_prediction [key ] * factor
170
+ for key , (long_key , factor ) in key_map .items ()
171
+ if key in model_prediction
172
+ }
173
+ self .results ["free_energy" ] = self .results ["energy" ]
174
+ self .results ["crystal_fea" ] = model_prediction ["crystal_fea" ]
165
175
if self .return_site_energies :
166
- self .results . update ( energies = model_prediction ["site_energies" ])
176
+ self .results [ " energies" ] = model_prediction ["site_energies" ]
167
177
168
178
169
179
class StructOptimizer :
@@ -174,7 +184,7 @@ def __init__(
174
184
model : CHGNet | CHGNetCalculator | None = None ,
175
185
optimizer_class : Optimizer | str | None = "FIRE" ,
176
186
use_device : str | None = None ,
177
- stress_weight : float = 1 / 160.21766208 ,
187
+ stress_weight : float = units . GPa ,
178
188
on_isolated_atoms : Literal ["ignore" , "warn" , "error" ] = "warn" ,
179
189
) -> None :
180
190
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
@@ -773,7 +783,7 @@ def __init__(
773
783
model : CHGNet | CHGNetCalculator | None = None ,
774
784
optimizer_class : Optimizer | str | None = "FIRE" ,
775
785
use_device : str | None = None ,
776
- stress_weight : float = 1 / 160.21766208 ,
786
+ stress_weight : float = units . GPa ,
777
787
on_isolated_atoms : Literal ["ignore" , "warn" , "error" ] = "error" ,
778
788
) -> None :
779
789
"""Initialize a structure optimizer object for calculation of bulk modulus.
0 commit comments