@@ -33,6 +33,7 @@ def __init__(
33
33
structures : list [Structure ],
34
34
energies : list [float ],
35
35
forces : list [Sequence [Sequence [float ]]],
36
+ * ,
36
37
stresses : list [Sequence [Sequence [float ]]] | None = None ,
37
38
magmoms : list [Sequence [Sequence [float ]]] | None = None ,
38
39
structure_ids : list | None = None ,
@@ -63,7 +64,7 @@ def __init__(
63
64
"""
64
65
for idx , struct in enumerate (structures ):
65
66
if not isinstance (struct , Structure ):
66
- raise ValueError (f"{ idx } is not a pymatgen Structure object: { struct } " )
67
+ raise TypeError (f"{ idx } is not a pymatgen Structure object: { struct } " )
67
68
for name in "energies forces stresses magmoms structure_ids" .split ():
68
69
labels = locals ()[name ]
69
70
if labels is not None and len (labels ) != len (structures ):
@@ -80,7 +81,7 @@ def __init__(
80
81
self .keys = np .arange (len (structures ))
81
82
if shuffle :
82
83
random .shuffle (self .keys )
83
- print (f"{ len (structures )} structures imported " )
84
+ print (f"{ type ( self ). __name__ } imported { len (structures ):, } structures" )
84
85
self .graph_converter = graph_converter or CrystalGraphConverter (
85
86
atom_graph_cutoff = 6 , bond_graph_cutoff = 3
86
87
)
@@ -91,11 +92,12 @@ def __init__(
91
92
def from_vasp (
92
93
cls ,
93
94
file_root : str ,
95
+ * ,
94
96
check_electronic_convergence : bool = True ,
95
97
save_path : str | None = None ,
96
98
graph_converter : CrystalGraphConverter | None = None ,
97
99
shuffle : bool = True ,
98
- ):
100
+ ) -> StructureData :
99
101
"""Parse VASP output files into structures and labels and feed into the dataset.
100
102
101
103
Args:
@@ -196,6 +198,7 @@ class CIFData(Dataset):
196
198
def __init__ (
197
199
self ,
198
200
cif_path : str ,
201
+ * ,
199
202
labels : str | dict = "labels.json" ,
200
203
targets : TrainTask = "efsm" ,
201
204
graph_converter : CrystalGraphConverter | None = None ,
@@ -311,6 +314,7 @@ class GraphData(Dataset):
311
314
def __init__ (
312
315
self ,
313
316
graph_path : str ,
317
+ * ,
314
318
labels : str | dict = "labels.json" ,
315
319
targets : TrainTask = "efsm" ,
316
320
exclude : str | list | None = None ,
@@ -429,6 +433,7 @@ def get_train_val_test_loader(
429
433
self ,
430
434
train_ratio : float = 0.8 ,
431
435
val_ratio : float = 0.1 ,
436
+ * ,
432
437
train_key : list [str ] | None = None ,
433
438
val_key : list [str ] | None = None ,
434
439
test_key : list [str ] | None = None ,
@@ -541,6 +546,7 @@ def __init__(
541
546
self ,
542
547
data : str | dict ,
543
548
graph_converter : CrystalGraphConverter ,
549
+ * ,
544
550
targets : TrainTask = "efsm" ,
545
551
energy_key : str = "energy_per_atom" ,
546
552
force_key : str = "force" ,
@@ -580,14 +586,14 @@ def __init__(
580
586
elif isinstance (data , dict ):
581
587
self .data = data
582
588
else :
583
- raise ValueError (f"data must be JSON path or dictionary, got { type (data )} " )
589
+ raise TypeError (f"data must be JSON path or dictionary, got { type (data )} " )
584
590
585
591
self .keys = [
586
592
(mp_id , graph_id ) for mp_id , dct in self .data .items () for graph_id in dct
587
593
]
588
594
if shuffle :
589
595
random .shuffle (self .keys )
590
- print (f"{ len (self .data )} mp_ids , { len (self )} structures imported" )
596
+ print (f"{ len (self .data )} MP IDs , { len (self )} structures imported" )
591
597
self .graph_converter = graph_converter
592
598
self .energy_key = energy_key
593
599
self .force_key = force_key
@@ -602,7 +608,7 @@ def __len__(self) -> int:
602
608
return len (self .keys )
603
609
604
610
@functools .cache # Cache loaded structures
605
- def __getitem__ (self , idx ) :
611
+ def __getitem__ (self , idx : int ) -> tuple [ CrystalGraph , dict [ str , Tensor ]] :
606
612
"""Get one item in the dataset.
607
613
608
614
Returns:
@@ -654,6 +660,7 @@ def get_train_val_test_loader(
654
660
self ,
655
661
train_ratio : float = 0.8 ,
656
662
val_ratio : float = 0.1 ,
663
+ * ,
657
664
train_key : list [str ] | None = None ,
658
665
val_key : list [str ] | None = None ,
659
666
test_key : list [str ] | None = None ,
@@ -747,7 +754,7 @@ def get_train_val_test_loader(
747
754
return train_loader , val_loader , test_loader
748
755
749
756
750
- def collate_graphs (batch_data : list ):
757
+ def collate_graphs (batch_data : list ) -> tuple [ list [ CrystalGraph ], dict [ str , Tensor ]] :
751
758
"""Collate of list of (graph, target) into batch data.
752
759
753
760
Args:
@@ -777,13 +784,14 @@ def collate_graphs(batch_data: list):
777
784
778
785
def get_train_val_test_loader (
779
786
dataset : Dataset ,
787
+ * ,
780
788
batch_size : int = 64 ,
781
789
train_ratio : float = 0.8 ,
782
790
val_ratio : float = 0.1 ,
783
791
return_test : bool = True ,
784
792
num_workers : int = 0 ,
785
793
pin_memory : bool = True ,
786
- ):
794
+ ) -> tuple [ DataLoader , DataLoader , DataLoader ] :
787
795
"""Randomly partition a dataset into train, val, test loaders.
788
796
789
797
Args:
@@ -842,7 +850,9 @@ def get_train_val_test_loader(
842
850
return train_loader , val_loader
843
851
844
852
845
- def get_loader (dataset , batch_size = 64 , num_workers = 0 , pin_memory = True ):
853
+ def get_loader (
854
+ dataset , * , batch_size : int = 64 , num_workers : int = 0 , pin_memory : bool = True
855
+ ) -> DataLoader :
846
856
"""Get a dataloader from a dataset.
847
857
848
858
Args:
0 commit comments