Skip to content

Commit 8b6eac6

Browse files
authored
[TF FE] Optimize TensorList ops decompositions (openvinotoolkit#25003)
**Details:** Simplify representation of TensorList ops* and it preserves tensor list rank that helps for further fusion of Loop with tensor list ops into RNN sequence operations. Currently, it always flattens tensor list elements to 1D and it is blocking the fusion. **Ticket:** TBD --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent 20d91fd commit 8b6eac6

File tree

8 files changed

+537
-101
lines changed

8 files changed

+537
-101
lines changed

src/frontends/tensorflow/src/frontend.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
1414
#include "helper_transforms/saved_model_unused_remover.hpp"
1515
#include "helper_transforms/tensor_array_v3_replacer.hpp"
16+
#include "helper_transforms/tensor_list_ops_resolver.hpp"
1617
#include "input_model.hpp"
1718
#include "op_table.hpp"
1819
#include "openvino/core/so_extension.hpp"
@@ -567,6 +568,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
567568
manager.register_pass<pass::TensorArrayV3Replacer>();
568569
manager.register_pass<pass::ConstToResultRemover>();
569570
manager.register_pass<pass::SwitchMergeResolver>();
571+
manager.register_pass<pass::TensorListOperationsResolver>();
570572
manager.register_pass<ov::pass::UnrollIf>();
571573
manager.register_pass<ov::pass::RemoveConcatZeroDimInput>();
572574
manager.register_pass<ov::pass::TransposeSinkingGeneral>();

src/frontends/tensorflow/src/op_table.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
247247
{"DynamicPartition", CreatorFunction(translate_dynamic_partition_op)},
248248
{"Einsum", CreatorFunction(translate_einsum_op)},
249249
{"Elu", CreatorFunction(translate_elu_op)},
250-
{"EmptyTensorList", CreatorFunction(translate_tensor_list_reserve_op)},
250+
{"EmptyTensorList", CreatorFunction(translate_empty_tensor_list_op)},
251251
{"EnsureShape", CreatorFunction(translate_identity_op)},
252252
{"ExpandDims", CreatorFunction(translate_expand_dims_op)},
253253
{"ExtractImagePatches", CreatorFunction(translate_extract_image_patches_op)},

src/frontends/tensorflow/tests/convert_tricky_models.cpp

+4-14
Original file line numberDiff line numberDiff line change
@@ -420,20 +420,10 @@ TEST_F(FrontEndConversionWithReferenceTestsF, ModelWithEmptyTensorListAndPushBac
420420
{ model = convert_model("empty_tensor_list/empty_tensor_list.pb"); }
421421
{
422422
auto x = make_shared<v0::Parameter>(f32, Shape{2, 3, 5});
423-
auto minus_one_const = make_shared<v0::Constant>(i32, Shape{1}, -1);
424-
auto x_flatten = make_shared<v1::Reshape>(x, minus_one_const, false);
425-
auto zero_const = make_shared<v0::Constant>(i32, Shape{1}, 0);
426-
auto x_unsqueeze_flatten = make_shared<v0::Unsqueeze>(x_flatten, zero_const);
427-
auto list_push_back = make_shared<v0::Concat>(OutputVector{x_unsqueeze_flatten}, 0);
428-
auto list_push_back_shape = make_shared<v3::ShapeOf>(list_push_back, element::i32);
429-
auto start = make_shared<v0::Constant>(i32, Shape{1}, 0);
430-
auto stop = make_shared<v0::Constant>(i32, Shape{1}, 1);
431-
auto step = make_shared<v0::Constant>(i32, Shape{1}, 1);
432-
auto batch = make_shared<v8::Slice>(list_push_back_shape, start, stop, step);
433-
auto shape_without_batch = make_shared<v0::Constant>(i32, Shape{3}, vector<int32_t>{2, 3, 5});
434-
auto recover_item_shape = make_shared<v0::Concat>(OutputVector{batch, shape_without_batch}, 0);
435-
auto recover_item = make_shared<v1::Reshape>(list_push_back, recover_item_shape, false);
436-
model_ref = make_shared<Model>(OutputVector{recover_item}, ParameterVector{x});
423+
auto axes = make_shared<v0::Constant>(i32, Shape{1}, 0);
424+
auto x_unsqueeze = make_shared<v0::Unsqueeze>(x, axes);
425+
auto list_push_back = make_shared<v0::Concat>(OutputVector{x_unsqueeze}, 0);
426+
model_ref = make_shared<Model>(OutputVector{list_push_back}, ParameterVector{x});
437427
}
438428
comparator.disable(FunctionsComparator::CmpValues::ATTRIBUTES);
439429
}

src/frontends/tensorflow_common/include/common_op_table.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ OP_CONVERTER(translate_square_op);
160160
OP_CONVERTER(translate_squeeze_op);
161161
OP_CONVERTER(translate_strided_slice_op);
162162
OP_CONVERTER(translate_sqrt_op);
163+
OP_CONVERTER(translate_empty_tensor_list_op);
163164
OP_CONVERTER(translate_tensor_list_from_tensor_op);
164165
OP_CONVERTER(translate_tensor_list_get_item_op);
165166
OP_CONVERTER(translate_tensor_list_length_op);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <vector>
8+
9+
#include "internal_operation.hpp"
10+
#include "openvino/op/constant.hpp"
11+
12+
namespace ov {
13+
namespace frontend {
14+
namespace tensorflow {
15+
16+
// Internal operation for TensorList that represents a initial state of tensor list container
17+
class TensorList : public InternalOperation {
18+
public:
19+
OPENVINO_OP("TensorList", "ov::frontend::tensorflow", InternalOperation);
20+
21+
TensorList(const ov::Output<ov::Node>& num_elements,
22+
const ov::Rank& element_rank,
23+
const element::Type& element_dtype,
24+
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
25+
: InternalOperation(decoder, OutputVector{num_elements}, 1, "TensorList"),
26+
m_num_elements(num_elements),
27+
m_element_rank(element_rank),
28+
m_element_dtype(element_dtype) {
29+
validate_and_infer_types();
30+
}
31+
32+
void validate_and_infer_types() override {
33+
if (m_element_rank.is_static()) {
34+
auto element_rank = m_element_rank.get_length();
35+
auto output_shape = ov::PartialShape::dynamic(element_rank + 1);
36+
set_output_type(0, m_element_dtype, output_shape);
37+
}
38+
39+
set_output_type(0, m_element_dtype, ov::PartialShape::dynamic());
40+
}
41+
42+
ov::element::Type get_element_type() const {
43+
return m_element_dtype;
44+
}
45+
46+
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
47+
FRONT_END_OP_CONVERSION_CHECK(inputs.size() == 1,
48+
"[TensorFlow Frontend] internal error: TensorList expects no inputs");
49+
auto tensor_list_node = std::make_shared<TensorList>(inputs[0], m_element_rank, m_element_dtype, m_decoder);
50+
tensor_list_node->set_attrs(get_attrs());
51+
return tensor_list_node;
52+
}
53+
54+
ov::Rank get_element_rank() const {
55+
return m_element_rank;
56+
}
57+
58+
void set_element_rank(const ov::Rank& element_rank) {
59+
m_element_rank = element_rank;
60+
}
61+
62+
ov::Output<ov::Node> get_num_elements() const {
63+
return m_num_elements;
64+
}
65+
66+
private:
67+
ov::Output<ov::Node> m_num_elements;
68+
ov::Rank m_element_rank;
69+
ov::element::Type m_element_dtype;
70+
};
71+
72+
// Internal operation for TensorListGetItem
73+
// it gets an element (Tensor) in tensor list by index
74+
class TensorListGetItem : public InternalOperation {
75+
public:
76+
OPENVINO_OP("TensorListGetItem", "ov::frontend::tensorflow", InternalOperation);
77+
78+
TensorListGetItem(const Output<Node>& input_handle,
79+
const Output<Node>& index,
80+
const Output<Node>& element_shape,
81+
const ov::element::Type& element_type,
82+
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
83+
: InternalOperation(decoder, OutputVector{input_handle, index, element_shape}, 1, "TensorListGetItem"),
84+
m_element_type(element_type) {
85+
validate_and_infer_types();
86+
}
87+
88+
void validate_and_infer_types() override {
89+
// deduce an element (Tensor) shape
90+
ov::PartialShape comp_element_shape = ov::PartialShape::dynamic();
91+
if (const auto& const_element_shape =
92+
ov::as_type_ptr<ov::op::v0::Constant>(input_value(2).get_node_shared_ptr())) {
93+
auto element_shape_value = const_element_shape->get_vector<int32_t>();
94+
comp_element_shape = ov::PartialShape::dynamic(static_cast<int64_t>(element_shape_value.size()));
95+
for (size_t idx = 0; idx < element_shape_value.size(); ++idx) {
96+
comp_element_shape[idx] = (element_shape_value[idx] >= 0)
97+
? static_cast<int64_t>(element_shape_value[idx])
98+
: ov::Dimension::dynamic();
99+
}
100+
} else if (input_value(0).get_partial_shape().rank().is_static()) {
101+
// the second try to deduce element shape if it is still of dynamic rank
102+
auto tensor_list_rank = input_value(0).get_partial_shape().rank().get_length();
103+
OPENVINO_ASSERT(
104+
tensor_list_rank > 0,
105+
"[TensorFlow Frontend] internal error or inconsistent model: tensor list rank must be greater than 0");
106+
// exclude tensor dimension (or batch)
107+
comp_element_shape = ov::PartialShape::dynamic(tensor_list_rank - 1);
108+
for (int64_t idx = 1; idx < tensor_list_rank; ++idx) {
109+
comp_element_shape[idx - 1] = input_value(0).get_partial_shape()[idx];
110+
}
111+
}
112+
113+
// deduce an element (Tensor) type
114+
if (m_element_type.is_dynamic() && input_value(0).get_element_type().is_static()) {
115+
m_element_type = input_value(0).get_element_type();
116+
}
117+
118+
set_output_type(0, m_element_type, comp_element_shape);
119+
}
120+
121+
ov::element::Type get_element_type() const {
122+
return m_element_type;
123+
}
124+
125+
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
126+
FRONT_END_OP_CONVERSION_CHECK(inputs.size() == 3,
127+
"[TensorFlow Frontend] internal error: TensorListGetItem expects three inputs");
128+
auto tensor_list_get_item =
129+
std::make_shared<TensorListGetItem>(inputs[0], inputs[1], inputs[2], m_element_type, m_decoder);
130+
tensor_list_get_item->set_attrs(get_attrs());
131+
return tensor_list_get_item;
132+
}
133+
134+
private:
135+
ov::element::Type m_element_type;
136+
};
137+
138+
// Internal operation for TensorListSetItem
139+
// it inserts tensor to tensor list by index
140+
class TensorListSetItem : public InternalOperation {
141+
public:
142+
OPENVINO_OP("TensorListSetItem", "ov::frontend::tensorflow", InternalOperation);
143+
144+
TensorListSetItem(const Output<Node>& input_handle,
145+
const Output<Node>& index,
146+
const Output<Node>& item,
147+
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
148+
: InternalOperation(decoder, OutputVector{input_handle, index, item}, 1, "TensorListSetItem") {
149+
validate_and_infer_types();
150+
}
151+
152+
void validate_and_infer_types() override {
153+
// deduce a type of elements in tensor list
154+
ov::element::Type element_type = ov::element::dynamic;
155+
if (input_value(0).get_element_type().is_static()) {
156+
element_type = input_value(0).get_element_type();
157+
} else if (input_value(2).get_element_type().is_static()) {
158+
element_type = input_value(2).get_element_type();
159+
}
160+
161+
// deduce a shape of tensor list [num_tensors, <tensor shape>]
162+
ov::PartialShape tensor_list_shape = ov::PartialShape::dynamic();
163+
if (input_value(2).get_partial_shape().rank().is_static()) {
164+
auto element_rank = input_value(2).get_partial_shape().rank().get_length();
165+
tensor_list_shape = ov::PartialShape::dynamic(element_rank + 1);
166+
for (int64_t idx = 0; idx < element_rank; ++idx) {
167+
tensor_list_shape[idx + 1] = input_value(2).get_partial_shape()[idx];
168+
}
169+
}
170+
171+
set_output_type(0, element_type, tensor_list_shape);
172+
}
173+
174+
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
175+
FRONT_END_OP_CONVERSION_CHECK(inputs.size() == 3,
176+
"[TensorFlow Frontend] internal error: TensorListSetItem expects three inputs");
177+
auto tensor_list_set_item = std::make_shared<TensorListSetItem>(inputs[0], inputs[1], inputs[2], m_decoder);
178+
tensor_list_set_item->set_attrs(get_attrs());
179+
return tensor_list_set_item;
180+
}
181+
};
182+
183+
} // namespace tensorflow
184+
} // namespace frontend
185+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <memory>
8+
#include <utility>
9+
10+
#include "openvino/pass/graph_rewrite.hpp"
11+
#include "openvino/pass/pass.hpp"
12+
13+
namespace ov {
14+
namespace frontend {
15+
namespace tensorflow {
16+
namespace pass {
17+
18+
// Replace internal operation TensorListReserve with a sub-graph producing initial container
19+
class TensorListReplacer : public ov::pass::MatcherPass {
20+
public:
21+
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorListReplacer");
22+
TensorListReplacer();
23+
};
24+
25+
// Replace internal operation TensorListSetItem with a sub-graph that inserts a new tensor into container
26+
class TensorListSetItemReplacer : public ov::pass::MatcherPass {
27+
public:
28+
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorListSetItemReplacer");
29+
TensorListSetItemReplacer();
30+
};
31+
32+
// Replace internal operation TensorListGetItem with a sub-graph that gets a tensor from container by index
33+
class TensorListGetItemReplacer : public ov::pass::MatcherPass {
34+
public:
35+
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorListGetItemReplacer");
36+
TensorListGetItemReplacer();
37+
};
38+
39+
// Replace and optimize sub-graphs with TensorList operations such as TensorListReserve,
40+
// TensorListSetItem, TensorListGetItem
41+
class TensorListOperationsResolver : public ov::pass::GraphRewrite {
42+
public:
43+
OPENVINO_RTTI("TensorListOperationsResolver", "0");
44+
TensorListOperationsResolver() {
45+
add_matcher<TensorListReplacer>();
46+
add_matcher<TensorListSetItemReplacer>();
47+
add_matcher<TensorListGetItemReplacer>();
48+
}
49+
};
50+
51+
} // namespace pass
52+
} // namespace tensorflow
53+
} // namespace frontend
54+
} // namespace ov

0 commit comments

Comments
 (0)