Skip to content

Commit 07cbea2

Browse files
authored
[TF FE][MOC] Fuse Keras LSTM to LSTMSequence and Optimize TF While with TensorList ops (#25170)
**Details:** Fuse Keras LSTM to LSTMSequence and Optimize TF While with TensorList ops. Loop operations with TensorListSetItem transformed to ConcatOutput outputs. Loop operations with TensorListGetItem transformed to SlicedInput inputs. It helps to fuse six loops with LSTMCell to six LSTM sequence model. It reduces customer model size by twice and increase throughput by 1.97x. **Tickets:** TBD --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent 76a668b commit 07cbea2

File tree

8 files changed

+1134
-53
lines changed

8 files changed

+1134
-53
lines changed

src/common/transformations/include/transformations/op_conversions/convert_ti_to_sequences.hpp

+19-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence;
1818
class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence;
1919
class TRANSFORMATIONS_API ConvertTensorIteratorToSequence;
2020

21+
class TRANSFORMATIONS_API ConvertLoopWithSlicedInputConcatOutputToLSTMSequence;
22+
class TRANSFORMATIONS_API ConvertLoopWithScatterUpdateToLSTMSequence;
2123
class TRANSFORMATIONS_API ConvertLoopToLSTMSequence;
2224
class TRANSFORMATIONS_API FuseReverseLSTMSequence;
2325

@@ -68,14 +70,29 @@ class ov::pass::ConvertTensorIteratorToSequence : public GraphRewrite {
6870
ConvertTensorIteratorToSequence();
6971
};
7072

73+
class ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence : public ov::pass::MatcherPass {
74+
public:
75+
OPENVINO_RTTI("ConvertLoopWithSlicedInputConcatOutputToLSTMSequence", "0");
76+
ConvertLoopWithSlicedInputConcatOutputToLSTMSequence();
77+
};
78+
79+
class ov::pass::ConvertLoopWithScatterUpdateToLSTMSequence : public ov::pass::MatcherPass {
80+
public:
81+
OPENVINO_RTTI("ConvertLoopWithScatterUpdateToLSTMSequence", "0");
82+
ConvertLoopWithScatterUpdateToLSTMSequence();
83+
};
84+
7185
/**
7286
* @ingroup ov_transformation_common_api
7387
* @brief Replaces Loop with LSTMCell inside to LSTMSequence
7488
*/
75-
class ov::pass::ConvertLoopToLSTMSequence : public ov::pass::MatcherPass {
89+
class ov::pass::ConvertLoopToLSTMSequence : public ov::pass::GraphRewrite {
7690
public:
7791
OPENVINO_RTTI("ConvertLoopToLSTMSequence", "0");
78-
ConvertLoopToLSTMSequence();
92+
ConvertLoopToLSTMSequence() {
93+
add_matcher<ov::pass::ConvertLoopWithScatterUpdateToLSTMSequence>();
94+
add_matcher<ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence>();
95+
}
7996
};
8097

8198
/**

src/common/transformations/include/transformations/utils/utils.hpp

+11
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ inline std::string get_ie_output_name(const Output<Node>& output) {
114114
*/
115115
float cast_eps_to_float(double eps_d);
116116

117+
template <typename T>
118+
bool get_constant_value(const std::shared_ptr<ov::Node>& node, T& value) {
119+
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
120+
if (!constant)
121+
return false;
122+
if (shape_size(constant->get_shape()) != 1)
123+
return false;
124+
value = constant->cast_vector<T>()[0];
125+
return true;
126+
}
127+
117128
template <typename T>
118129
bool has_constant_value(const std::shared_ptr<Node>& node,
119130
const T value,

src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp

+471-37
Large diffs are not rendered by default.

src/common/transformations/tests/op_conversions/convert_ti_to_sequences_test.cpp

+195
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <gtest/gtest.h>
88

99
#include <memory>
10+
#include <random>
1011
#include <string>
1112

1213
#include "common_test_utils/ov_test_utils.hpp"
@@ -1265,3 +1266,197 @@ TEST_P(FuseLSTMSequencesToBidirectionalLSTMSequenceTest, FusionTest) {
12651266
INSTANTIATE_TEST_SUITE_P(FuseLSTMSequencesToBidirectionalLSTMSequence,
12661267
FuseLSTMSequencesToBidirectionalLSTMSequenceTest,
12671268
testing::Combine(testing::Values(false, true), testing::Values(false, true)));
1269+
1270+
using LoopWithLSTMCellToLSTMSequenceFusionParam = std::tuple<std::string, // f activation function
1271+
std::string, // g activation function
1272+
std::string, // h activation function
1273+
size_t, // input size
1274+
size_t>; // hidden size
1275+
1276+
class LoopWithLSTMCellToLSTMSequenceFusionTest
1277+
: public testing::WithParamInterface<LoopWithLSTMCellToLSTMSequenceFusionParam>,
1278+
public TransformationTestsF {};
1279+
1280+
namespace {
1281+
void generate_weights_value(std::vector<float>& weights_value, const Shape& weights_shape) {
1282+
weights_value.resize(shape_size(weights_shape));
1283+
std::mt19937 rng(9812);
1284+
std::uniform_real_distribution<float> distribution(-300, 300);
1285+
for (size_t i = 0; i < weights_value.size(); ++i) {
1286+
weights_value[i] = distribution(rng);
1287+
}
1288+
}
1289+
} // namespace
1290+
1291+
TEST_P(LoopWithLSTMCellToLSTMSequenceFusionTest, FusionTest) {
1292+
const auto& param = GetParam();
1293+
const std::string& f_activation = std::get<0>(param);
1294+
const std::string& g_activation = std::get<1>(param);
1295+
const std::string& h_activation = std::get<2>(param);
1296+
size_t input_size = std::get<3>(param);
1297+
size_t hidden_size = std::get<4>(param);
1298+
size_t batch_size = 2;
1299+
size_t time_len = 10;
1300+
1301+
// generate weights values
1302+
// w must be of a shape [input_size, hidden_size]
1303+
// r must be of a shape [hidden_size, hidden_size]
1304+
// b must be of a shape [hidden_size]
1305+
Shape w_shape({4 * hidden_size, input_size});
1306+
Shape r_shape({4 * hidden_size, hidden_size});
1307+
Shape b_shape({4 * hidden_size});
1308+
std::vector<float> w, r, b;
1309+
generate_weights_value(w, w_shape);
1310+
generate_weights_value(r, r_shape);
1311+
generate_weights_value(b, b_shape);
1312+
1313+
{
1314+
// create body graph with LSTMCell
1315+
auto xi = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, batch_size, input_size});
1316+
auto squeeze_axis = std::make_shared<op::v0::Constant>(element::i64, Shape{}, 0);
1317+
auto xi_squeeze = std::make_shared<op::v0::Squeeze>(xi, squeeze_axis);
1318+
auto init_hidden_state = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
1319+
auto init_cell_state = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
1320+
auto w_const = op::v0::Constant::create(element::f32, w_shape, w);
1321+
auto r_const = op::v0::Constant::create(element::f32, r_shape, r);
1322+
auto b_const = op::v0::Constant::create(element::f32, b_shape, b);
1323+
auto lstm_cell =
1324+
std::make_shared<op::v4::LSTMCell>(xi_squeeze,
1325+
init_hidden_state,
1326+
init_cell_state,
1327+
w_const,
1328+
r_const,
1329+
b_const,
1330+
hidden_size,
1331+
std::vector<std::string>{f_activation, g_activation, h_activation});
1332+
1333+
auto hidden_state_res = std::make_shared<op::v0::Result>(lstm_cell->output(0));
1334+
auto cell_state_res = std::make_shared<op::v0::Result>(lstm_cell->output(1));
1335+
auto unsqueeze_axis = std::make_shared<op::v0::Constant>(element::i64, Shape{}, 0);
1336+
auto unsqueeze_hidden_state = std::make_shared<op::v0::Unsqueeze>(lstm_cell->output(0), unsqueeze_axis);
1337+
auto unsqueeze_hidden_state_res = std::make_shared<op::v0::Result>(unsqueeze_hidden_state);
1338+
1339+
// conditional graph
1340+
auto num_iters = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
1341+
auto counter = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
1342+
auto increment = std::make_shared<op::v0::Constant>(element::i32, Shape{}, 1);
1343+
auto add = std::make_shared<op::v1::Add>(counter, increment);
1344+
auto updated_counter = std::make_shared<op::v0::Result>(add);
1345+
auto less = std::make_shared<op::v1::Less>(add, num_iters);
1346+
auto less_res = std::make_shared<op::v0::Result>(less);
1347+
1348+
auto body_graph = std::make_shared<Model>(
1349+
ResultVector{hidden_state_res, cell_state_res, unsqueeze_hidden_state_res, less_res, updated_counter},
1350+
ParameterVector{xi, init_hidden_state, init_cell_state, num_iters, counter});
1351+
1352+
// create main graph with Loop
1353+
auto x = std::make_shared<op::v0::Parameter>(element::f32, Shape{time_len, batch_size, input_size});
1354+
auto h_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
1355+
auto c_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
1356+
auto execution_cond = std::make_shared<op::v0::Constant>(ov::element::boolean, ov::Shape{}, true);
1357+
auto max_iter = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, -1);
1358+
auto num_iter_const =
1359+
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(time_len));
1360+
auto counter_const = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, 0);
1361+
1362+
auto loop_node = std::make_shared<op::v5::Loop>(max_iter, execution_cond);
1363+
1364+
loop_node->set_function(body_graph);
1365+
loop_node->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 3});
1366+
1367+
// set inputs for Loop
1368+
// x input will be sliced for each time step
1369+
loop_node->set_sliced_input(xi, x, 0, 1, 1, -1, 0);
1370+
// set back edges for cell and hidden states
1371+
// since they are changing through timeline
1372+
loop_node->set_merged_input(init_hidden_state, h_init, hidden_state_res);
1373+
loop_node->set_merged_input(init_cell_state, c_init, cell_state_res);
1374+
loop_node->set_invariant_input(num_iters, num_iter_const);
1375+
loop_node->set_merged_input(counter, counter_const, updated_counter);
1376+
1377+
// set external outputs for Loop node
1378+
// concatenated cell and hidden states from all time steps
1379+
auto hs = loop_node->get_concatenated_slices(unsqueeze_hidden_state_res, 0, 1, 1, -1, 0);
1380+
auto hs_res = std::make_shared<op::v0::Result>(hs);
1381+
1382+
model = std::make_shared<Model>(ResultVector{hs_res}, ParameterVector{x, h_init, c_init});
1383+
manager.register_pass<ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence>();
1384+
}
1385+
1386+
{
1387+
auto x = std::make_shared<op::v0::Parameter>(element::f32, Shape{time_len, batch_size, input_size});
1388+
auto h_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
1389+
auto c_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
1390+
1391+
// transpose x since LSTMSequence expects x in a format [batch_size, time_len, input_size]
1392+
auto tr_order =
1393+
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 0, 2});
1394+
auto tr_x = std::make_shared<op::v1::Transpose>(x, tr_order);
1395+
// prepare init hidden and cell states to have a format [batch_size, num_directions, hidden_size]
1396+
// where num_directions equals one
1397+
auto unsqueeze_axis =
1398+
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{1});
1399+
auto h_init_unsqueeze = std::make_shared<op::v0::Unsqueeze>(h_init, unsqueeze_axis);
1400+
auto c_init_unsqueeze = std::make_shared<op::v0::Unsqueeze>(c_init, unsqueeze_axis);
1401+
// prepare seq_lens
1402+
auto batch_size = std::make_shared<op::v3::ShapeOf>(x, element::i64)->output(0);
1403+
auto begin = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{1});
1404+
auto end = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{2});
1405+
auto stride = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{1});
1406+
batch_size = std::make_shared<op::v1::StridedSlice>(batch_size,
1407+
begin,
1408+
end,
1409+
stride,
1410+
std::vector<int64_t>{0},
1411+
std::vector<int64_t>{0});
1412+
auto num_iter_const =
1413+
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(time_len));
1414+
auto seq_lens = std::make_shared<op::v1::Broadcast>(num_iter_const, batch_size);
1415+
// prepare W, R, B weights to a format with num_directions dimension
1416+
auto w_const = op::v0::Constant::create(element::f32, w_shape, w);
1417+
auto r_const = op::v0::Constant::create(element::f32, r_shape, r);
1418+
auto b_const = op::v0::Constant::create(element::f32, b_shape, b);
1419+
auto unsqueeze_axis2 =
1420+
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{0});
1421+
auto w = std::make_shared<op::v0::Unsqueeze>(w_const, unsqueeze_axis2);
1422+
auto r = std::make_shared<op::v0::Unsqueeze>(r_const, unsqueeze_axis2);
1423+
auto b = std::make_shared<op::v0::Unsqueeze>(b_const, unsqueeze_axis2);
1424+
1425+
// create LSTMSequence
1426+
auto lstm_sequence = std::make_shared<ov::op::v5::LSTMSequence>(
1427+
tr_x,
1428+
h_init_unsqueeze,
1429+
c_init_unsqueeze,
1430+
seq_lens,
1431+
w,
1432+
r,
1433+
b,
1434+
hidden_size,
1435+
ov::op::RecurrentSequenceDirection::FORWARD,
1436+
std::vector<float>{},
1437+
std::vector<float>{},
1438+
std::vector<std::string>{f_activation, g_activation, h_activation},
1439+
0.0f);
1440+
1441+
// prepare output
1442+
auto squeeze_axis = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, 1);
1443+
auto squeeze_output_hs = std::make_shared<op::v0::Squeeze>(lstm_sequence->output(0), squeeze_axis);
1444+
auto tr_order2 =
1445+
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 0, 2});
1446+
auto tr_squeeze_output_hs = std::make_shared<op::v1::Transpose>(squeeze_output_hs, tr_order2);
1447+
auto output_hs_res = std::make_shared<op::v0::Result>(tr_squeeze_output_hs);
1448+
model_ref = std::make_shared<Model>(ResultVector{output_hs_res}, ParameterVector{x, h_init, c_init});
1449+
}
1450+
1451+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
1452+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
1453+
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
1454+
}
1455+
1456+
INSTANTIATE_TEST_SUITE_P(LoopWithLSTMCellToLSTMSequenceFusion,
1457+
LoopWithLSTMCellToLSTMSequenceFusionTest,
1458+
testing::Combine(testing::Values("sigmoid", "tanh"),
1459+
testing::Values("sigmoid", "relu"),
1460+
testing::Values("tanh", "relu"),
1461+
testing::Values(2, 3),
1462+
testing::Values(3, 4)));

