Skip to content

Commit 6408de5

Browse files
committed
Adding function to initialize StructureData from vasp dir
1 parent 11315e5 commit 6408de5

File tree

5 files changed

+181
-53
lines changed

5 files changed

+181
-53
lines changed

chgnet/data/dataset.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
forces: list[Sequence[Sequence[float]]],
3636
stresses: list[Sequence[Sequence[float]]] | None = None,
3737
magmoms: list[Sequence[Sequence[float]]] | None = None,
38-
structure_ids: list[str] | None = None,
38+
structure_ids: list | None = None,
3939
graph_converter: CrystalGraphConverter | None = None,
4040
shuffle: bool = True,
4141
) -> None:
@@ -49,7 +49,7 @@ def __init__(
4949
Default = None
5050
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
5151
Default = None
52-
structure_ids (list[str], optional): a list of ids to track the structures
52+
structure_ids (list, optional): a list of ids to track the structures
5353
Default = None
5454
graph_converter (CrystalGraphConverter, optional): Converts the structures
5555
to graphs. If None, it will be set to CHGNet 0.3.0 converter
@@ -87,6 +87,51 @@ def __init__(
8787
self.failed_idx: list[int] = []
8888
self.failed_graph_id: dict[str, str] = {}
8989

90+
@classmethod
91+
def from_vasp(
92+
cls,
93+
file_root: str,
94+
check_electronic_convergence: bool = True,
95+
save_path: str | None = None,
96+
graph_converter: CrystalGraphConverter | None = None,
97+
shuffle: bool = True,
98+
):
99+
"""Parse VASP output files into structures and labels and feed into the dataset.
100+
101+
Args:
102+
file_root (str): the directory of the VASP calculation outputs
103+
check_electronic_convergence (bool): if set to True, this function will
104+
raise Exception to VASP calculation that did not achieve
105+
electronic convergence.
106+
Default = True
107+
save_path (str): path to save the parsed VASP labels
108+
Default = None
109+
graph_converter (CrystalGraphConverter, optional): Converts the structures
110+
to graphs. If None, it will be set to CHGNet 0.3.0 converter
111+
with AtomGraph cutoff = 6A.
112+
shuffle (bool): whether to shuffle the sequence of dataset
113+
Default = True
114+
"""
115+
result_dict = utils.parse_vasp_dir(
116+
file_root=file_root,
117+
check_electronic_convergence=check_electronic_convergence,
118+
save_path=save_path,
119+
)
120+
return cls(
121+
structures=result_dict["structure"],
122+
energies=result_dict["energy_per_atom"],
123+
forces=result_dict["force"],
124+
stresses=None
125+
if result_dict["stress"] in [None, []]
126+
else result_dict["stress"],
127+
magmoms=None
128+
if result_dict["magmom"] in [None, []]
129+
else result_dict["magmom"],
130+
structure_ids=np.arange(len(result_dict["structure"])),
131+
graph_converter=graph_converter,
132+
shuffle=shuffle,
133+
)
134+
90135
def __len__(self) -> int:
91136
"""Get the number of structures in this dataset."""
92137
return len(self.keys)

chgnet/utils/vasp_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
from monty.io import reverse_readfile
88
from pymatgen.io.vasp.outputs import Oszicar, Vasprun
99

10+
from chgnet.utils import write_json
11+
1012
if TYPE_CHECKING:
1113
from pymatgen.core import Structure
1214

1315

1416
def parse_vasp_dir(
15-
file_root: str, check_electronic_convergence: bool = True
17+
file_root: str,
18+
check_electronic_convergence: bool = True,
19+
save_path: str | None = None,
1620
) -> dict[str, list]:
1721
"""Parse VASP output files into structures and labels
1822
By default, the magnetization is read from mag_x from VASP,
@@ -22,6 +26,8 @@ def parse_vasp_dir(
2226
file_root (str): the directory of the VASP calculation outputs
2327
check_electronic_convergence (bool): if set to True, this function will raise
2428
Exception to VASP calculation that did not achieve electronic convergence.
29+
Default = True
30+
save_path (str): path to save the parsed VASP labels
2531
"""
2632
if os.path.exists(file_root) is False:
2733
raise FileNotFoundError("No such file or directory")
@@ -153,6 +159,10 @@ def parse_vasp_dir(
153159
if dataset["uncorrected_total_energy"] == []:
154160
raise RuntimeError(f"No data parsed from {file_root}!")
155161

162+
if save_path is not None:
163+
save_dict = dataset.copy()
164+
save_dict["structure"] = [struct.as_dict() for struct in dataset["structure"]]
165+
write_json(save_dict, save_path)
156166
return dataset
157167

158168

0 commit comments

Comments
 (0)