Skip to content

Commit 9e4bd5e

Browse files
committed
added test for dataset shuffling
1 parent 72f9108 commit 9e4bd5e

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

tests/test_dataset.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import random
4+
35
import numpy as np
46
import pytest
57
import torch
@@ -17,6 +19,7 @@
1719
@pytest.fixture()
1820
def structure_data() -> StructureData:
1921
"""Create a graph with 3 nodes and 3 directed edges."""
22+
random.seed(42)
2023
structures, energies, forces, stresses, magmoms, structure_ids = (
2124
[],
2225
[],
@@ -25,15 +28,15 @@ def structure_data() -> StructureData:
2528
[],
2629
[],
2730
)
28-
for _ in range(100):
31+
for index in range(100):
2932
struct = NaCl.copy()
3033
struct.perturb(0.1)
3134
structures.append(struct)
3235
energies.append(np.random.random(1))
3336
forces.append(np.random.random([2, 3]))
3437
stresses.append(np.random.random([3, 3]))
3538
magmoms.append(np.random.random([2, 1]))
36-
structure_ids.append("tmp_id")
39+
structure_ids.append(index)
3740
return StructureData(
3841
structures=structures,
3942
energies=energies,
@@ -47,7 +50,8 @@ def structure_data() -> StructureData:
4750
def test_structure_data(structure_data: StructureData) -> None:
4851
get_one = structure_data[0]
4952
assert isinstance(get_one[0], CrystalGraph)
50-
assert get_one[0].mp_id == "tmp_id"
53+
assert isinstance(get_one[0].mp_id, int)
54+
assert get_one[0].mp_id == 42
5155
assert isinstance(get_one[1], dict)
5256
assert isinstance(get_one[1]["e"], torch.Tensor)
5357
assert isinstance(get_one[1]["f"], torch.Tensor)
@@ -85,3 +89,36 @@ def test_structure_data_inconsistent_length():
8589
== f"Inconsistent number of structures and labels: {len(structures)=}, "
8690
f"{len(forces)=}"
8791
)
92+
93+
94+
def test_dataset_no_shuffling():
95+
structures, energies, forces, stresses, magmoms, structure_ids = (
96+
[],
97+
[],
98+
[],
99+
[],
100+
[],
101+
[],
102+
)
103+
for index in range(100):
104+
struct = NaCl.copy()
105+
struct.perturb(0.1)
106+
structures.append(struct)
107+
energies.append(np.random.random(1))
108+
forces.append(np.random.random([2, 3]))
109+
stresses.append(np.random.random([3, 3]))
110+
magmoms.append(np.random.random([2, 1]))
111+
structure_ids.append(index)
112+
structure_data = StructureData(
113+
structures=structures,
114+
energies=energies,
115+
forces=forces,
116+
stresses=stresses,
117+
magmoms=magmoms,
118+
structure_ids=structure_ids,
119+
shuffle=False,
120+
)
121+
122+
assert structure_data[0][0].mp_id == 0
123+
assert structure_data[1][0].mp_id == 1
124+
assert structure_data[2][0].mp_id == 2

0 commit comments

Comments
 (0)