@@ -127,12 +127,14 @@ void InsertTailLoop::propagate_updated_subtensor_through_loop(const LinearIR& li
127
127
(*expr_it)->updateShapes ();
128
128
}
129
129
130
- LinearIR::container InsertTailLoop::copy_loop ( const LinearIR& linear_ir, const size_t loop_id) {
130
+ LinearIR::constExprIt InsertTailLoop::insert_copy_loop ( LinearIR& linear_ir, const size_t loop_id, const LinearIR::constExprIt& insert_pos ) {
131
131
const auto & loop_manager = linear_ir.get_loop_manager ();
132
132
LinearIR::constExprIt loop_begin_pos, loop_end_pos;
133
133
loop_manager->get_loop_bounds (linear_ir, loop_id, loop_begin_pos, loop_end_pos, true );
134
134
ExressionMap expression_map;
135
135
const auto & loop_copy_range = LinearIR::deep_copy_range (loop_begin_pos, std::next (loop_end_pos), expression_map);
136
+ const auto new_loop_begin_pos = linear_ir.insert (insert_pos, loop_copy_range.begin (), loop_copy_range.end ());
137
+ const auto new_loop_end_pos = insert_pos;
136
138
137
139
const auto original_loop_info = loop_manager->get_loop_info (loop_id);
138
140
std::vector<LinearIR::LoopManager::LoopPort> new_entry_points, new_exit_points;
@@ -156,11 +158,9 @@ LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const s
156
158
loop_manager->update_loops_port (outer_loop_ids, expr->get_output_port (i), {expr->get_output_port (i), new_expr->get_output_port (i)}, false );
157
159
}
158
160
159
- const auto new_loop_begin_pos = loop_copy_range.begin ();
160
- const auto new_loop_end_pos = loop_copy_range.end ();
161
161
const auto new_id = loop_manager->replace_with_new_loop (linear_ir,
162
- std::next ( new_loop_begin_pos) ,
163
- std::prev ( new_loop_end_pos) ,
162
+ new_loop_begin_pos,
163
+ new_loop_end_pos,
164
164
original_loop_info->get_work_amount (),
165
165
original_loop_info->get_increment (),
166
166
new_entry_points,
@@ -169,7 +169,7 @@ LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const s
169
169
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev (new_loop_end_pos)->get ()->get_node ());
170
170
OPENVINO_ASSERT (loop_end, " Cloned Loop does not contain LoopEnd op at the expected place." );
171
171
loop_end->set_id (new_id);
172
- return loop_copy_range ;
172
+ return new_loop_begin_pos ;
173
173
}
174
174
175
175
void InsertTailLoop::create_tail_loop (LinearIR& linear_ir,
@@ -186,17 +186,16 @@ void InsertTailLoop::create_tail_loop(LinearIR& linear_ir,
186
186
auto original_loop_info = loop_manager->get_loop_info (original_loop_id);
187
187
auto tail_loop_info = original_loop_info;
188
188
if (need_vector_loop) {
189
- const auto new_loop_range = copy_loop (linear_ir, original_loop_id);
190
- const auto new_loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev (new_loop_range.end ())->get ()->get_node ());
191
- OPENVINO_ASSERT (new_loop_end, " Cloned Loop does not contain LoopEnd op at the expected place." );
192
- tail_loop_info = original_loop_info;
193
- original_loop_info = loop_manager->get_loop_info (new_loop_end->get_id ());
194
-
195
189
// Note: new loop body is inserted before the original loop
196
190
// So new loop becomes a main vector loop, the original loop becomes tail loop
197
191
// This is done in such way to have original ops from the main body at the end:
198
192
// this allows us to conveniently interact with outer loops in further passes
199
- linear_ir.insert (begin, new_loop_range.begin (), new_loop_range.end ());
193
+ const auto new_loop_begin_pos = insert_copy_loop (linear_ir, original_loop_id, begin);
194
+ const auto new_loop_begin = ov::as_type_ptr<op::LoopBegin>(new_loop_begin_pos->get ()->get_node ());
195
+ OPENVINO_ASSERT (new_loop_begin, " Cloned Loop does not contain LoopBegin op at the expected place." );
196
+ const auto new_loop_end = new_loop_begin->get_loop_end ();
197
+ tail_loop_info = original_loop_info;
198
+ original_loop_info = loop_manager->get_loop_info (new_loop_end->get_id ());
200
199
201
200
const auto new_vector_loop_wa = original_loop_info->get_work_amount () - tail_size;
202
201
original_loop_info->set_work_amount (new_vector_loop_wa);
0 commit comments