Skip to content

Commit 7297835

Browse files
committedSep 25, 2024
[Snippets] Fixed the pass ExtractLoopInvariants (#26590)
### Details: - *The pass `ExtractLoopInvariants` cannot extract loop invariant from the outermost or any intermediate loops which have inner loops. This PR adds check that `inner_loop_id == potential_expr->get_loop_ids().back`* - *Added snippets test* ### Tickets: - *152426*
1 parent 2234272 commit 7297835

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed
 

‎src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,14 @@ int64_t get_stride_after_move_outer(const LoopPort& loop_port) {
6161
}
6262
}
6363

64-
bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info) {
64+
bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info, size_t loop_id) {
65+
// Extraction is possible only from the innermost Loop!
66+
// We cannot extract Expression from the outermost or any intermediate Loop with other Loops inside
67+
const auto& loop_ids = expr->get_loop_ids();
68+
OPENVINO_ASSERT(!loop_ids.empty(), "Expression must be in a Loop");
69+
if (loop_ids.back() != loop_id)
70+
return false;
71+
6572
const auto& expr_input_ports = expr->get_input_ports();
6673
const auto& input_port_size = expr_input_ports.size();
6774
if (input_port_size == 0)
@@ -162,7 +169,7 @@ bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir) {
162169
const auto& potential_extractable_exprs = get_loop_input_exprs(inner_loop_input_ports);
163170
bool expr_extracted = false;
164171
for (const auto& port_expr : potential_extractable_exprs) {
165-
if (is_extraction_applicable(port_expr, inner_loop_info)) {
172+
if (is_extraction_applicable(port_expr, inner_loop_info, inner_loop_id)) {
166173
status = true;
167174
LinearIR::constExprIt inner_loop_begin_pos, inner_loop_end_pos;
168175
std::tie(inner_loop_begin_pos, inner_loop_end_pos) = loop_manager->get_loop_bounds(linear_ir, inner_loop_id);

‎src/common/snippets/tests/src/lowered/pass/extracted_loop_invariants.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,40 @@ TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsFromInnermostToLoopOuts
288288
}
289289
}
290290

291+
TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsImpossible) {
292+
const auto input_precision = ov::element::f32;
293+
const ov::Shape input_shape_0{32, 8, 1};
294+
ov::snippets::VectorDims order{1, 2, 0};
295+
ov::snippets::VectorDims layout{0, 1, 2};
296+
ov::snippets::VectorDims subtensor{1, 1};
297+
/*
298+
* < Transpose decomposition >
299+
*
300+
* Param0(32,8,1)
301+
* |
302+
* LoadReshape with order (1,2,0)
303+
* |
304+
* Store
305+
* |
306+
* Result
307+
*/
308+
{
309+
auto param = linear_ir->push_node<ov::opset10::Parameter>(input_precision, input_shape_0);
310+
auto load_reshape = linear_ir->push_node<ov::snippets::op::LoadReshape>(param.second, 1, 0, layout);
311+
auto store = linear_ir->push_node<ov::snippets::op::Store>(load_reshape.second, 1, 0);
312+
init_expr_descriptors(*load_reshape.first, {subtensor, subtensor}, {order, layout});
313+
init_expr_descriptors(*store.first, {subtensor, subtensor}, {layout, layout});
314+
auto result = linear_ir->push_node<ov::opset10::Result>(store.second);
315+
linear_ir->get_loop_manager()->mark_loop(load_reshape.first, result.first, 32, 1,
316+
std::vector<LoopPort>{LoopPort((*load_reshape.first)->get_input_port(0), true, 0)},
317+
std::vector<LoopPort>{LoopPort((*store.first)->get_output_port(0), true, 0)});
318+
linear_ir->get_loop_manager()->mark_loop(load_reshape.first, result.first, 1, 1,
319+
std::vector<LoopPort>{LoopPort((*load_reshape.first)->get_input_port(0), true, 1)},
320+
std::vector<LoopPort>{LoopPort((*store.first)->get_output_port(0), true, 1)});
321+
linear_ir->set_loop_depth(2);
322+
}
323+
}
324+
291325
TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsSplitLoops) {
292326
size_t vector_size = 16;
293327
size_t block_size = 32;

0 commit comments

Comments
 (0)