From 1cf592cf8558fdd34df58e1950274bf99314568b Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 20 Sep 2024 17:42:54 +0400 Subject: [PATCH] [Snippets] Fixed the pass ExtractLoopInvariants (#26590) ### Details: - *The pass `ExtractLoopInvariants` cannot extract loop invariant from the outermost or any intermediate loops which have inner loops. This PR adds check that `inner_loop_id == potential_expr->get_loop_ids().back`* - *Added snippets test* ### Tickets: - *152426* --- .../lowered/pass/extract_loop_invariants.cpp | 11 ++++-- .../pass/extracted_loop_invariants.cpp | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp b/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp index f20ace893df463..27f4713810bfe0 100644 --- a/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp +++ b/src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp @@ -61,7 +61,14 @@ int64_t get_stride_after_move_outer(const LoopPort& loop_port) { } } -bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info) { +bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info, size_t loop_id) { + // Extraction is possible only from the innermost Loop! + // We cannot extract Expression from the outermost or any intermediate Loop with other Loops inside + const auto& loop_ids = expr->get_loop_ids(); + OPENVINO_ASSERT(!loop_ids.empty(), "Expression must be in a Loop"); + if (loop_ids.back() != loop_id) + return false; + const auto& expr_input_ports = expr->get_input_ports(); const auto& input_port_size = expr_input_ports.size(); if (input_port_size == 0) @@ -162,7 +169,7 @@ bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir) { const auto& potential_extractable_exprs = get_loop_input_exprs(inner_loop_input_ports); bool expr_extracted = false; for (const auto& port_expr : potential_extractable_exprs) { - if (is_extraction_applicable(port_expr, inner_loop_info)) { + if (is_extraction_applicable(port_expr, inner_loop_info, inner_loop_id)) { status = true; LinearIR::constExprIt inner_loop_begin_pos, inner_loop_end_pos; std::tie(inner_loop_begin_pos, inner_loop_end_pos) = loop_manager->get_loop_bounds(linear_ir, inner_loop_id); diff --git a/src/common/snippets/tests/src/lowered/pass/extracted_loop_invariants.cpp b/src/common/snippets/tests/src/lowered/pass/extracted_loop_invariants.cpp index 3e148d3c1cf329..ee76c5af7234d8 100644 --- a/src/common/snippets/tests/src/lowered/pass/extracted_loop_invariants.cpp +++ b/src/common/snippets/tests/src/lowered/pass/extracted_loop_invariants.cpp @@ -288,6 +288,40 @@ TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsFromInnermostToLoopOuts } } +TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsImpossible) { + const auto input_precision = ov::element::f32; + const ov::Shape input_shape_0{32, 8, 1}; + ov::snippets::VectorDims order{1, 2, 0}; + ov::snippets::VectorDims layout{0, 1, 2}; + ov::snippets::VectorDims subtensor{1, 1}; + /* + * < Transpose decomposition > + * + * Param0(32,8,1) + * | + * LoadReshape with order (1,2,0) + * | + * Store + * | + * Result + */ + { + auto param = linear_ir->push_node(input_precision, input_shape_0); + auto load_reshape = linear_ir->push_node(param.second, 1, 0, layout); + auto store = linear_ir->push_node(load_reshape.second, 1, 0); + init_expr_descriptors(*load_reshape.first, {subtensor, subtensor}, {order, layout}); + init_expr_descriptors(*store.first, {subtensor, subtensor}, {layout, layout}); + auto result = linear_ir->push_node(store.second); + linear_ir->get_loop_manager()->mark_loop(load_reshape.first, result.first, 32, 1, + std::vector{LoopPort((*load_reshape.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*store.first)->get_output_port(0), true, 0)}); + linear_ir->get_loop_manager()->mark_loop(load_reshape.first, result.first, 1, 1, + std::vector{LoopPort((*load_reshape.first)->get_input_port(0), true, 1)}, + std::vector{LoopPort((*store.first)->get_output_port(0), true, 1)}); + linear_ir->set_loop_depth(2); + } +} + TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsSplitLoops) { size_t vector_size = 16; size_t block_size = 32;