src/frontends/tensorflow/src/frontend.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "openvino/util/file_util.hpp"
2828
#include "openvino/util/log.hpp"
2929
#include "tf_framework_node.hpp"
30+
#include "transformations/common_optimizations/eliminate_loop_inputs_outputs.hpp"
3031
#include "transformations/common_optimizations/remove_concat_zero_dim_input.hpp"
3132
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
3233
#include "transformations/control_flow/unroll_if.hpp"
@@ -568,7 +569,16 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
568569
manager.register_pass<pass::TensorArrayV3Replacer>();
569570
manager.register_pass<pass::ConstToResultRemover>();
570571
manager.register_pass<pass::SwitchMergeResolver>();
571-
manager.register_pass<pass::TensorListOperationsResolver>();
572+
573+
// apply EliminateLoopInputsOutputs to avoid extra Results
574+
// that output the same value as receiving on input
575+
// it is needed for applying TensorListInLoopOptimization
576+
manager.register_pass<ov::pass::EliminateLoopInputsOutputs>();
577+
manager.register_pass<pass::TensorListReplacer>();
578+
manager.register_pass<pass::TensorListInLoopOptimization>();
579+
manager.register_pass<pass::TensorListSetItemReplacer>();
580+
manager.register_pass<pass::TensorListGetItemReplacer>();
581+
572582
manager.register_pass<ov::pass::UnrollIf>();
573583
manager.register_pass<ov::pass::RemoveConcatZeroDimInput>();
574584
manager.register_pass<ov::pass::TransposeSinkingGeneral>();

