Skip to content

Commit a6d92e3

Browse files
authored
[TF FE]: Support complex tensors for Sub operation (openvinotoolkit#26342)
### Details: - Support complex tensors for Sub operation + tests ### Tickets: - [None](openvinotoolkit#22948)
1 parent 3983c35 commit a6d92e3

File tree

4 files changed

+96
-2
lines changed

4 files changed

+96
-2
lines changed

src/frontends/tensorflow/src/op_table.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
#include "openvino/op/softplus.hpp"
6565
#include "openvino/op/softsign.hpp"
6666
#include "openvino/op/squared_difference.hpp"
67-
#include "openvino/op/subtract.hpp"
6867
#include "openvino/op/swish.hpp"
6968
#include "openvino/op/tan.hpp"
7069
#include "openvino/op/tanh.hpp"
@@ -195,7 +194,6 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
195194
{"Pow", CreatorFunction(translate_binary_op<v1::Power>)},
196195
{"RealDiv", CreatorFunction(translate_binary_op<v1::Divide>)},
197196
{"SquaredDifference", CreatorFunction(translate_binary_op<v0::SquaredDifference>)},
198-
{"Sub", CreatorFunction(translate_binary_op<v1::Subtract>)},
199197

200198
// note: ReduceOp translator declaration for each op must to be added in reduce.cpp file
201199
{"Any", CreatorFunction(translate_direct_reduce_op<v1::ReduceLogicalOr>)},
@@ -396,6 +394,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
396394
{"StatelessIf", CreatorFunction(translate_if_op)},
397395
{"StatelessWhile", CreatorFunction(translate_while_op)},
398396
{"StridedSlice", CreatorFunction(translate_strided_slice_op)},
397+
{"Sub", CreatorFunction(translate_sub_op)},
399398
{"Switch", CreatorFunction(translate_switch_op)},
400399
{"TensorArrayCloseV3", CreatorFunction(translate_tensor_array_close_v3_op)},
401400
{"TensorArrayConcatV3", CreatorFunction(translate_tensor_array_concat_v3_op)},

src/frontends/tensorflow_common/include/common_op_table.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ OP_CONVERTER(translate_split_v_op);
162162
OP_CONVERTER(translate_square_op);
163163
OP_CONVERTER(translate_squeeze_op);
164164
OP_CONVERTER(translate_strided_slice_op);
165+
OP_CONVERTER(translate_sub_op);
165166
OP_CONVERTER(translate_sqrt_op);
166167
OP_CONVERTER(translate_empty_tensor_list_op);
167168
OP_CONVERTER(translate_tensor_list_from_tensor_op);

src/frontends/tensorflow_common/src/op/binary_op.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,33 @@ OutputVector translate_addv2_op(const NodeContext& node) {
171171
return {result};
172172
}
173173

174+
OutputVector translate_sub_op(const NodeContext& node) {
175+
default_op_checks(node, 2, {"Sub"}, true);
176+
auto lhs = node.get_input(0);
177+
auto rhs = node.get_input(1);
178+
179+
auto complex_type_mark_lhs = as_type_ptr<ComplexTypeMark>(lhs.get_node_shared_ptr());
180+
auto complex_type_mark_rhs = as_type_ptr<ComplexTypeMark>(rhs.get_node_shared_ptr());
181+
auto complex_type_inputs = (complex_type_mark_lhs && complex_type_mark_rhs);
182+
183+
if (complex_type_inputs) {
184+
lhs = complex_type_mark_lhs->input_value(0);
185+
rhs = complex_type_mark_rhs->input_value(0);
186+
}
187+
188+
// performing an actual operation
189+
auto result = make_shared<v1::Subtract>(lhs, rhs);
190+
191+
if (complex_type_inputs) {
192+
auto complex_result = make_shared<ComplexTypeMark>(result, complex_type_mark_lhs->get_complex_part_type());
193+
set_node_name(node.get_name(), result);
194+
195+
return {complex_result};
196+
}
197+
set_node_name(node.get_name(), result);
198+
return {result};
199+
}
200+
174201
template OutputVector translate_binary_op<v1::Add>(const NodeContext& node);
175202
template OutputVector translate_binary_op<v13::BitwiseAnd>(const NodeContext& node);
176203
template OutputVector translate_binary_op<v13::BitwiseOr>(const NodeContext& node);

tests/layer_tests/tensorflow_tests/test_tf_Sub.py

+67
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,70 @@ def test_sub_placeholder_const_broadcast_5D(self, params, ie_device, precision,
216216
use_legacy_frontend=use_legacy_frontend),
217217
ie_device, precision, ir_version,
218218
temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend)
219+
220+
221+
class TestComplexSub(CommonTFLayerTest):
222+
def _prepare_input(self, inputs_info):
223+
rng = np.random.default_rng(84821)
224+
225+
assert 'param_real_x:0' in inputs_info
226+
assert 'param_imag_x:0' in inputs_info
227+
228+
assert 'param_real_y:0' in inputs_info
229+
assert 'param_imag_y:0' in inputs_info
230+
231+
param_real_shape_x = inputs_info['param_real_x:0']
232+
param_imag_shape_x = inputs_info['param_imag_x:0']
233+
234+
param_real_shape_y = inputs_info['param_real_y:0']
235+
param_imag_shape_y = inputs_info['param_imag_y:0']
236+
237+
inputs_data = {}
238+
inputs_data['param_real_x:0'] = rng.uniform(-10.0, 10.0, param_real_shape_x).astype(np.float32)
239+
inputs_data['param_imag_x:0'] = rng.uniform(-10.0, 10.0, param_imag_shape_x).astype(np.float32)
240+
241+
inputs_data['param_real_y:0'] = rng.uniform(-10.0, 10.0, param_real_shape_y).astype(np.float32)
242+
inputs_data['param_imag_y:0'] = rng.uniform(-10.0, 10.0, param_imag_shape_y).astype(np.float32)
243+
244+
return inputs_data
245+
246+
def create_complex_sub_net(self, x_shape, y_shape, ir_version, use_legacy_frontend):
247+
import tensorflow as tf
248+
249+
tf.compat.v1.reset_default_graph()
250+
with tf.compat.v1.Session() as sess:
251+
param_real_x = tf.compat.v1.placeholder(np.float32, x_shape, 'param_real_x')
252+
param_imag_x = tf.compat.v1.placeholder(np.float32, x_shape, 'param_imag_x')
253+
254+
param_real_y = tf.compat.v1.placeholder(np.float32, y_shape, 'param_real_y')
255+
param_imag_y = tf.compat.v1.placeholder(np.float32, y_shape, 'param_imag_y')
256+
257+
x = tf.raw_ops.Complex(real=param_real_x, imag=param_imag_x)
258+
y = tf.raw_ops.Complex(real=param_real_y, imag=param_imag_y)
259+
260+
result = tf.raw_ops.Sub(x=x, y=y, name='Sub')
261+
262+
tf.raw_ops.Real(input=result)
263+
tf.raw_ops.Imag(input=result)
264+
265+
tf.compat.v1.global_variables_initializer()
266+
tf_net = sess.graph_def
267+
268+
ref_net = None
269+
270+
return tf_net, ref_net
271+
272+
@pytest.mark.parametrize('x_shape, y_shape', [
273+
[[5, 5], [5]],
274+
[[4, 10], [4, 1]],
275+
[[1, 3, 50, 224], [1]],
276+
[[10, 10, 10], [10, 10, 10]],
277+
])
278+
@pytest.mark.precommit
279+
@pytest.mark.nightly
280+
def test_complex_sub(self, x_shape, y_shape,
281+
ie_device, precision, ir_version, temp_dir, use_legacy_frontend):
282+
self._test(*self.create_complex_sub_net(x_shape, y_shape, ir_version=ir_version,
283+
use_legacy_frontend=use_legacy_frontend),
284+
ie_device, precision, ir_version, temp_dir=temp_dir,
285+
use_legacy_frontend=use_legacy_frontend)

0 commit comments

Comments
 (0)