18
18
19
19
try :
20
20
import spglib
21
+ from structuretoolkit .analyse .symmetry import _SymmetrizeTensor
21
22
22
23
skip_spglib_test = False
23
24
except ImportError :
@@ -108,20 +109,22 @@ def test_get_symmetry(self):
108
109
"AlAl" , scaled_positions = [(0 , 0 , 0 ), (0.5 , 0.5 , 0.5 )], cell = cell , pbc = True
109
110
)
110
111
v = np .random .rand (6 ).reshape (- 1 , 3 )
112
+ sym = stk .analyse .get_symmetry (structure = Al )
111
113
self .assertAlmostEqual (
112
- np .linalg .norm (
113
- stk .analyse .get_symmetry (structure = Al ).symmetrize_vectors (v )
114
- ),
114
+ np .linalg .norm (sym .symmetrize_vectors (v )),
115
115
0 ,
116
116
)
117
117
vv = np .random .rand (12 ).reshape (2 , 2 , 3 )
118
- for vvv in stk . analyse . get_symmetry ( structure = Al ) .symmetrize_vectors (vv ):
118
+ for vvv in sym .symmetrize_vectors (vv ):
119
119
self .assertAlmostEqual (np .linalg .norm (vvv ), 0 )
120
120
Al .positions [0 , 0 ] += 0.01
121
- w = stk . analyse . get_symmetry ( structure = Al ) .symmetrize_vectors (v )
121
+ w = sym .symmetrize_vectors (v )
122
122
self .assertAlmostEqual (
123
123
np .absolute (w [:, 0 ]).sum (), np .linalg .norm (w , axis = - 1 ).sum ()
124
124
)
125
+ self .assertAlmostEqual (
126
+ np .linalg .norm (sym .symmetrize_vectors (v ) - sym .symmetrize_tensor (v )), 0
127
+ )
125
128
126
129
def test_get_symmetry_dataset (self ):
127
130
cell = 2.2 * np .identity (3 )
@@ -155,7 +158,7 @@ def test_get_primitive_cell(self):
155
158
)
156
159
157
160
def test_get_primitive_cell_hex (self ):
158
- elements = ['Fe' , 'Fe' , 'Fe' , 'Fe' , 'O' , 'O' , 'O' , 'O' , 'O' , 'O' ]
161
+ elements = ["Fe" , "Fe" , "Fe" , "Fe" , "O" , "O" , "O" , "O" , "O" , "O" ]
159
162
positions = [
160
163
[0.0 , 0.0 , 4.89 ],
161
164
[0.0 , 0.0 , 11.78 ],
@@ -174,8 +177,7 @@ def test_get_primitive_cell_hex(self):
174
177
sym = stk .analyse .get_symmetry (structure = structure_repeat )
175
178
structure_prim_base = sym .get_primitive_cell ()
176
179
self .assertEqual (
177
- structure_prim_base .get_chemical_symbols (),
178
- structure .get_chemical_symbols ()
180
+ structure_prim_base .get_chemical_symbols (), structure .get_chemical_symbols ()
179
181
)
180
182
181
183
def test_get_equivalent_points (self ):
@@ -284,5 +286,57 @@ def test_error(self):
284
286
stk .analyse .get_symmetry (structure = structure )
285
287
286
288
289
+ @unittest .skipIf (
290
+ skip_spglib_test , "spglib is not installed, so the spglib tests are skipped."
291
+ )
292
+ class TestSymmetrizeTensors (unittest .TestCase ):
293
+ @classmethod
294
+ def setUpClass (cls ):
295
+ cls .structure = bulk ("Al" , cubic = True , a = 4.0 ).repeat (2 )
296
+ cls .dataset = {
297
+ "structure" : cls .structure ,
298
+ "rotations" : np .eye (3 ),
299
+ "permutations" : np .arange (len (cls .structure )),
300
+ }
301
+
302
+ def test_order (self ):
303
+ with self .assertRaises (ValueError ):
304
+ _SymmetrizeTensor (
305
+ tensor = np .array ([1 ]), ** self .dataset
306
+ ).order
307
+ self .assertEqual (
308
+ _SymmetrizeTensor (
309
+ tensor = np .random .randn (* self .structure .positions .shape ), ** self .dataset
310
+ ).order ,
311
+ 1 ,
312
+ )
313
+ self .assertEqual (
314
+ _SymmetrizeTensor (
315
+ tensor = np .random .randn (* 2 * self .structure .positions .shape ),
316
+ ** self .dataset ,
317
+ ).order ,
318
+ 2 ,
319
+ )
320
+
321
+ def test_indexing (self ):
322
+ st = _SymmetrizeTensor (
323
+ tensor = np .random .randn (* 2 * self .structure .positions .shape ), ** self .dataset
324
+ )
325
+ self .assertEqual (st .ij , "abcd" )
326
+ self .assertEqual (st .ij_reorder , "acbd" )
327
+ self .assertEqual (st .IJ , "ABCD" )
328
+ self .assertEqual (st .IJ_reorder , "ACBD" )
329
+
330
+ def test_str_einsum (self ):
331
+ st = _SymmetrizeTensor (
332
+ tensor = np .random .randn (* 2 * self .structure .positions .shape ), ** self .dataset
333
+ )
334
+ self .assertEqual (st .str_einsum , "Cc,Dd,ABcd...->...ACBD" )
335
+ st = _SymmetrizeTensor (
336
+ tensor = np .random .randn (* self .structure .positions .shape ), ** self .dataset
337
+ )
338
+ self .assertEqual (st .str_einsum , "Bb,Ab...->...AB" )
339
+
340
+
287
341
if __name__ == "__main__" :
288
342
unittest .main ()
0 commit comments