src/frontends/tensorflow/src/op/block_lstm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void create_decomposed_block_lstm(const Output<Node>& x,
7171
auto squeeze_axis = std::make_shared<v0::Constant>(element::i32, Shape{1}, 0);
7272
auto xi = std::make_shared<v0::Squeeze>(xi_param, squeeze_axis);
7373

74-
auto lstm_cell = std::make_shared<v0::LSTMCell>(xi,
74+
auto lstm_cell = std::make_shared<v4::LSTMCell>(xi,
7575
h_prev_param,
7676
c_prev_param,
7777
w_param,

src/frontends/tensorflow_common/include/helper_transforms/tensor_list_ops_resolver.hpp

+5-12
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
#pragma once
66

7-
#include <memory>
8-
#include <utility>
9-
107
#include "openvino/pass/graph_rewrite.hpp"
118
#include "openvino/pass/pass.hpp"
129

@@ -36,16 +33,12 @@ class TensorListGetItemReplacer : public ov::pass::MatcherPass {
3633
TensorListGetItemReplacer();
3734
};
3835

39-
// Replace and optimize sub-graphs with TensorList operations such as TensorListReserve,
40-
// TensorListSetItem, TensorListGetItem
41-
class TensorListOperationsResolver : public ov::pass::GraphRewrite {
36+
// Optimize sub-graphs with TensorList operations in Loop body graph
37+
// Replace TensorListSetItem and TensorListGetItem with ConcatOutput and SlicedInput
38+
class TensorListInLoopOptimization : public ov::pass::MatcherPass {
4239
public:
43-
OPENVINO_RTTI("TensorListOperationsResolver", "0");
44-
TensorListOperationsResolver() {
45-
add_matcher<TensorListReplacer>();
46-
add_matcher<TensorListSetItemReplacer>();
47-
add_matcher<TensorListGetItemReplacer>();
48-
}
40+
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorListInLoopOptimization");
41+
TensorListInLoopOptimization();
4942
};
5043

5144
} // namespace pass

0 commit comments

Comments
 (0)