Skip to content

Commit 73d6096

Browse files
committed
test.yml set fail-fast: false, fix some doc strings missing Returns/Raises, change parse_vasp_dir error type FileNotFoundError->NotADirectoryError
1 parent ce1f29c commit 73d6096

File tree

9 files changed

+90
-35
lines changed

9 files changed

+90
-35
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
jobs:
1111
tests:
1212
strategy:
13-
fail-fast: true
13+
fail-fast: false
1414
matrix:
1515
os: [ubuntu-latest, macos-14, windows-latest]
1616
version:

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]
44

55
repos:
66
- repo: https://github.com/astral-sh/ruff-pre-commit
7-
rev: v0.6.2
7+
rev: v0.6.4
88
hooks:
99
- id: ruff
1010
args: [--fix]
@@ -48,7 +48,7 @@ repos:
4848
- svelte
4949

5050
- repo: https://github.com/pre-commit/mirrors-eslint
51-
rev: v9.9.0
51+
rev: v9.10.0
5252
hooks:
5353
- id: eslint
5454
types: [file]

chgnet/trainer/trainer.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ def __init__(
110110
that are not included in the trainer_args. Default = None
111111
112112
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
113+
114+
Raises:
115+
NotImplementedError: If the optimizer or scheduler is not implemented
116+
ImportError: If wandb_path is specified but wandb is not installed
117+
ValueError: If wandb_path is specified but not in the format
118+
'project/run_name'
113119
"""
114120
# Store trainer args for reproducibility
115121
self.trainer_args = {
@@ -271,6 +277,9 @@ def train(
271277
wandb_log_freq ("epoch" | "batch"): Frequency of logging to wandb.
272278
'epoch' logs once per epoch, 'batch' logs after every batch.
273279
Default = "batch"
280+
281+
Raises:
282+
ValueError: If model is not initialized
274283
"""
275284
if self.model is None:
276285
raise ValueError("Model needs to be initialized")
@@ -579,7 +588,11 @@ def _validate(
579588
return {k: round(mae_error.avg, 6) for k, mae_error in mae_errors.items()}
580589

581590
def get_best_model(self) -> CHGNet:
582-
"""Get best model recorded in the trainer."""
591+
"""Get best model recorded in the trainer.
592+
593+
Returns:
594+
CHGNet: the model with lowest validation set energy error
595+
"""
583596
if self.best_model is None:
584597
raise RuntimeError("the model needs to be trained first")
585598
MAE = min(self.training_history["e"]["val"]) # noqa: N806
@@ -649,7 +662,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None:
649662

650663
@classmethod
651664
def load(cls, path: str) -> Self:
652-
"""Load trainer state_dict."""
665+
"""Load trainer state_dict.
666+
667+
Args:
668+
path (str): path to the saved model
669+
670+
Returns:
671+
Trainer: the loaded trainer
672+
"""
653673
state = torch.load(path, map_location=torch.device("cpu"))
654674
model = CHGNet.from_dict(state["model"])
655675
print(f"Loaded model params = {sum(p.numel() for p in model.parameters()):,}")
@@ -664,8 +684,21 @@ def load(cls, path: str) -> Self:
664684
return trainer
665685

666686
@staticmethod
667-
def move_to(obj, device) -> Tensor | list[Tensor]:
668-
"""Move object to device."""
687+
def move_to(
688+
obj: Tensor | list[Tensor], device: torch.device
689+
) -> Tensor | list[Tensor]:
690+
"""Move object to device.
691+
692+
Args:
693+
obj (Tensor | list[Tensor]): object(s) to move to device
694+
device (torch.device): device to move object to
695+
696+
Raises:
697+
TypeError: if obj is not a tensor or list of tensors
698+
699+
Returns:
700+
Tensor | list[Tensor]: moved object(s)
701+
"""
669702
if torch.is_tensor(obj):
670703
return obj.to(device)
671704
if isinstance(obj, list):

chgnet/utils/common_utils.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def cuda_devices_sorted_by_free_mem() -> list[int]:
3939
"""List available CUDA devices sorted by increasing available memory.
4040
4141
To get the device with the most free memory, use the last list item.
42+
43+
Returns:
44+
list[int]: CUDA device numbers sorted by increasing free memory.
4245
"""
4346
if not torch.cuda.is_available():
4447
return []
@@ -94,10 +97,10 @@ def mae(prediction: Tensor, target: Tensor) -> Tensor:
9497

9598

9699
def read_json(filepath: str) -> dict:
97-
"""Read the json file.
100+
"""Read the JSON file.
98101
99102
Args:
100-
filepath (str): file name of json to read.
103+
filepath (str): file name of JSON to read.
101104
102105
Returns:
103106
dict: data stored in filepath
@@ -107,27 +110,27 @@ def read_json(filepath: str) -> dict:
107110

108111

109112
def write_json(dct: dict, filepath: str) -> dict:
110-
"""Write the json file.
113+
"""Write the JSON file.
111114
112115
Args:
113116
dct (dict): dictionary to write
114-
filepath (str): file name of json to write.
115-
116-
Returns:
117-
written dictionary
117+
filepath (str): file name of JSON to write.
118118
"""
119119

120120
def handler(obj: object) -> int | object:
121121
"""Convert numpy int64 to int.
122122
123123
Fixes TypeError: Object of type int64 is not JSON serializable
124124
reported in https://github.com/CederGroupHub/chgnet/issues/168.
125+
126+
Returns:
127+
int | object: object for serialization
125128
"""
126129
if isinstance(obj, np.integer):
127130
return int(obj)
128131
return obj
129132

130-
with open(filepath, "w") as file:
133+
with open(filepath, mode="w") as file:
131134
json.dump(dct, file, default=handler)
132135

133136

chgnet/utils/vasp_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,16 @@ def parse_vasp_dir(
3030
Exception to VASP calculation that did not achieve electronic convergence.
3131
Default = True
3232
save_path (str): path to save the parsed VASP labels
33+
34+
Raises:
35+
NotADirectoryError: if the base_dir is not a directory
36+
37+
Returns:
38+
dict: a dictionary of lists with keys for structure, uncorrected_total_energy,
39+
energy_per_atom, force, magmom, stress.
3340
"""
3441
if os.path.isdir(base_dir) is False:
35-
raise FileNotFoundError(f"{base_dir=} is not a directory")
42+
raise NotADirectoryError(f"{base_dir=} is not a directory")
3643

3744
oszicar_path = zpath(f"{base_dir}/OSZICAR")
3845
vasprun_path = zpath(f"{base_dir}/vasprun.xml")
@@ -170,6 +177,9 @@ def solve_charge_by_mag(
170177
(3.5, 4.2): 3,
171178
(4.2, 5): 2
172179
))
180+
181+
Returns:
182+
Structure: pymatgen Structure with oxidation states assigned based on magmoms.
173183
"""
174184
out_structure = structure.copy()
175185
out_structure.remove_oxidation_states()

examples/make_graphs.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,22 @@ def make_graphs(
5858
make_partition(labels, graph_dir, train_ratio, val_ratio)
5959

6060

61-
def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool:
62-
"""Convert a structure to a CrystalGraph and save it."""
61+
def make_one_graph(
62+
mp_id: str, graph_id: str, data: StructureJsonData, graph_dir: str
63+
) -> dict | bool:
64+
"""Convert a structure to a CrystalGraph and save it.
65+
66+
Args:
67+
mp_id (str): The material id.
68+
graph_id (str): The graph id.
69+
data (StructureJsonData): The dataset. Warning: Dicts are popped from the data,
70+
i.e. modifying the data in place.
71+
graph_dir (str): The directory to save the graphs.
72+
73+
Returns:
74+
dict | bool: The label dictionary if the graph is saved successfully, False
75+
otherwise.
76+
"""
6377
dct = data.data[mp_id].pop(graph_id)
6478
struct = Structure.from_dict(dct.pop("structure"))
6579
try:

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ docstring-code-format = true
102102
"ANN201",
103103
"D100",
104104
"D103",
105+
"DOC201", # doc string missing Return section
105106
"FBT001",
106107
"FBT002",
107108
"INP001",

site/make_docs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@
4949
markdown = markdown.replace(
5050
"\n**Global Variables**\n---------------\n- **TYPE_CHECKING**\n\n", ""
5151
)
52-
with open(path, "w") as file:
52+
with open(path, mode="w") as file:
5353
file.write(markdown)

tests/test_dataset.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,17 @@
2020
def structure_data() -> StructureData:
2121
"""Create a graph with 3 nodes and 3 directed edges."""
2222
random.seed(42)
23-
structures, energies, forces, stresses, magmoms, structure_ids = (
24-
[],
25-
[],
26-
[],
27-
[],
28-
[],
29-
[],
30-
)
23+
structures, energies, forces = [], [], []
24+
stresses, magmoms, structure_ids = [], [], []
25+
3126
for index in range(100):
32-
struct = NaCl.copy()
33-
struct.perturb(0.1)
34-
structures.append(struct)
35-
energies.append(np.random.random(1))
36-
forces.append(np.random.random([2, 3]))
37-
stresses.append(np.random.random([3, 3]))
38-
magmoms.append(np.random.random([2, 1]))
39-
structure_ids.append(index)
27+
structures += [NaCl.copy().perturb(0.1)]
28+
energies += [np.random.random(1)]
29+
forces += [np.random.random([2, 3])]
30+
stresses += [np.random.random([3, 3])]
31+
magmoms += [np.random.random([2, 1])]
32+
structure_ids += [index]
33+
4034
return StructureData(
4135
structures=structures,
4236
energies=energies,

0 commit comments

Comments
 (0)