13
13
#include " openvino/op/shape_of.hpp"
14
14
#include " openvino/op/squeeze.hpp"
15
15
#include " openvino/op/subtract.hpp"
16
+ #include " openvino/op/unsqueeze.hpp"
16
17
17
18
using namespace std ;
18
19
using namespace ov ;
@@ -31,7 +32,19 @@ OutputVector translate_select_base_op(const NodeContext& node,
31
32
set_node_name (node.get_name (), select );
32
33
return {select };
33
34
}
34
-
35
+ bool has_complex_inputs (Output<Node>& x, Output<Node>& y, element::Type& complex_part_type) {
36
+ auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr ());
37
+ auto complex_type_mark_y = as_type_ptr<ComplexTypeMark>(y.get_node_shared_ptr ());
38
+ if (complex_type_mark_x) {
39
+ x = complex_type_mark_x->input_value (0 );
40
+ complex_part_type = complex_type_mark_x->get_complex_part_type ();
41
+ }
42
+ if (complex_type_mark_y) {
43
+ y = complex_type_mark_y->input_value (0 );
44
+ complex_part_type = complex_type_mark_y->get_complex_part_type ();
45
+ }
46
+ return (complex_type_mark_x || complex_type_mark_y);
47
+ }
35
48
OutputVector translate_select_v2_op (const NodeContext& node) {
36
49
// according to the TensorFlow documentation. See in the code:
37
50
// https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/lite/kernels/select.cc#L188-L211
@@ -40,10 +53,23 @@ OutputVector translate_select_v2_op(const NodeContext& node) {
40
53
// is true or the value of 'y' if false. There are valid condition input sizes:
41
54
// 1. Either the same shape (in which case the select is elementwise), or
42
55
// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
43
- default_op_checks (node, 3 , {" SelectV2" , " SELECT_V2" });
44
- // no preparation for inputs are needed
45
- // inputs are already NumPy broadcastable
46
- return translate_select_base_op (node, node.get_input (0 ), node.get_input (1 ), node.get_input (2 ));
56
+ default_op_checks (node, 3 , {" SelectV2" , " SELECT_V2" }, true );
57
+ auto condition = node.get_input (0 );
58
+ auto x = node.get_input (1 );
59
+ auto y = node.get_input (2 );
60
+
61
+ element::Type complex_part_type;
62
+ auto is_complex = has_complex_inputs (x, y, complex_part_type);
63
+
64
+ if (is_complex) {
65
+ auto const_negative_one = make_shared<v0::Constant>(element::i32, Shape{1 }, -1 );
66
+ auto new_condition = make_shared<v0::Unsqueeze>(condition, const_negative_one);
67
+ auto result = translate_select_base_op (node, new_condition, x, y);
68
+ auto complex_result = make_shared<ComplexTypeMark>(result[0 ].get_node_shared_ptr (), complex_part_type);
69
+ return {complex_result->output (0 )};
70
+ } else {
71
+ return translate_select_base_op (node, condition, x, y);
72
+ }
47
73
}
48
74
49
75
OutputVector translate_select_op (const NodeContext& node) {
@@ -59,21 +85,9 @@ OutputVector translate_select_op(const NodeContext& node) {
59
85
auto condition = node.get_input (0 );
60
86
auto x = node.get_input (1 );
61
87
auto y = node.get_input (2 );
62
- auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr ());
63
- auto complex_type_mark_y = as_type_ptr<ComplexTypeMark>(y.get_node_shared_ptr ());
64
88
65
- auto is_complex = (complex_type_mark_x || complex_type_mark_y);
66
89
element::Type complex_part_type;
67
-
68
- if (complex_type_mark_x) {
69
- x = complex_type_mark_x->input_value (0 );
70
- complex_part_type = complex_type_mark_x->get_complex_part_type ();
71
- }
72
-
73
- if (complex_type_mark_y) {
74
- y = complex_type_mark_y->input_value (0 );
75
- complex_part_type = complex_type_mark_y->get_complex_part_type ();
76
- }
90
+ auto is_complex = has_complex_inputs (x, y, complex_part_type);
77
91
78
92
// compute number of dimensions to unsqueeze the condition
79
93
auto cond_rank = compute_subgraph_scalar_rank (condition, element::i32);
@@ -85,14 +99,13 @@ OutputVector translate_select_op(const NodeContext& node) {
85
99
auto new_subshape = make_shared<v3::Broadcast>(const_one, num_new_axes);
86
100
auto cond_shape = make_shared<v3::ShapeOf>(condition, element::i32);
87
101
// use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat
88
- auto const_1 = make_shared<v0::Constant>(element::i32, Shape{1 }, 1 );
89
- auto new_cond_shape = make_shared<v0::Concat>(OutputVector{const_1, cond_shape, new_subshape}, 0 );
102
+ auto new_cond_shape = make_shared<v0::Concat>(OutputVector{const_one, cond_shape, new_subshape}, 0 );
90
103
91
104
// prepare the condition to have the same rank as operands `x` and `y`
92
105
auto prep_cond = make_shared<v1::Reshape>(condition, new_cond_shape, false )->output (0 );
93
106
// squeeze prep_cond by one extra dimension specially added
94
- auto const_0 = make_shared<v0::Constant>(element::i32, Shape{1 }, 0 );
95
- prep_cond = make_shared<v0::Squeeze>(prep_cond, const_0 );
107
+ auto const_zero = make_shared<v0::Constant>(element::i32, Shape{1 }, 0 );
108
+ prep_cond = make_shared<v0::Squeeze>(prep_cond, const_zero );
96
109
97
110
auto result = translate_select_base_op (node, prep_cond, x, y);
98
111
if (is_complex) {
0 commit comments