@@ -141,6 +141,14 @@ void LoopInfo::set_first_iter_handler(LoopInfo::FirstIterHandler first_iter_hand
141
141
m_first_iter_handler = std::move (first_iter_handler);
142
142
}
143
143
144
+ void LoopInfo::update_entry_points (const std::function<void (LoopPort&)>& updater) {
145
+ std::for_each (m_entry_points.begin (), m_entry_points.end (), updater);
146
+ }
147
+
148
+ void LoopInfo::update_exit_points (const std::function<void (LoopPort&)>& updater) {
149
+ std::for_each (m_exit_points.begin (), m_exit_points.end (), updater);
150
+ }
151
+
144
152
bool operator ==(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopManager::LoopPort& rhs) {
145
153
if (&lhs == &rhs)
146
154
return true ;
@@ -194,26 +202,19 @@ std::vector<size_t> LinearIR::LoopManager::get_outer_expr_loops(const Expression
194
202
return std::vector<size_t >(loop_ids.cbegin (), it);
195
203
}
196
204
197
- void LinearIR::LoopManager::get_loop_bounds (const LinearIR &linear_ir,
198
- size_t loop_id,
199
- LinearIR::constExprIt &loop_begin_pos,
200
- LinearIR::constExprIt &loop_end_pos,
201
- bool loop_ops_inserted) const {
205
+ std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LinearIR::LoopManager::get_loop_bounds (const LinearIR &linear_ir, size_t loop_id) const {
202
206
const auto loop_info = get_loop_info (loop_id);
203
- get_loop_bounds (linear_ir, loop_info->get_entry_points (), loop_info->get_exit_points (), loop_begin_pos, loop_end_pos, loop_id, loop_ops_inserted );
207
+ return get_loop_bounds (linear_ir, loop_id, loop_info->get_entry_points (), loop_info->get_exit_points ());
204
208
}
205
209
206
- void LinearIR::LoopManager::get_loop_bounds (const LinearIR &linear_ir,
207
- const std::vector<LoopPort>& entries,
208
- const std::vector<LoopPort>& exits,
209
- LinearIR::constExprIt &loop_begin_pos,
210
- LinearIR::constExprIt &loop_end_pos,
211
- size_t loop_id, bool loop_ops_inserted) {
210
+ std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LinearIR::LoopManager::get_loop_bounds (const LinearIR &linear_ir, size_t loop_id,
211
+ const std::vector<LoopPort>& entries,
212
+ const std::vector<LoopPort>& exits) {
212
213
OPENVINO_ASSERT (!entries.empty (), " Loop must have entry points" );
213
214
OPENVINO_ASSERT (!exits.empty (), " Loop must have entry points" );
214
- const auto & entry_expr = entries.front ().expr_port ->get_expr ();
215
- loop_begin_pos = linear_ir.find (entry_expr);
216
215
216
+ const auto & entry_expr = entries.front ().expr_port ->get_expr ();
217
+ auto loop_begin_pos = linear_ir.find (entry_expr);
217
218
// Some operations in Loop can be before first entry points: Scalars, VectorBuffer.
218
219
// We should iterate by them till the expr is in the corresponding Loop
219
220
auto prev_loop_ids = (*std::prev (loop_begin_pos))->get_loop_ids ();
@@ -222,18 +223,28 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
222
223
prev_loop_ids = (*std::prev (loop_begin_pos))->get_loop_ids ();
223
224
}
224
225
225
- if (loop_ops_inserted) {
226
- const auto loop_begin = ov::as_type_ptr<op::LoopBegin>((*std::prev (loop_begin_pos))->get_node ());
227
- OPENVINO_ASSERT (loop_begin, " Failed explicit loop bounds getting: LoopBegin has not been found" );
228
- const auto loop_end = loop_begin->get_loop_end ();
229
- OPENVINO_ASSERT (loop_end->get_id () == loop_id, " Failed explicit loop bounds getting: Loop bounds with correct ID have not been found" );
230
- loop_begin_pos = std::prev (loop_begin_pos);
231
- loop_end_pos = linear_ir.find_after (loop_begin_pos, linear_ir.get_expr_by_node (loop_end));
232
- } else {
233
- // At the moment all Loops must have exit points
234
- const auto & exit_expr = exits.back ().expr_port ->get_expr ();
235
- loop_end_pos = std::next (linear_ir.find_after (loop_begin_pos, exit_expr));
226
+ const auto & exit_expr = exits.back ().expr_port ->get_expr ();
227
+ auto loop_end_pos = std::next (linear_ir.find_after (loop_begin_pos, exit_expr));
228
+ // There might be LoopEnd with another `loop_id` but in the target Loop as well.
229
+ auto current_loop_ids = (*loop_end_pos)->get_loop_ids ();
230
+ while (std::find (current_loop_ids.begin (), current_loop_ids.end (), loop_id) != current_loop_ids.end ()) {
231
+ loop_end_pos = std::next (loop_end_pos);
232
+ current_loop_ids = (*loop_end_pos)->get_loop_ids ();
236
233
}
234
+
235
+ // Check for the existing LoopBegin/LoopEnd
236
+ if (const auto loop_end = ov::as_type_ptr<op::LoopEnd>((*loop_end_pos)->get_node ())) {
237
+ if (loop_end->get_id () == loop_id) {
238
+ // loop_begin_pos is iterator of LoopBegin now
239
+ // loop_end_pos is iterator of LoopEnd now
240
+ loop_begin_pos = std::prev (loop_begin_pos);
241
+ const auto loop_begin = loop_end->get_loop_begin ();
242
+ OPENVINO_ASSERT ((*loop_begin_pos)->get_node () == loop_begin, " LoopBegin has not been found!" );
243
+ }
244
+ }
245
+
246
+ OPENVINO_ASSERT (loop_begin_pos != linear_ir.cend () && loop_end_pos != linear_ir.cend (), " Loop bounds haven't been found!" );
247
+ return std::make_pair (loop_begin_pos, loop_end_pos);
237
248
}
238
249
239
250
LinearIR::LoopManager::LoopPort LinearIR::LoopManager::get_loop_port_by_expr_port (const ExpressionPort& expr_port, const size_t loop_id) {
@@ -353,27 +364,35 @@ size_t LinearIR::LoopManager::replace_with_new_loop(const LinearIR& linear_ir,
353
364
const std::vector<LoopPort>& entries,
354
365
const std::vector<LoopPort>& exits,
355
366
const size_t old_id) {
356
- const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
357
- const auto loop_id = this ->add_loop_info (loop_info);
358
- for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
367
+ const auto is_bound_explicit_loop_begin = ov::is_type<op::LoopBegin>(loop_begin_pos->get ()->get_node ());
368
+ const auto is_bound_explicit_loop_end = ov::is_type<op::LoopEnd>(std::prev (loop_end_pos)->get ()->get_node ());
369
+ OPENVINO_ASSERT ((is_bound_explicit_loop_begin && is_bound_explicit_loop_end) ||
370
+ (!is_bound_explicit_loop_begin && !is_bound_explicit_loop_end),
371
+ " Incorrect LoopBounds!" );
372
+ const auto explicit_loop_bounds = is_bound_explicit_loop_begin && is_bound_explicit_loop_end;
373
+
374
+ const auto loop_id =
375
+ this ->add_loop_info (std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits));
376
+ const auto begin = explicit_loop_bounds ? std::next (loop_begin_pos) : loop_begin_pos;
377
+ const auto end = explicit_loop_bounds ? std::prev (loop_end_pos) : loop_end_pos;
378
+ for (auto expr_it = begin; expr_it != end; ++expr_it) {
359
379
replace_loop_id (*expr_it, old_id, loop_id);
360
380
}
361
381
362
- const auto old_loop_info = this -> get_loop_info (old_id);
363
- const auto old_loop_begin_pos = linear_ir. find (old_loop_info-> get_entry_points (). front (). expr_port -> get_expr ());
364
- const auto old_loop_end_pos = linear_ir. find (old_loop_info-> get_exit_points (). back (). expr_port -> get_expr () );
382
+ // Check that other expression in LinearIR doesn't have the old loop ID - otherwise completely removed from loop
383
+ // manager
384
+ const auto old_loop_bounds = get_loop_bounds (linear_ir, old_id );
365
385
// If new bounds are equal to old loop bounds, this means that old Loop is removed totally from LIR
366
386
// In this case old loop info must be completely removed from loop manager
367
- if (loop_begin_pos == old_loop_begin_pos && loop_end_pos == old_loop_end_pos ) {
387
+ if (loop_begin_pos == old_loop_bounds. first && loop_end_pos == old_loop_bounds. second ) {
368
388
this ->remove_loop_info (old_id);
369
389
}
370
390
return loop_id;
371
391
}
372
392
373
393
void LinearIR::LoopManager::fuse_loops (const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper) {
374
- LinearIR::constExprIt loop_begin_target, loop_end_target;
375
- get_loop_bounds (linear_ir, fuse_into_upper ? loop_id_lower : loop_id_upper, loop_begin_target, loop_end_target);
376
- fuse_loops (loop_begin_target, loop_end_target, loop_id_upper, loop_id_lower, fuse_into_upper);
394
+ const auto loop_bounds = get_loop_bounds (linear_ir, fuse_into_upper ? loop_id_lower : loop_id_upper);
395
+ fuse_loops (loop_bounds.first , loop_bounds.second , loop_id_upper, loop_id_lower, fuse_into_upper);
377
396
}
378
397
379
398
void LinearIR::LoopManager::fuse_loops (LinearIR::constExprIt loop_begin_target, LinearIR::constExprIt loop_end_target,
@@ -541,6 +560,7 @@ void LinearIR::LoopManager::sort_loop_ports(LinearIR::constExprIt& loop_begin_po
541
560
542
561
void LinearIR::LoopManager::insert_loop_id (const ExpressionPtr& expr, size_t new_id, bool before, size_t target_id) {
543
562
OPENVINO_ASSERT (m_map.count (new_id) == 1 , " Failed marking expression by Loop ID: the Loop with this ID hasn't registered" );
563
+ OPENVINO_ASSERT (!is_loop_id_found (expr, new_id), " Expression cannot have several the same Loop IDs" );
544
564
auto & loop_ids = expr->m_loop_ids ;
545
565
OPENVINO_ASSERT (std::find (loop_ids.cbegin (), loop_ids.cend (), new_id) == loop_ids.cend (),
546
566
" Expression cannot have several the same Loop IDs" );
@@ -568,6 +588,7 @@ void LinearIR::LoopManager::insert_loop_ids(const ExpressionPtr& expr, const std
568
588
569
589
void LinearIR::LoopManager::replace_loop_id (const ExpressionPtr& expr, size_t prev_id, size_t new_id) {
570
590
OPENVINO_ASSERT (m_map.count (new_id), " Failed marking expression by Loop ID: the Loop with this ID hasn't registered" );
591
+ OPENVINO_ASSERT (!is_loop_id_found (expr, new_id), " Expression cannot have several the same Loop IDs" );
571
592
auto & loop_ids = expr->m_loop_ids ;
572
593
OPENVINO_ASSERT (std::find (loop_ids.cbegin (), loop_ids.cend (), new_id) == loop_ids.cend (),
573
594
" Expression already has the Loop with ID " + std::to_string (new_id));
@@ -584,6 +605,11 @@ void LinearIR::LoopManager::remove_loop_id(const ExpressionPtr& expr, size_t id)
584
605
loop_ids.erase (it);
585
606
}
586
607
608
+ bool LinearIR::LoopManager::is_loop_id_found (const ExpressionPtr& expr, size_t id) {
609
+ const auto loop_ids = expr->get_loop_ids ();
610
+ return std::find (loop_ids.cbegin (), loop_ids.cend (), id) != loop_ids.cend ();
611
+ }
612
+
587
613
}// namespace lowered
588
614
}// namespace snippets
589
615
}// namespace ov
0 commit comments