@@ -216,3 +216,70 @@ def test_sub_placeholder_const_broadcast_5D(self, params, ie_device, precision,
216
216
use_legacy_frontend = use_legacy_frontend ),
217
217
ie_device , precision , ir_version ,
218
218
temp_dir = temp_dir , use_legacy_frontend = use_legacy_frontend )
219
+
220
+
221
+ class TestComplexSub (CommonTFLayerTest ):
222
+ def _prepare_input (self , inputs_info ):
223
+ rng = np .random .default_rng (84821 )
224
+
225
+ assert 'param_real_x:0' in inputs_info
226
+ assert 'param_imag_x:0' in inputs_info
227
+
228
+ assert 'param_real_y:0' in inputs_info
229
+ assert 'param_imag_y:0' in inputs_info
230
+
231
+ param_real_shape_x = inputs_info ['param_real_x:0' ]
232
+ param_imag_shape_x = inputs_info ['param_imag_x:0' ]
233
+
234
+ param_real_shape_y = inputs_info ['param_real_y:0' ]
235
+ param_imag_shape_y = inputs_info ['param_imag_y:0' ]
236
+
237
+ inputs_data = {}
238
+ inputs_data ['param_real_x:0' ] = rng .uniform (- 10.0 , 10.0 , param_real_shape_x ).astype (np .float32 )
239
+ inputs_data ['param_imag_x:0' ] = rng .uniform (- 10.0 , 10.0 , param_imag_shape_x ).astype (np .float32 )
240
+
241
+ inputs_data ['param_real_y:0' ] = rng .uniform (- 10.0 , 10.0 , param_real_shape_y ).astype (np .float32 )
242
+ inputs_data ['param_imag_y:0' ] = rng .uniform (- 10.0 , 10.0 , param_imag_shape_y ).astype (np .float32 )
243
+
244
+ return inputs_data
245
+
246
+ def create_complex_sub_net (self , x_shape , y_shape , ir_version , use_legacy_frontend ):
247
+ import tensorflow as tf
248
+
249
+ tf .compat .v1 .reset_default_graph ()
250
+ with tf .compat .v1 .Session () as sess :
251
+ param_real_x = tf .compat .v1 .placeholder (np .float32 , x_shape , 'param_real_x' )
252
+ param_imag_x = tf .compat .v1 .placeholder (np .float32 , x_shape , 'param_imag_x' )
253
+
254
+ param_real_y = tf .compat .v1 .placeholder (np .float32 , y_shape , 'param_real_y' )
255
+ param_imag_y = tf .compat .v1 .placeholder (np .float32 , y_shape , 'param_imag_y' )
256
+
257
+ x = tf .raw_ops .Complex (real = param_real_x , imag = param_imag_x )
258
+ y = tf .raw_ops .Complex (real = param_real_y , imag = param_imag_y )
259
+
260
+ result = tf .raw_ops .Sub (x = x , y = y , name = 'Sub' )
261
+
262
+ tf .raw_ops .Real (input = result )
263
+ tf .raw_ops .Imag (input = result )
264
+
265
+ tf .compat .v1 .global_variables_initializer ()
266
+ tf_net = sess .graph_def
267
+
268
+ ref_net = None
269
+
270
+ return tf_net , ref_net
271
+
272
+ @pytest .mark .parametrize ('x_shape, y_shape' , [
273
+ [[5 , 5 ], [5 ]],
274
+ [[4 , 10 ], [4 , 1 ]],
275
+ [[1 , 3 , 50 , 224 ], [1 ]],
276
+ [[10 , 10 , 10 ], [10 , 10 , 10 ]],
277
+ ])
278
+ @pytest .mark .precommit
279
+ @pytest .mark .nightly
280
+ def test_complex_sub (self , x_shape , y_shape ,
281
+ ie_device , precision , ir_version , temp_dir , use_legacy_frontend ):
282
+ self ._test (* self .create_complex_sub_net (x_shape , y_shape , ir_version = ir_version ,
283
+ use_legacy_frontend = use_legacy_frontend ),
284
+ ie_device , precision , ir_version , temp_dir = temp_dir ,
285
+ use_legacy_frontend = use_legacy_frontend )
0 commit comments