Skip to content

Commit d0632a1

Browse files
authored
ruff fixes + type annotations (#156)
* ruff auto fixes * fix ruff FBT001 FBT002 * ruff select = ["ALL"] and fix legacy errors * fix TypeError: unhashable type 'list'chgnet/model/functions.py:71: in __init__ if hidden_dim in {None, 0}:
1 parent 455f4d8 commit d0632a1

17 files changed

+152
-129
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]
44

55
repos:
66
- repo: https://github.com/astral-sh/ruff-pre-commit
7-
rev: v0.4.3
7+
rev: v0.4.4
88
hooks:
99
- id: ruff
1010
args: [--fix]
@@ -46,7 +46,7 @@ repos:
4646
- svelte
4747

4848
- repo: https://github.com/pre-commit/mirrors-eslint
49-
rev: v9.2.0
49+
rev: v9.3.0
5050
hooks:
5151
- id: eslint
5252
types: [file]

chgnet/data/dataset.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
structures: list[Structure],
3434
energies: list[float],
3535
forces: list[Sequence[Sequence[float]]],
36+
*,
3637
stresses: list[Sequence[Sequence[float]]] | None = None,
3738
magmoms: list[Sequence[Sequence[float]]] | None = None,
3839
structure_ids: list | None = None,
@@ -63,7 +64,7 @@ def __init__(
6364
"""
6465
for idx, struct in enumerate(structures):
6566
if not isinstance(struct, Structure):
66-
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
67+
raise TypeError(f"{idx} is not a pymatgen Structure object: {struct}")
6768
for name in "energies forces stresses magmoms structure_ids".split():
6869
labels = locals()[name]
6970
if labels is not None and len(labels) != len(structures):
@@ -80,7 +81,7 @@ def __init__(
8081
self.keys = np.arange(len(structures))
8182
if shuffle:
8283
random.shuffle(self.keys)
83-
print(f"{len(structures)} structures imported")
84+
print(f"{type(self).__name__} imported {len(structures):,} structures")
8485
self.graph_converter = graph_converter or CrystalGraphConverter(
8586
atom_graph_cutoff=6, bond_graph_cutoff=3
8687
)
@@ -91,11 +92,12 @@ def __init__(
9192
def from_vasp(
9293
cls,
9394
file_root: str,
95+
*,
9496
check_electronic_convergence: bool = True,
9597
save_path: str | None = None,
9698
graph_converter: CrystalGraphConverter | None = None,
9799
shuffle: bool = True,
98-
):
100+
) -> StructureData:
99101
"""Parse VASP output files into structures and labels and feed into the dataset.
100102
101103
Args:
@@ -196,6 +198,7 @@ class CIFData(Dataset):
196198
def __init__(
197199
self,
198200
cif_path: str,
201+
*,
199202
labels: str | dict = "labels.json",
200203
targets: TrainTask = "efsm",
201204
graph_converter: CrystalGraphConverter | None = None,
@@ -311,6 +314,7 @@ class GraphData(Dataset):
311314
def __init__(
312315
self,
313316
graph_path: str,
317+
*,
314318
labels: str | dict = "labels.json",
315319
targets: TrainTask = "efsm",
316320
exclude: str | list | None = None,
@@ -429,6 +433,7 @@ def get_train_val_test_loader(
429433
self,
430434
train_ratio: float = 0.8,
431435
val_ratio: float = 0.1,
436+
*,
432437
train_key: list[str] | None = None,
433438
val_key: list[str] | None = None,
434439
test_key: list[str] | None = None,
@@ -541,6 +546,7 @@ def __init__(
541546
self,
542547
data: str | dict,
543548
graph_converter: CrystalGraphConverter,
549+
*,
544550
targets: TrainTask = "efsm",
545551
energy_key: str = "energy_per_atom",
546552
force_key: str = "force",
@@ -580,14 +586,14 @@ def __init__(
580586
elif isinstance(data, dict):
581587
self.data = data
582588
else:
583-
raise ValueError(f"data must be JSON path or dictionary, got {type(data)}")
589+
raise TypeError(f"data must be JSON path or dictionary, got {type(data)}")
584590

585591
self.keys = [
586592
(mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct
587593
]
588594
if shuffle:
589595
random.shuffle(self.keys)
590-
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
596+
print(f"{len(self.data)} MP IDs, {len(self)} structures imported")
591597
self.graph_converter = graph_converter
592598
self.energy_key = energy_key
593599
self.force_key = force_key
@@ -602,7 +608,7 @@ def __len__(self) -> int:
602608
return len(self.keys)
603609

604610
@functools.cache # Cache loaded structures
605-
def __getitem__(self, idx):
611+
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
606612
"""Get one item in the dataset.
607613
608614
Returns:
@@ -654,6 +660,7 @@ def get_train_val_test_loader(
654660
self,
655661
train_ratio: float = 0.8,
656662
val_ratio: float = 0.1,
663+
*,
657664
train_key: list[str] | None = None,
658665
val_key: list[str] | None = None,
659666
test_key: list[str] | None = None,
@@ -747,7 +754,7 @@ def get_train_val_test_loader(
747754
return train_loader, val_loader, test_loader
748755

749756

750-
def collate_graphs(batch_data: list):
757+
def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tensor]]:
751758
"""Collate of list of (graph, target) into batch data.
752759
753760
Args:
@@ -777,13 +784,14 @@ def collate_graphs(batch_data: list):
777784

778785
def get_train_val_test_loader(
779786
dataset: Dataset,
787+
*,
780788
batch_size: int = 64,
781789
train_ratio: float = 0.8,
782790
val_ratio: float = 0.1,
783791
return_test: bool = True,
784792
num_workers: int = 0,
785793
pin_memory: bool = True,
786-
):
794+
) -> tuple[DataLoader, DataLoader, DataLoader]:
787795
"""Randomly partition a dataset into train, val, test loaders.
788796
789797
Args:
@@ -842,7 +850,9 @@ def get_train_val_test_loader(
842850
return train_loader, val_loader
843851

844852

845-
def get_loader(dataset, batch_size=64, num_workers=0, pin_memory=True):
853+
def get_loader(
854+
dataset, *, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = True
855+
) -> DataLoader:
846856
"""Get a dataloader from a dataset.
847857
848858
Args:

chgnet/graph/converter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class CrystalGraphConverter(nn.Module):
3333

3434
def __init__(
3535
self,
36+
*,
3637
atom_graph_cutoff: float = 6,
3738
bond_graph_cutoff: float = 3,
3839
algorithm: Literal["legacy", "fast"] = "fast",
@@ -274,7 +275,6 @@ def set_isolated_atom_response(
274275
None
275276
"""
276277
self.on_isolated_atoms = on_isolated_atoms
277-
return
278278

279279
def as_dict(self) -> dict[str, str | float]:
280280
"""Save the args of the graph converter."""

chgnet/graph/crystalgraph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __repr__(self) -> str:
183183
)
184184

185185
@property
186-
def num_isolated_atoms(self):
186+
def num_isolated_atoms(self) -> int:
187187
"""Number of isolated atoms given the atom graph cutoff
188188
Isolated atoms are disconnected nodes in the atom graph
189189
that will not get updated in CHGNet.

chgnet/graph/graph.py

+28-24
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class UndirectedEdge(Edge):
6666

6767
__hash__ = Edge.__hash__
6868

69-
def __eq__(self, other):
69+
def __eq__(self, other: object) -> bool:
7070
"""Check if two undirected edges are equal."""
7171
return set(self.nodes) == set(other.nodes) and self.info == other.info
7272

@@ -178,16 +178,16 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None:
178178
):
179179
# There is an undirected edge with similar length and only one of
180180
# the directed edges associated has been added
181-
added_DE = self.directed_edges_list[
181+
added_dir_edge = self.directed_edges_list[
182182
undirected_edge.info["directed_edge_index"][0]
183183
]
184184

185185
# See if the DE that's associated to this UDE
186186
# is the reverse of our DE
187-
if added_DE == this_directed_edge:
187+
if added_dir_edge == this_directed_edge:
188188
# Add UDE index to this DE
189189
this_directed_edge.info["undirected_edge_index"] = (
190-
added_DE.info["undirected_edge_index"]
190+
added_dir_edge.info["undirected_edge_index"]
191191
)
192192

193193
# At the center node, draw edge with this DE
@@ -217,7 +217,7 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None:
217217
self.nodes[center_index].add_neighbor(neighbor_index, this_directed_edge)
218218
self.directed_edges_list.append(this_directed_edge)
219219

220-
def adjacency_list(self):
220+
def adjacency_list(self) -> tuple[list[list[int]], list[int]]:
221221
"""Get the adjacency list
222222
Return:
223223
graph: the adjacency list
@@ -240,7 +240,7 @@ def adjacency_list(self):
240240
]
241241
return graph, directed2undirected
242242

243-
def line_graph_adjacency_list(self, cutoff):
243+
def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]:
244244
"""Get the line graph adjacency list.
245245
246246
Args:
@@ -264,11 +264,12 @@ def line_graph_adjacency_list(self, cutoff):
264264
a list of length = num_undirected_edge that
265265
maps the undirected edge index to one of its directed edges indices
266266
"""
267-
assert len(self.directed_edges_list) == 2 * len(self.undirected_edges_list), (
268-
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 * "
269-
f"number of undirected edges={len(self.directed_edges_list)}!"
270-
f"This indicates directed edges are not complete"
271-
)
267+
if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list):
268+
raise ValueError(
269+
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 "
270+
f"* number of undirected edges={len(self.directed_edges_list)}!"
271+
f"This indicates directed edges are not complete"
272+
)
272273
line_graph = []
273274
undirected2directed = []
274275

@@ -285,39 +286,42 @@ def line_graph_adjacency_list(self, cutoff):
285286
# if encountered exception,
286287
# it means after Atom_Graph creation, the UDE has only 1 DE associated
287288
# This exception is not encountered from the develop team's experience
288-
assert len(u_edge.info["directed_edge_index"]) == 2, (
289-
"Did not find 2 Directed_edges !!!"
290-
f"undirected edge {u_edge} has:"
291-
f"edge.info['directed_edge_index'] = "
292-
f"{u_edge.info['directed_edge_index']}"
293-
f"len directed_edges_list = {len(self.directed_edges_list)}"
294-
f"len undirected_edges_list = {len(self.undirected_edges_list)}"
295-
)
289+
if len(u_edge.info["directed_edge_index"]) != 2:
290+
raise ValueError(
291+
"Did not find 2 Directed_edges !!!"
292+
f"undirected edge {u_edge} has:"
293+
f"edge.info['directed_edge_index'] = "
294+
f"{u_edge.info['directed_edge_index']}"
295+
f"len directed_edges_list = {len(self.directed_edges_list)}"
296+
f"len undirected_edges_list = {len(self.undirected_edges_list)}"
297+
)
296298

297299
# This UDE is valid to be considered as a node in Bond_Graph
298300

299301
# Get the two ends (centers) and the two DE associated with this UDE
300302
# DE1 should have center=center1 and DE2 should have center=center2
301303
# We will need to find directed edges with center = center1
302304
# and create angles with DE1, then do the same for center2 and DE2
303-
for center, DE in zip(u_edge.nodes, u_edge.info["directed_edge_index"]):
305+
for center, dir_edge in zip(
306+
u_edge.nodes, u_edge.info["directed_edge_index"]
307+
):
304308
for directed_edges in self.nodes[center].neighbors.values():
305309
for directed_edge in directed_edges:
306-
if directed_edge.index == DE:
310+
if directed_edge.index == dir_edge:
307311
continue
308312
if directed_edge.info["distance"] < cutoff:
309313
line_graph.append(
310314
[
311315
center,
312316
u_edge.index,
313-
DE,
317+
dir_edge,
314318
directed_edge.info["undirected_edge_index"],
315319
directed_edge.index,
316320
]
317321
)
318322
return line_graph, undirected2directed
319323

320-
def undirected2directed(self):
324+
def undirected2directed(self) -> list[int]:
321325
"""The index map from undirected_edge index to one of its directed_edge
322326
index.
323327
"""
@@ -326,7 +330,7 @@ def undirected2directed(self):
326330
for undirected_edge in self.undirected_edges_list
327331
]
328332

329-
def as_dict(self):
333+
def as_dict(self) -> dict:
330334
"""Return dictionary serialization of a Graph."""
331335
return {
332336
"nodes": self.nodes,

chgnet/model/basis.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class Fourier(nn.Module):
99
"""Fourier Expansion for angle features."""
1010

11-
def __init__(self, order: int = 5, learnable: bool = False) -> None:
11+
def __init__(self, *, order: int = 5, learnable: bool = False) -> None:
1212
"""Initialize the Fourier expansion.
1313
1414
Args:
@@ -47,6 +47,7 @@ class RadialBessel(torch.nn.Module):
4747

4848
def __init__(
4949
self,
50+
*,
5051
num_radial: int = 9,
5152
cutoff: float = 5,
5253
learnable: bool = False,
@@ -90,7 +91,7 @@ def __init__(
9091
self.smooth_cutoff = None
9192

9293
def forward(
93-
self, dist: Tensor, return_smooth_factor: bool = False
94+
self, dist: Tensor, *, return_smooth_factor: bool = False
9495
) -> Tensor | tuple[Tensor, Tensor]:
9596
"""Apply Bessel expansion to a feature Tensor.
9697
@@ -122,8 +123,8 @@ class GaussianExpansion(nn.Module):
122123

123124
def __init__(
124125
self,
125-
min: float = 0,
126-
max: float = 5,
126+
min: float = 0, # noqa: A002
127+
max: float = 5, # noqa: A002
127128
step: float = 0.5,
128129
var: float | None = None,
129130
) -> None:
@@ -137,8 +138,10 @@ def __init__(
137138
var (float): variance in gaussian filter, default to step
138139
"""
139140
super().__init__()
140-
assert min < max
141-
assert max - min > step
141+
if min >= max:
142+
raise ValueError(f"{min=} must be less than {max=}")
143+
if max - min <= step:
144+
raise ValueError(f"{max - min=} must be greater than {step=}")
142145
self.register_buffer("gaussian_centers", torch.arange(min, max + step, step))
143146
self.var = var or step
144147
if self.var <= 0:

0 commit comments

Comments
 (0)