1
1
from __future__ import annotations
2
2
3
+ import random
4
+
3
5
import numpy as np
4
6
import pytest
5
7
import torch
17
19
@pytest .fixture ()
18
20
def structure_data () -> StructureData :
19
21
"""Create a graph with 3 nodes and 3 directed edges."""
22
+ random .seed (42 )
20
23
structures , energies , forces , stresses , magmoms , structure_ids = (
21
24
[],
22
25
[],
@@ -25,15 +28,15 @@ def structure_data() -> StructureData:
25
28
[],
26
29
[],
27
30
)
28
- for _ in range (100 ):
31
+ for index in range (100 ):
29
32
struct = NaCl .copy ()
30
33
struct .perturb (0.1 )
31
34
structures .append (struct )
32
35
energies .append (np .random .random (1 ))
33
36
forces .append (np .random .random ([2 , 3 ]))
34
37
stresses .append (np .random .random ([3 , 3 ]))
35
38
magmoms .append (np .random .random ([2 , 1 ]))
36
- structure_ids .append ("tmp_id" )
39
+ structure_ids .append (index )
37
40
return StructureData (
38
41
structures = structures ,
39
42
energies = energies ,
@@ -47,7 +50,8 @@ def structure_data() -> StructureData:
47
50
def test_structure_data (structure_data : StructureData ) -> None :
48
51
get_one = structure_data [0 ]
49
52
assert isinstance (get_one [0 ], CrystalGraph )
50
- assert get_one [0 ].mp_id == "tmp_id"
53
+ assert isinstance (get_one [0 ].mp_id , int )
54
+ assert get_one [0 ].mp_id == 42
51
55
assert isinstance (get_one [1 ], dict )
52
56
assert isinstance (get_one [1 ]["e" ], torch .Tensor )
53
57
assert isinstance (get_one [1 ]["f" ], torch .Tensor )
@@ -85,3 +89,36 @@ def test_structure_data_inconsistent_length():
85
89
== f"Inconsistent number of structures and labels: { len (structures )= } , "
86
90
f"{ len (forces )= } "
87
91
)
92
+
93
+
94
+ def test_dataset_no_shuffling ():
95
+ structures , energies , forces , stresses , magmoms , structure_ids = (
96
+ [],
97
+ [],
98
+ [],
99
+ [],
100
+ [],
101
+ [],
102
+ )
103
+ for index in range (100 ):
104
+ struct = NaCl .copy ()
105
+ struct .perturb (0.1 )
106
+ structures .append (struct )
107
+ energies .append (np .random .random (1 ))
108
+ forces .append (np .random .random ([2 , 3 ]))
109
+ stresses .append (np .random .random ([3 , 3 ]))
110
+ magmoms .append (np .random .random ([2 , 1 ]))
111
+ structure_ids .append (index )
112
+ structure_data = StructureData (
113
+ structures = structures ,
114
+ energies = energies ,
115
+ forces = forces ,
116
+ stresses = stresses ,
117
+ magmoms = magmoms ,
118
+ structure_ids = structure_ids ,
119
+ shuffle = False ,
120
+ )
121
+
122
+ assert structure_data [0 ][0 ].mp_id == 0
123
+ assert structure_data [1 ][0 ].mp_id == 1
124
+ assert structure_data [2 ][0 ].mp_id == 2
0 commit comments