Skip to content

Commit b819ef5

Browse files
committed
Support for edge case of structures with all isolated atoms
1 parent b1bc8a2 commit b819ef5

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

chgnet/model/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,9 @@ def from_graphs(
818818

819819
# Bonds
820820
atom_cart_coords = graph.atom_frac_coord @ lattice
821+
if graph.atom_graph.dim() == 1:
822+
# This is to avoid structure with all atoms isolated
823+
graph.atom_graph = graph.atom_graph.reshape(0, 2)
821824
bond_basis_ag, bond_basis_bg, bond_vectors = bond_basis_expansion(
822825
center=atom_cart_coords[graph.atom_graph[:, 0]],
823826
neighbor=atom_cart_coords[graph.atom_graph[:, 1]],

tests/test_model.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import pytest
7-
from pymatgen.core import Structure
7+
from pymatgen.core import Lattice, Structure
88

99
from chgnet import ROOT
1010
from chgnet.graph import CrystalGraphConverter
@@ -207,6 +207,18 @@ def test_predict_batched_structures() -> None:
207207
)
208208

209209

210+
def test_predict_isolated_structures() -> None:
211+
lattice10 = Lattice.cubic(10)
212+
lattice20 = Lattice.cubic(20)
213+
positions = [[0, 0, 0], [0.5, 0.5, 0.5]]
214+
215+
# Create the structure
216+
model.graph_converter.set_isolated_atom_response("ignore")
217+
prediction10 = model.predict_structure(Structure(lattice10, ["H", "H"], positions))
218+
prediction20 = model.predict_structure(Structure(lattice20, ["H", "H"], positions))
219+
assert prediction10["e"] == pytest.approx(prediction20["e"], rel=1e-5, abs=1e-5)
220+
221+
210222
def test_as_to_from_dict() -> None:
211223
dct = model.as_dict()
212224
assert {*dct} == {"model_args", "state_dict"}

0 commit comments

Comments
 (0)