@@ -35,7 +35,7 @@ def __init__(
35
35
forces : list [Sequence [Sequence [float ]]],
36
36
stresses : list [Sequence [Sequence [float ]]] | None = None ,
37
37
magmoms : list [Sequence [Sequence [float ]]] | None = None ,
38
- structure_ids : list [ str ] | None = None ,
38
+ structure_ids : list | None = None ,
39
39
graph_converter : CrystalGraphConverter | None = None ,
40
40
shuffle : bool = True ,
41
41
) -> None :
@@ -49,7 +49,7 @@ def __init__(
49
49
Default = None
50
50
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
51
51
Default = None
52
- structure_ids (list[str] , optional): a list of ids to track the structures
52
+ structure_ids (list, optional): a list of ids to track the structures
53
53
Default = None
54
54
graph_converter (CrystalGraphConverter, optional): Converts the structures
55
55
to graphs. If None, it will be set to CHGNet 0.3.0 converter
@@ -87,6 +87,51 @@ def __init__(
87
87
self .failed_idx : list [int ] = []
88
88
self .failed_graph_id : dict [str , str ] = {}
89
89
90
+ @classmethod
91
+ def from_vasp (
92
+ cls ,
93
+ file_root : str ,
94
+ check_electronic_convergence : bool = True ,
95
+ save_path : str | None = None ,
96
+ graph_converter : CrystalGraphConverter | None = None ,
97
+ shuffle : bool = True ,
98
+ ):
99
+ """Parse VASP output files into structures and labels and feed into the dataset.
100
+
101
+ Args:
102
+ file_root (str): the directory of the VASP calculation outputs
103
+ check_electronic_convergence (bool): if set to True, this function will
104
+ raise Exception to VASP calculation that did not achieve
105
+ electronic convergence.
106
+ Default = True
107
+ save_path (str): path to save the parsed VASP labels
108
+ Default = None
109
+ graph_converter (CrystalGraphConverter, optional): Converts the structures
110
+ to graphs. If None, it will be set to CHGNet 0.3.0 converter
111
+ with AtomGraph cutoff = 6A.
112
+ shuffle (bool): whether to shuffle the sequence of dataset
113
+ Default = True
114
+ """
115
+ result_dict = utils .parse_vasp_dir (
116
+ file_root = file_root ,
117
+ check_electronic_convergence = check_electronic_convergence ,
118
+ save_path = save_path ,
119
+ )
120
+ return cls (
121
+ structures = result_dict ["structure" ],
122
+ energies = result_dict ["energy_per_atom" ],
123
+ forces = result_dict ["force" ],
124
+ stresses = None
125
+ if result_dict ["stress" ] in [None , []]
126
+ else result_dict ["stress" ],
127
+ magmoms = None
128
+ if result_dict ["magmom" ] in [None , []]
129
+ else result_dict ["magmom" ],
130
+ structure_ids = np .arange (len (result_dict ["structure" ])),
131
+ graph_converter = graph_converter ,
132
+ shuffle = shuffle ,
133
+ )
134
+
90
135
def __len__ (self ) -> int :
91
136
"""Get the number of structures in this dataset."""
92
137
return len (self .keys )
0 commit comments