Skip to content

Commit 3983c35

Browse files
hibahassan1mlukaszerkazants
authored
[Good First Issue][TF FE]: Support complex tensors for Pack operation (openvinotoolkit#25193)
### Details: - ***Addresses the issue** : [[Good First Issue][TF FE]: Support complex tensors for Pack operation openvinotoolkit#22954](openvinotoolkit#22954 - *Added support for complex tensors for pack operation* - *Let me know if any changes are required* --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
1 parent 04e0fea commit 3983c35

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

src/frontends/tensorflow_common/src/op/pack.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
//
44

55
#include "common_op_table.hpp"
6+
#include "helper_ops/complex_type_mark.hpp"
67
#include "openvino/op/concat.hpp"
78
#include "openvino/op/constant.hpp"
9+
#include "openvino/op/shape_of.hpp"
810
#include "openvino/op/unsqueeze.hpp"
911

1012
using namespace std;
@@ -16,20 +18,40 @@ namespace tensorflow {
1618
namespace op {
1719

1820
OutputVector translate_pack_op(const NodeContext& node) {
19-
default_op_checks(node, 1, {"Pack", "PACK"});
21+
default_op_checks(node, 1, {"Pack", "PACK"}, true);
2022
auto num_size = static_cast<int>(node.get_input_size());
2123

2224
auto axis = node.get_attribute<int64_t>("axis", 0);
25+
if (axis < 0 && as_type_ptr<ComplexTypeMark>(node.get_input(0).get_node_shared_ptr())) {
26+
// need to account auxiliary dimension for real and imaginary parts
27+
axis -= 1;
28+
}
29+
2330
auto axis_const = make_shared<v0::Constant>(element::i64, Shape{}, axis);
2431

2532
OutputVector concat_inputs;
33+
bool has_complex_input = false;
34+
element::Type complex_part_type;
35+
2636
for (int ind = 0; ind < num_size; ++ind) {
2737
auto in = node.get_input(ind);
38+
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(in.get_node_shared_ptr());
39+
if (complex_type_mark) {
40+
has_complex_input = true;
41+
complex_part_type = complex_type_mark->get_complex_part_type();
42+
in = complex_type_mark->input_value(0);
43+
}
2844
concat_inputs.push_back(make_shared<v0::Unsqueeze>(in, axis_const));
2945
}
3046

3147
auto pack = make_shared<v0::Concat>(concat_inputs, axis);
3248
set_node_name(node.get_name(), pack);
49+
50+
if (has_complex_input) {
51+
auto complex_pack = make_shared<ComplexTypeMark>(pack, complex_part_type);
52+
return {complex_pack->output(0)};
53+
}
54+
3355
return {pack};
3456
}
3557
} // namespace op

tests/layer_tests/tensorflow_tests/test_tf_Pack.py

+70
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,73 @@ def test_pack_basic(self, params, ie_device, precision, ir_version, temp_dir,
5151
self._test(*self.create_pack_net(**params),
5252
ie_device, precision, ir_version, temp_dir=temp_dir,
5353
use_legacy_frontend=use_legacy_frontend)
54+
55+
test_data_negative_axis = [
56+
dict(input_shape=[2, 4], input_num=2, axis=-1, input_type=np.float32),
57+
dict(input_shape=[3, 1, 2], input_num=3, axis=-2, input_type=np.int32),
58+
]
59+
60+
@pytest.mark.parametrize("params", test_data_negative_axis)
61+
@pytest.mark.precommit
62+
@pytest.mark.nightly
63+
def test_pack_negative_axis(self, params, ie_device, precision, ir_version, temp_dir,
64+
use_legacy_frontend):
65+
self._test(*self.create_pack_net(**params),
66+
ie_device, precision, ir_version, temp_dir=temp_dir,
67+
use_legacy_frontend=use_legacy_frontend)
68+
69+
70+
class TestComplexPack(CommonTFLayerTest):
71+
def _prepare_input(self, inputs_info):
72+
inputs_data = {}
73+
for input_name, input_shape in inputs_info.items():
74+
inputs_data[input_name] = np.random.randint(-5, 5, input_shape).astype(np.float32)
75+
return inputs_data
76+
77+
def create_complex_pack_net(self, input_shape, input_num, axis):
78+
tf.compat.v1.reset_default_graph()
79+
with tf.compat.v1.Session() as sess:
80+
inputs_real = []
81+
inputs_imag = []
82+
for ind in range(input_num):
83+
input_real = tf.compat.v1.placeholder(tf.float32, input_shape, 'input' + str(ind) + '_real')
84+
input_imag = tf.compat.v1.placeholder(tf.float32, input_shape, 'input' + str(ind) + '_imag')
85+
inputs_real.append(input_real)
86+
inputs_imag.append(input_imag)
87+
if axis is not None:
88+
tf.raw_ops.Pack(values=inputs_real + inputs_imag, axis=axis)
89+
else:
90+
tf.raw_ops.Pack(values=inputs_real + inputs_imag)
91+
tf.compat.v1.global_variables_initializer()
92+
93+
tf_net = sess.graph_def
94+
95+
return tf_net, None
96+
97+
test_data_basic = [
98+
dict(input_shape=[2, 4], input_num=2, axis=None),
99+
dict(input_shape=[3, 1, 2], input_num=3, axis=1),
100+
]
101+
102+
@pytest.mark.parametrize("params", test_data_basic)
103+
@pytest.mark.precommit
104+
@pytest.mark.nightly
105+
def test_complex_pack_basic(self, params, ie_device, precision, ir_version, temp_dir,
106+
use_legacy_frontend):
107+
self._test(*self.create_complex_pack_net(**params),
108+
ie_device, precision, ir_version, temp_dir=temp_dir,
109+
use_legacy_frontend=use_legacy_frontend)
110+
111+
test_data_negative_axis = [
112+
dict(input_shape=[2, 4], input_num=2, axis=-1),
113+
dict(input_shape=[3, 1, 2], input_num=3, axis=-2),
114+
]
115+
116+
@pytest.mark.parametrize("params", test_data_negative_axis)
117+
@pytest.mark.precommit
118+
@pytest.mark.nightly
119+
def test_complex_pack_negative_axis(self, params, ie_device, precision, ir_version, temp_dir,
120+
use_legacy_frontend):
121+
self._test(*self.create_complex_pack_net(**params),
122+
ie_device, precision, ir_version, temp_dir=temp_dir,
123+
use_legacy_frontend=use_legacy_frontend)

0 commit comments

Comments
 (0)