Skip to content

Commit 50f7361

Browse files
authored
Merge pull request #210 from pyiron/symmetrize_tensor
Symmetrize tensor
2 parents cbb9558 + 3b31807 commit 50f7361

File tree

2 files changed

+161
-8
lines changed

2 files changed

+161
-8
lines changed

structuretoolkit/analyse/symmetry.py

+99
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import spglib
1010
from scipy.spatial import cKDTree
1111
from typing import Optional
12+
import string
13+
from functools import cached_property
1214

1315
import structuretoolkit.common.helper
1416
from structuretoolkit.common.error import SymmetryError
@@ -230,6 +232,33 @@ def symmetrize_vectors(
230232
np.einsum("ijk->jki", v_reshaped)[self.permutations],
231233
).reshape(np.shape(vectors)) / len(self["rotations"])
232234

235+
def symmetrize_tensor(self, tensor: np.ndarray) -> np.ndarray:
236+
"""
237+
Symmetrization of any tensor. The tensor is defined by a matrix with a
238+
shape of `n * (n_atoms, 3)`. For example, if the structure has 100
239+
atoms, the vector can have a shape of (100, 3), (100, 3, 100, 3),
240+
(100, 3, 100, 3, 100, 3) etc. Additionally, you can also have an array
241+
of tensors, i.e. in this example you can have a shape like (4, 100, 3)
242+
or (2, 100, 3, 100, 3). When the shape is (n, n_atoms, 3), the function
243+
works in the same way as `symmetrize_vectors`, which might be somewhat
244+
faster.
245+
246+
This function can be useful for the symmetrization of Hessian tensors,
247+
or any other tensors which should be symmetric.
248+
249+
Args:
250+
tensors (ndarray): n * (n_atoms, 3) tensor to symmetrize
251+
252+
Returns
253+
(np.ndarray) symmetrized tensor of the same shape
254+
"""
255+
return _SymmetrizeTensor(
256+
tensor=tensor,
257+
structure=self._structure,
258+
rotations=self.rotations,
259+
permutations=self.permutations,
260+
).result
261+
233262
def _get_spglib_cell(
234263
self, use_elements: Optional[bool] = None, use_magmoms: Optional[bool] = None
235264
) -> tuple:
@@ -389,3 +418,73 @@ def get_ir_reciprocal_mesh(
389418
if mesh is None:
390419
raise SymmetryError(spglib.spglib.spglib_error.message)
391420
return mesh
421+
422+
423+
class _SymmetrizeTensor:
424+
def __init__(self, tensor, structure, rotations, permutations):
425+
self._tensor = np.array(tensor)
426+
self._structure = structure
427+
self._rotations = rotations
428+
self._permutations = permutations
429+
430+
@cached_property
431+
def order(self):
432+
order = len(self._tensor.shape) // 2
433+
if self._tensor.shape[-2 * order :] != order * self._structure.positions.shape:
434+
raise ValueError(
435+
"Tensor must have a shape of a multiple of (n_atoms, 3). See"
436+
" docstring for more info"
437+
)
438+
return order
439+
440+
@cached_property
441+
def ij(self):
442+
return string.ascii_lowercase[: 2 * self.order]
443+
444+
@property
445+
def IJ(self):
446+
return self.ij.upper()
447+
448+
@property
449+
def ij_reorder(self):
450+
return "".join(
451+
[self.ij[ii] for ii in np.arange(2 * self.order).reshape(-1, 2).T.flatten()]
452+
)
453+
454+
@property
455+
def IJ_reorder(self):
456+
return "".join(
457+
[self.IJ[ii] for ii in np.arange(2 * self.order).reshape(2, -1).T.flatten()]
458+
)
459+
460+
@cached_property
461+
def t_t(self):
462+
return np.einsum("...{}->{}...".format(self.ij, self.ij_reorder), self._tensor)
463+
464+
@cached_property
465+
def str_einsum(self):
466+
return (
467+
",".join(
468+
[I + i for i, I in zip(self.ij[-self.order :], self.IJ[-self.order :])]
469+
)
470+
+ ","
471+
+ self.IJ[: self.order]
472+
+ self.ij[self.order :]
473+
+ "...->..."
474+
+ self.IJ_reorder
475+
)
476+
477+
@property
478+
def result(self):
479+
return np.mean(
480+
[
481+
np.einsum(
482+
self.str_einsum,
483+
*self.order * (rot,),
484+
self.t_t[tuple(np.meshgrid(*self.order * (perm,), indexing="ij"))],
485+
optimize=True,
486+
)
487+
for rot, perm in zip(self._rotations, self._permutations)
488+
],
489+
axis=0,
490+
)

tests/test_symmetry.py

+62-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
try:
2020
import spglib
21+
from structuretoolkit.analyse.symmetry import _SymmetrizeTensor
2122

2223
skip_spglib_test = False
2324
except ImportError:
@@ -108,20 +109,22 @@ def test_get_symmetry(self):
108109
"AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True
109110
)
110111
v = np.random.rand(6).reshape(-1, 3)
112+
sym = stk.analyse.get_symmetry(structure=Al)
111113
self.assertAlmostEqual(
112-
np.linalg.norm(
113-
stk.analyse.get_symmetry(structure=Al).symmetrize_vectors(v)
114-
),
114+
np.linalg.norm(sym.symmetrize_vectors(v)),
115115
0,
116116
)
117117
vv = np.random.rand(12).reshape(2, 2, 3)
118-
for vvv in stk.analyse.get_symmetry(structure=Al).symmetrize_vectors(vv):
118+
for vvv in sym.symmetrize_vectors(vv):
119119
self.assertAlmostEqual(np.linalg.norm(vvv), 0)
120120
Al.positions[0, 0] += 0.01
121-
w = stk.analyse.get_symmetry(structure=Al).symmetrize_vectors(v)
121+
w = sym.symmetrize_vectors(v)
122122
self.assertAlmostEqual(
123123
np.absolute(w[:, 0]).sum(), np.linalg.norm(w, axis=-1).sum()
124124
)
125+
self.assertAlmostEqual(
126+
np.linalg.norm(sym.symmetrize_vectors(v) - sym.symmetrize_tensor(v)), 0
127+
)
125128

