|
7 | 7 | #include <gtest/gtest.h>
|
8 | 8 |
|
9 | 9 | #include <memory>
|
| 10 | +#include <random> |
10 | 11 | #include <string>
|
11 | 12 |
|
12 | 13 | #include "common_test_utils/ov_test_utils.hpp"
|
@@ -1265,3 +1266,197 @@ TEST_P(FuseLSTMSequencesToBidirectionalLSTMSequenceTest, FusionTest) {
|
1265 | 1266 | INSTANTIATE_TEST_SUITE_P(FuseLSTMSequencesToBidirectionalLSTMSequence,
|
1266 | 1267 | FuseLSTMSequencesToBidirectionalLSTMSequenceTest,
|
1267 | 1268 | testing::Combine(testing::Values(false, true), testing::Values(false, true)));
|
| 1269 | + |
| 1270 | +using LoopWithLSTMCellToLSTMSequenceFusionParam = std::tuple<std::string, // f activation function |
| 1271 | + std::string, // g activation function |
| 1272 | + std::string, // h activation function |
| 1273 | + size_t, // input size |
| 1274 | + size_t>; // hidden size |
| 1275 | + |
| 1276 | +class LoopWithLSTMCellToLSTMSequenceFusionTest |
| 1277 | + : public testing::WithParamInterface<LoopWithLSTMCellToLSTMSequenceFusionParam>, |
| 1278 | + public TransformationTestsF {}; |
| 1279 | + |
| 1280 | +namespace { |
| 1281 | +void generate_weights_value(std::vector<float>& weights_value, const Shape& weights_shape) { |
| 1282 | + weights_value.resize(shape_size(weights_shape)); |
| 1283 | + std::mt19937 rng(9812); |
| 1284 | + std::uniform_real_distribution<float> distribution(-300, 300); |
| 1285 | + for (size_t i = 0; i < weights_value.size(); ++i) { |
| 1286 | + weights_value[i] = distribution(rng); |
| 1287 | + } |
| 1288 | +} |
| 1289 | +} // namespace |
| 1290 | + |
| 1291 | +TEST_P(LoopWithLSTMCellToLSTMSequenceFusionTest, FusionTest) { |
| 1292 | + const auto& param = GetParam(); |
| 1293 | + const std::string& f_activation = std::get<0>(param); |
| 1294 | + const std::string& g_activation = std::get<1>(param); |
| 1295 | + const std::string& h_activation = std::get<2>(param); |
| 1296 | + size_t input_size = std::get<3>(param); |
| 1297 | + size_t hidden_size = std::get<4>(param); |
| 1298 | + size_t batch_size = 2; |
| 1299 | + size_t time_len = 10; |
| 1300 | + |
| 1301 | + // generate weights values |
| 1302 | + // w must be of a shape [input_size, hidden_size] |
| 1303 | + // r must be of a shape [hidden_size, hidden_size] |
| 1304 | + // b must be of a shape [hidden_size] |
| 1305 | + Shape w_shape({4 * hidden_size, input_size}); |
| 1306 | + Shape r_shape({4 * hidden_size, hidden_size}); |
| 1307 | + Shape b_shape({4 * hidden_size}); |
| 1308 | + std::vector<float> w, r, b; |
| 1309 | + generate_weights_value(w, w_shape); |
| 1310 | + generate_weights_value(r, r_shape); |
| 1311 | + generate_weights_value(b, b_shape); |
| 1312 | + |
| 1313 | + { |
| 1314 | + // create body graph with LSTMCell |
| 1315 | + auto xi = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, batch_size, input_size}); |
| 1316 | + auto squeeze_axis = std::make_shared<op::v0::Constant>(element::i64, Shape{}, 0); |
| 1317 | + auto xi_squeeze = std::make_shared<op::v0::Squeeze>(xi, squeeze_axis); |
| 1318 | + auto init_hidden_state = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size}); |
| 1319 | + auto init_cell_state = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size}); |
| 1320 | + auto w_const = op::v0::Constant::create(element::f32, w_shape, w); |
| 1321 | + auto r_const = op::v0::Constant::create(element::f32, r_shape, r); |
| 1322 | + auto b_const = op::v0::Constant::create(element::f32, b_shape, b); |
| 1323 | + auto lstm_cell = |
| 1324 | + std::make_shared<op::v4::LSTMCell>(xi_squeeze, |
| 1325 | + init_hidden_state, |
| 1326 | + init_cell_state, |
| 1327 | + w_const, |
| 1328 | + r_const, |
| 1329 | + b_const, |
| 1330 | + hidden_size, |
| 1331 | + std::vector<std::string>{f_activation, g_activation, h_activation}); |
| 1332 | + |
| 1333 | + auto hidden_state_res = std::make_shared<op::v0::Result>(lstm_cell->output(0)); |
| 1334 | + auto cell_state_res = std::make_shared<op::v0::Result>(lstm_cell->output(1)); |
| 1335 | + auto unsqueeze_axis = std::make_shared<op::v0::Constant>(element::i64, Shape{}, 0); |
| 1336 | + auto unsqueeze_hidden_state = std::make_shared<op::v0::Unsqueeze>(lstm_cell->output(0), unsqueeze_axis); |
| 1337 | + auto unsqueeze_hidden_state_res = std::make_shared<op::v0::Result>(unsqueeze_hidden_state); |
| 1338 | + |
| 1339 | + // conditional graph |
| 1340 | + auto num_iters = std::make_shared<op::v0::Parameter>(element::i32, Shape{1}); |
| 1341 | + auto counter = std::make_shared<op::v0::Parameter>(element::i32, Shape{1}); |
| 1342 | + auto increment = std::make_shared<op::v0::Constant>(element::i32, Shape{}, 1); |
| 1343 | + auto add = std::make_shared<op::v1::Add>(counter, increment); |
| 1344 | + auto updated_counter = std::make_shared<op::v0::Result>(add); |
| 1345 | + auto less = std::make_shared<op::v1::Less>(add, num_iters); |
| 1346 | + auto less_res = std::make_shared<op::v0::Result>(less); |
| 1347 | + |
| 1348 | + auto body_graph = std::make_shared<Model>( |
| 1349 | + ResultVector{hidden_state_res, cell_state_res, unsqueeze_hidden_state_res, less_res, updated_counter}, |
| 1350 | + ParameterVector{xi, init_hidden_state, init_cell_state, num_iters, counter}); |
| 1351 | + |
| 1352 | + // create main graph with Loop |
| 1353 | + auto x = std::make_shared<op::v0::Parameter>(element::f32, Shape{time_len, batch_size, input_size}); |
| 1354 | + auto h_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size}); |
| 1355 | + auto c_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size}); |
| 1356 | + auto execution_cond = std::make_shared<op::v0::Constant>(ov::element::boolean, ov::Shape{}, true); |
| 1357 | + auto max_iter = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, -1); |
| 1358 | + auto num_iter_const = |
| 1359 | + std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(time_len)); |
| 1360 | + auto counter_const = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, 0); |
| 1361 | + |
| 1362 | + auto loop_node = std::make_shared<op::v5::Loop>(max_iter, execution_cond); |
| 1363 | + |
| 1364 | + loop_node->set_function(body_graph); |
| 1365 | + loop_node->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 3}); |
| 1366 | + |
| 1367 | + // set inputs for Loop |
| 1368 | + // x input will be sliced for each time step |
| 1369 | + loop_node->set_sliced_input(xi, x, 0, 1, 1, -1, 0); |
| 1370 | + // set back edges for cell and hidden states |
| 1371 | + // since they are changing through timeline |
| 1372 | + loop_node->set_merged_input(init_hidden_state, h_init, hidden_state_res); |
| 1373 | + loop_node->set_merged_input(init_cell_state, c_init, cell_state_res); |
| 1374 | + loop_node->set_invariant_input(num_iters, num_iter_const); |
| 1375 | + loop_node->set_merged_input(counter, counter_const, updated_counter); |
| 1376 | + |
| 1377 | + // set external outputs for Loop node |
| 1378 | + // concatenated cell and hidden states from all time steps |
| 1379 | + auto hs = loop_node->get_concatenated_slices(unsqueeze_hidden_state_res, 0, 1, 1, -1, 0); |
| 1380 | + auto hs_res = std::make_shared<op::v0::Result>(hs); |
| 1381 | + |
| 1382 | + model = std::make_shared<Model>(ResultVector{hs_res}, ParameterVector{x, h_init, c_init}); |
| 1383 | + manager.register_pass<ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence>(); |
| 1384 | + } |
| 1385 | + |
| 1386 | + { |
| 1387 | + auto x = std::make_shared<op::v0::Parameter>(element::f32, Shape{time_len, batch_size, input_size}); |
| 1388 | + auto h_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size}); |
| 1389 | + auto c_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size}); |
| 1390 | + |
| 1391 | + // transpose x since LSTMSequence expects x in a format [batch_size, time_len, input_size] |
| 1392 | + auto tr_order = |
| 1393 | + std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 0, 2}); |
| 1394 | + auto tr_x = std::make_shared<op::v1::Transpose>(x, tr_order); |
| 1395 | + // prepare init hidden and cell states to have a format [batch_size, num_directions, hidden_size] |
| 1396 | + // where num_directions equals one |
| 1397 | + auto unsqueeze_axis = |
| 1398 | + std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{1}); |
| 1399 | + auto h_init_unsqueeze = std::make_shared<op::v0::Unsqueeze>(h_init, unsqueeze_axis); |
| 1400 | + auto c_init_unsqueeze = std::make_shared<op::v0::Unsqueeze>(c_init, unsqueeze_axis); |
| 1401 | + // prepare seq_lens |
| 1402 | + auto batch_size = std::make_shared<op::v3::ShapeOf>(x, element::i64)->output(0); |
| 1403 | + auto begin = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{1}); |
| 1404 | + auto end = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{2}); |
| 1405 | + auto stride = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{1}); |
| 1406 | + batch_size = std::make_shared<op::v1::StridedSlice>(batch_size, |
| 1407 | + begin, |
| 1408 | + end, |
| 1409 | + stride, |
| 1410 | + std::vector<int64_t>{0}, |
| 1411 | + std::vector<int64_t>{0}); |
| 1412 | + auto num_iter_const = |
| 1413 | + std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(time_len)); |
| 1414 | + auto seq_lens = std::make_shared<op::v1::Broadcast>(num_iter_const, batch_size); |
| 1415 | + // prepare W, R, B weights to a format with num_directions dimension |
| 1416 | + auto w_const = op::v0::Constant::create(element::f32, w_shape, w); |
| 1417 | + auto r_const = op::v0::Constant::create(element::f32, r_shape, r); |
| 1418 | + auto b_const = op::v0::Constant::create(element::f32, b_shape, b); |
| 1419 | + auto unsqueeze_axis2 = |
| 1420 | + std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{0}); |
| 1421 | + auto w = std::make_shared<op::v0::Unsqueeze>(w_const, unsqueeze_axis2); |
| 1422 | + auto r = std::make_shared<op::v0::Unsqueeze>(r_const, unsqueeze_axis2); |
| 1423 | + auto b = std::make_shared<op::v0::Unsqueeze>(b_const, unsqueeze_axis2); |
| 1424 | + |
| 1425 | + // create LSTMSequence |
| 1426 | + auto lstm_sequence = std::make_shared<ov::op::v5::LSTMSequence>( |
| 1427 | + tr_x, |
| 1428 | + h_init_unsqueeze, |
| 1429 | + c_init_unsqueeze, |
| 1430 | + seq_lens, |
| 1431 | + w, |
| 1432 | + r, |
| 1433 | + b, |
| 1434 | + hidden_size, |
| 1435 | + ov::op::RecurrentSequenceDirection::FORWARD, |
| 1436 | + std::vector<float>{}, |
| 1437 | + std::vector<float>{}, |
| 1438 | + std::vector<std::string>{f_activation, g_activation, h_activation}, |
| 1439 | + 0.0f); |
| 1440 | + |
| 1441 | + // prepare output |
| 1442 | + auto squeeze_axis = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, 1); |
| 1443 | + auto squeeze_output_hs = std::make_shared<op::v0::Squeeze>(lstm_sequence->output(0), squeeze_axis); |
| 1444 | + auto tr_order2 = |
| 1445 | + std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 0, 2}); |
| 1446 | + auto tr_squeeze_output_hs = std::make_shared<op::v1::Transpose>(squeeze_output_hs, tr_order2); |
| 1447 | + auto output_hs_res = std::make_shared<op::v0::Result>(tr_squeeze_output_hs); |
| 1448 | + model_ref = std::make_shared<Model>(ResultVector{output_hs_res}, ParameterVector{x, h_init, c_init}); |
| 1449 | + } |
| 1450 | + |
| 1451 | + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); |
| 1452 | + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); |
| 1453 | + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); |
| 1454 | +} |
| 1455 | + |
| 1456 | +INSTANTIATE_TEST_SUITE_P(LoopWithLSTMCellToLSTMSequenceFusion, |
| 1457 | + LoopWithLSTMCellToLSTMSequenceFusionTest, |
| 1458 | + testing::Combine(testing::Values("sigmoid", "tanh"), |
| 1459 | + testing::Values("sigmoid", "relu"), |
| 1460 | + testing::Values("tanh", "relu"), |
| 1461 | + testing::Values(2, 3), |
| 1462 | + testing::Values(3, 4))); |
0 commit comments