Skip to content

Commit 4c01a98

Browse files
11happyrkazants
andauthored
[TF FE] feat: implement complex type support for selectv2 (#28773)
**Overview:** This pull request fixes #28678. **Testing:** - Tested the implementation. Verified other implementations remain unaffected. ![Screenshot from 2025-02-01 09-00-09](https://github.com/user-attachments/assets/3dccf1ba-9fa4-42c6-ac73-c14ecb2ce441) **CC:** - @rkazants --------- Signed-off-by: 11happy <soni5happy@gmail.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
1 parent ce69c62 commit 4c01a98

File tree

2 files changed

+84
-22
lines changed

2 files changed

+84
-22
lines changed

src/frontends/tensorflow_common/src/op/select.cpp

+35-22
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "openvino/op/shape_of.hpp"
1414
#include "openvino/op/squeeze.hpp"
1515
#include "openvino/op/subtract.hpp"
16+
#include "openvino/op/unsqueeze.hpp"
1617

1718
using namespace std;
1819
using namespace ov;
@@ -31,7 +32,19 @@ OutputVector translate_select_base_op(const NodeContext& node,
3132
set_node_name(node.get_name(), select);
3233
return {select};
3334
}
34-
35+
bool has_complex_inputs(Output<Node>& x, Output<Node>& y, element::Type& complex_part_type) {
36+
auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());
37+
auto complex_type_mark_y = as_type_ptr<ComplexTypeMark>(y.get_node_shared_ptr());
38+
if (complex_type_mark_x) {
39+
x = complex_type_mark_x->input_value(0);
40+
complex_part_type = complex_type_mark_x->get_complex_part_type();
41+
}
42+
if (complex_type_mark_y) {
43+
y = complex_type_mark_y->input_value(0);
44+
complex_part_type = complex_type_mark_y->get_complex_part_type();
45+
}
46+
return (complex_type_mark_x || complex_type_mark_y);
47+
}
3548
OutputVector translate_select_v2_op(const NodeContext& node) {
3649
// according to the TensorFlow documentation. See in the code:
3750
// https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/lite/kernels/select.cc#L188-L211
@@ -40,10 +53,23 @@ OutputVector translate_select_v2_op(const NodeContext& node) {
4053
// is true or the value of 'y' if false. There are valid condition input sizes:
4154
// 1. Either the same shape (in which case the select is elementwise), or
4255
// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
43-
default_op_checks(node, 3, {"SelectV2", "SELECT_V2"});
44-
// no preparation for inputs are needed
45-
// inputs are already NumPy broadcastable
46-
return translate_select_base_op(node, node.get_input(0), node.get_input(1), node.get_input(2));
56+
default_op_checks(node, 3, {"SelectV2", "SELECT_V2"}, true);
57+
auto condition = node.get_input(0);
58+
auto x = node.get_input(1);
59+
auto y = node.get_input(2);
60+
61+
element::Type complex_part_type;
62+
auto is_complex = has_complex_inputs(x, y, complex_part_type);
63+
64+
if (is_complex) {
65+
auto const_negative_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
66+
auto new_condition = make_shared<v0::Unsqueeze>(condition, const_negative_one);
67+
auto result = translate_select_base_op(node, new_condition, x, y);
68+
auto complex_result = make_shared<ComplexTypeMark>(result[0].get_node_shared_ptr(), complex_part_type);
69+
return {complex_result->output(0)};
70+
} else {
71+
return translate_select_base_op(node, condition, x, y);
72+
}
4773
}
4874

4975
OutputVector translate_select_op(const NodeContext& node) {
@@ -59,21 +85,9 @@ OutputVector translate_select_op(const NodeContext& node) {
5985
auto condition = node.get_input(0);
6086
auto x = node.get_input(1);
6187
auto y = node.get_input(2);
62-
auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());
63-
auto complex_type_mark_y = as_type_ptr<ComplexTypeMark>(y.get_node_shared_ptr());
6488

65-
auto is_complex = (complex_type_mark_x || complex_type_mark_y);
6689
element::Type complex_part_type;
67-
68-
if (complex_type_mark_x) {
69-
x = complex_type_mark_x->input_value(0);
70-
complex_part_type = complex_type_mark_x->get_complex_part_type();
71-
}
72-
73-
if (complex_type_mark_y) {
74-
y = complex_type_mark_y->input_value(0);
75-
complex_part_type = complex_type_mark_y->get_complex_part_type();
76-
}
90+
auto is_complex = has_complex_inputs(x, y, complex_part_type);
7791

7892
// compute number of dimensions to unsqueeze the condition
7993
auto cond_rank = compute_subgraph_scalar_rank(condition, element::i32);
@@ -85,14 +99,13 @@ OutputVector translate_select_op(const NodeContext& node) {
8599
auto new_subshape = make_shared<v3::Broadcast>(const_one, num_new_axes);
86100
auto cond_shape = make_shared<v3::ShapeOf>(condition, element::i32);
87101
// use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat
88-
auto const_1 = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
89-
auto new_cond_shape = make_shared<v0::Concat>(OutputVector{const_1, cond_shape, new_subshape}, 0);
102+
auto new_cond_shape = make_shared<v0::Concat>(OutputVector{const_one, cond_shape, new_subshape}, 0);
90103

91104
// prepare the condition to have the same rank as operands `x` and `y`
92105
auto prep_cond = make_shared<v1::Reshape>(condition, new_cond_shape, false)->output(0);
93106
// squeeze prep_cond by one extra dimension specially added
94-
auto const_0 = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
95-
prep_cond = make_shared<v0::Squeeze>(prep_cond, const_0);
107+
auto const_zero = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
108+
prep_cond = make_shared<v0::Squeeze>(prep_cond, const_zero);
96109

97110
auto result = translate_select_base_op(node, prep_cond, x, y);
98111
if (is_complex) {

tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py

+49
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,52 @@ def test_select_v2_basic(self, params, ie_device, precision, ir_version, temp_di
5151
self._test(*self.create_select_v2_net(**params),
5252
ie_device, precision, ir_version, temp_dir=temp_dir,
5353
use_legacy_frontend=use_legacy_frontend)
54+
55+
56+
class TestComplexSelectV2(CommonTFLayerTest):
57+
def _prepare_input(self, inputs_info):
58+
rng = np.random.default_rng()
59+
assert 'cond:0' in inputs_info, "Test error: inputs_info must contain `cond`"
60+
assert 'x_real:0' in inputs_info, "Test error: inputs_info must contain `x_real`"
61+
assert 'x_imag:0' in inputs_info, "Test error: inputs_info must contain `x_imag`"
62+
assert 'y_real:0' in inputs_info, "Test error: inputs_info must contain `y_real`"
63+
assert 'y_imag:0' in inputs_info, "Test error: inputs_info must contain `y_imag`"
64+
cond_shape = inputs_info['cond:0']
65+
inputs_data = {}
66+
inputs_data['cond:0'] = np.random.randint(0, 2, cond_shape).astype(bool)
67+
for part in ['x_real:0', 'x_imag:0', 'y_real:0', 'y_imag:0']:
68+
inputs_data[part] = 4 * rng.random(inputs_info[part]).astype(np.float32) - 2
69+
return inputs_data
70+
71+
def create_complex_select_v2_net(self, cond_shape, x_shape, y_shape):
72+
tf.compat.v1.reset_default_graph()
73+
# Create the graph and model
74+
with tf.compat.v1.Session() as sess:
75+
cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond')
76+
x_real = tf.compat.v1.placeholder(tf.float32, x_shape, 'x_real')
77+
x_imag = tf.compat.v1.placeholder(tf.float32, x_shape, 'x_imag')
78+
y_real = tf.compat.v1.placeholder(tf.float32, y_shape, 'y_real')
79+
y_imag = tf.compat.v1.placeholder(tf.float32, y_shape, 'y_imag')
80+
complex_x = tf.raw_ops.Complex(real=x_real, imag=x_imag)
81+
complex_y = tf.raw_ops.Complex(real=y_real, imag=y_imag)
82+
complex_select = tf.raw_ops.SelectV2(condition=cond, t=complex_x, e=complex_y)
83+
tf.raw_ops.Real(input=complex_select)
84+
tf.raw_ops.Imag(input=complex_select)
85+
tf.compat.v1.global_variables_initializer()
86+
tf_net = sess.graph_def
87+
return tf_net, None
88+
89+
test_data_basic = [
90+
dict(cond_shape=[3, 1], x_shape=[3, 1], y_shape=[3, 1]),
91+
dict(cond_shape=[], x_shape=[2], y_shape=[3, 2]),
92+
dict(cond_shape=[4], x_shape=[3, 2, 1], y_shape=[2, 4]),
93+
]
94+
95+
@pytest.mark.parametrize("params", test_data_basic)
96+
@pytest.mark.precommit
97+
@pytest.mark.nightly
98+
def test_complex_select_v2(self, params, ie_device, precision, ir_version, temp_dir,
99+
use_legacy_frontend):
100+
self._test(*self.create_complex_select_v2_net(**params),
101+
ie_device, precision, ir_version, temp_dir=temp_dir,
102+
use_legacy_frontend=use_legacy_frontend)

0 commit comments

Comments
 (0)