126129
def test_get_symmetry_dataset(self):
127130
cell = 2.2 * np.identity(3)
@@ -155,7 +158,7 @@ def test_get_primitive_cell(self):
155158
)
156159

157160
def test_get_primitive_cell_hex(self):
158-
elements = ['Fe', 'Fe', 'Fe', 'Fe', 'O', 'O', 'O', 'O', 'O', 'O']
161+
elements = ["Fe", "Fe", "Fe", "Fe", "O", "O", "O", "O", "O", "O"]
159162
positions = [
160163
[0.0, 0.0, 4.89],
161164
[0.0, 0.0, 11.78],
@@ -174,8 +177,7 @@ def test_get_primitive_cell_hex(self):
174177
sym = stk.analyse.get_symmetry(structure=structure_repeat)
175178
structure_prim_base = sym.get_primitive_cell()
176179
self.assertEqual(
177-
structure_prim_base.get_chemical_symbols(),
178-
structure.get_chemical_symbols()
180+
structure_prim_base.get_chemical_symbols(), structure.get_chemical_symbols()
179181
)
180182

181183
def test_get_equivalent_points(self):
@@ -284,5 +286,57 @@ def test_error(self):
284286
stk.analyse.get_symmetry(structure=structure)
285287

286288

289+
@unittest.skipIf(
290+
skip_spglib_test, "spglib is not installed, so the spglib tests are skipped."
291+
)
292+
class TestSymmetrizeTensors(unittest.TestCase):
293+
@classmethod
294+
def setUpClass(cls):
295+
cls.structure = bulk("Al", cubic=True, a=4.0).repeat(2)
296+
cls.dataset = {
297+
"structure": cls.structure,
298+
"rotations": np.eye(3),
299+
"permutations": np.arange(len(cls.structure)),
300+
}
301+
302+
def test_order(self):
303+
with self.assertRaises(ValueError):
304+
_SymmetrizeTensor(
305+
tensor=np.array([1]), **self.dataset
306+
).order
307+
self.assertEqual(
308+
_SymmetrizeTensor(
309+
tensor=np.random.randn(*self.structure.positions.shape), **self.dataset
310+
).order,
311+
1,
312+
)
313+
self.assertEqual(
314+
_SymmetrizeTensor(
315+
tensor=np.random.randn(*2 * self.structure.positions.shape),
316+
**self.dataset,
317+
).order,
318+
2,
319+
)
320+
321+
def test_indexing(self):
322+
st = _SymmetrizeTensor(
323+
tensor=np.random.randn(*2 * self.structure.positions.shape), **self.dataset
324+
)
325+
self.assertEqual(st.ij, "abcd")
326+
self.assertEqual(st.ij_reorder, "acbd")
327+
self.assertEqual(st.IJ, "ABCD")
328+
self.assertEqual(st.IJ_reorder, "ACBD")
329+
330+
def test_str_einsum(self):
331+
st = _SymmetrizeTensor(
332+
tensor=np.random.randn(*2 * self.structure.positions.shape), **self.dataset
333+
)
334+
self.assertEqual(st.str_einsum, "Cc,Dd,ABcd...->...ACBD")
335+
st = _SymmetrizeTensor(
336+
tensor=np.random.randn(*self.structure.positions.shape), **self.dataset
337+
)
338+
self.assertEqual(st.str_einsum, "Bb,Ab...->...AB")
339+
340+
287341
if __name__ == "__main__":
288342
unittest.main()

0 commit comments

Comments
 (0)