Skip to content

Commit cb45c6d

Browse files
committed
back port full fix
1 parent 22568aa commit cb45c6d

9 files changed

+155
-119
lines changed

src/common/snippets/include/snippets/lowered/loop_manager.hpp

+50-15
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ class LinearIR::LoopManager {
8585
void set_outer_splited_loop(bool outer_splited_loop);
8686
void set_first_iter_handler(FirstIterHandler handler);
8787

88+
// Update the parameters of existing LoopPorts
89+
void update_entry_points(const std::function<void(LoopPort&)>& updater);
90+
void update_exit_points(const std::function<void(LoopPort&)>& updater);
91+
8892
private:
8993
size_t m_work_amount = 0;
9094
size_t m_increment = 0;
@@ -101,8 +105,6 @@ class LinearIR::LoopManager {
101105
using LoopInfoPtr = std::shared_ptr<LoopInfo>;
102106

103107
std::shared_ptr<LoopManager> clone_with_new_expr(const ExressionMap& expr_map) const;
104-
size_t add_loop_info(const LoopInfoPtr& loop);
105-
void remove_loop_info(size_t index);
106108
LoopInfoPtr get_loop_info(size_t index) const;
107109
size_t get_loop_count() const { return m_map.size(); }
108110
const std::map<size_t, LoopInfoPtr>& get_map() const;
@@ -182,23 +184,55 @@ class LinearIR::LoopManager {
182184
void expression_replacement(constExprIt new_expr_begin, constExprIt new_expr_end, const ExpressionPtr& decomposed_expr,
183185
size_t loop_id, const std::vector<ExpressionPort>& new_entries, const std::vector<ExpressionPort>& exits);
184186

185-
// Note: these methods find iterators of first entry loop point and last exit point (bounds of Loop)
186-
// If there are already inserted LoopBegin and LoopEnd in Linear IR, the methods can find them as well if `loop_ops_inserted` = true
187-
void get_loop_bounds(const LinearIR& linear_ir,
188-
size_t loop_id,
189-
LinearIR::constExprIt& loop_begin_pos,
190-
LinearIR::constExprIt& loop_end_pos,
191-
bool loop_ops_inserted = false) const;
192-
static void get_loop_bounds(const LinearIR& linear_ir,
193-
const std::vector<LoopPort>& entries,
194-
const std::vector<LoopPort>& exits,
195-
LinearIR::constExprIt& loop_begin_pos,
196-
LinearIR::constExprIt& loop_end_pos,
197-
size_t loop_id, bool loop_ops_inserted = false);
187+
/**
188+
* @brief Find bounds of Loop:
189+
* - If the explicit Loop exprs with the target `loop_id` have been inserted,
190+
* Loop bounds are these iterators of the corresponding LoopBegin and LoopEnd.
191+
* - Otherwise Loop bounds are iterators of the first entry loop port (or Scalar, VectorBuffer and another LoopBegin that
192+
* are in this Loop but have another `loop_id`) and the next iterator of the last exit loop port (or another LoopEnd that
193+
* are in this Loop but have another `loop_id`).
194+
* @param linear_ir linear IR
195+
* @param loop_id target Loop ID
196+
* @return the pair of loop_begin_pos and loop_end_pos iterators
197+
*/
198+
std::pair<LinearIR::constExprIt, LinearIR::constExprIt> get_loop_bounds(const LinearIR& linear_ir, size_t loop_id) const;
199+
/**
200+
* @brief Find bounds of Loop:
201+
* - If the explicit Loop exprs with the target `loop_id` have been inserted,
202+
* Loop bounds are these iterators of the corresponding LoopBegin and LoopEnd.
203+
* - Otherwise Loop bounds are iterators of the first entry loop port (or Scalar, VectorBuffer and another LoopBegin that
204+
* are in this Loop but have another `loop_id`) and the next iterator of the last exit loop port (or another LoopEnd that
205+
* are in this Loop but have another `loop_id`).
206+
* @param linear_ir linear IR
207+
* @param loop_id target Loop ID
208+
* @param entries input loop ports
209+
* @param exits output loop ports
210+
* @return the pair of loop_begin_pos and loop_end_pos iterators
211+
*/
212+
static std::pair<LinearIR::constExprIt, LinearIR::constExprIt> get_loop_bounds(const LinearIR& linear_ir, size_t loop_id,
213+
const std::vector<LoopPort>& entries, const std::vector<LoopPort>& exits);
198214

199215
LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id);
200216

201217
private:
218+
/**
219+
* @brief Add new Loop Info to the map
220+
* @param loop target loop info
221+
* @return the loop ID
222+
*/
223+
size_t add_loop_info(const LoopInfoPtr& loop);
224+
/**
225+
* @brief Remove LoopInfo from the map
226+
* @param index the target index of Loop
227+
*/
228+
void remove_loop_info(size_t index);
229+
/**
230+
* @brief Find expression ports in bounds that are connected to consumers or parent that aren't in these bounds
231+
* @param loop_begin_pos the first expression iterator of the Loop
232+
* @param loop_end_pos the next iterator after the last expression
233+
* @param entries found input expression ports
234+
* @param exits found output expression ports
235+
*/
202236
static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
203237
LinearIR::constExprIt loop_end_pos,
204238
std::vector<ExpressionPort>& entries,
@@ -220,6 +254,7 @@ class LinearIR::LoopManager {
220254
// for `before` the new Loop is the most outer Loop
221255
void insert_loop_id(const ExpressionPtr& expr, size_t new_id, bool before = true, size_t target_id = SIZE_MAX);
222256
void insert_loop_ids(const ExpressionPtr& expr, const std::vector<size_t>& new_ids, bool before = true, size_t target_id = SIZE_MAX);
257+
static bool is_loop_id_found(const ExpressionPtr& expr, size_t id);
223258

224259
std::map<size_t, LoopInfoPtr> m_map = {};
225260
size_t next_id = 0;

src/common/snippets/src/lowered/loop_manager.cpp

+61-35
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ void LoopInfo::set_first_iter_handler(LoopInfo::FirstIterHandler first_iter_hand
141141
m_first_iter_handler = std::move(first_iter_handler);
142142
}
143143

144+
void LoopInfo::update_entry_points(const std::function<void(LoopPort&)>& updater) {
145+
std::for_each(m_entry_points.begin(), m_entry_points.end(), updater);
146+
}
147+
148+
void LoopInfo::update_exit_points(const std::function<void(LoopPort&)>& updater) {
149+
std::for_each(m_exit_points.begin(), m_exit_points.end(), updater);
150+
}
151+
144152
bool operator==(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopManager::LoopPort& rhs) {
145153
if (&lhs == &rhs)
146154
return true;
@@ -194,26 +202,19 @@ std::vector<size_t> LinearIR::LoopManager::get_outer_expr_loops(const Expression
194202
return std::vector<size_t>(loop_ids.cbegin(), it);
195203
}
196204

197-
void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
198-
size_t loop_id,
199-
LinearIR::constExprIt &loop_begin_pos,
200-
LinearIR::constExprIt &loop_end_pos,
201-
bool loop_ops_inserted) const {
205+
std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir, size_t loop_id) const {
202206
const auto loop_info = get_loop_info(loop_id);
203-
get_loop_bounds(linear_ir, loop_info->get_entry_points(), loop_info->get_exit_points(), loop_begin_pos, loop_end_pos, loop_id, loop_ops_inserted);
207+
return get_loop_bounds(linear_ir, loop_id, loop_info->get_entry_points(), loop_info->get_exit_points());
204208
}
205209

206-
void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
207-
const std::vector<LoopPort>& entries,
208-
const std::vector<LoopPort>& exits,
209-
LinearIR::constExprIt &loop_begin_pos,
210-
LinearIR::constExprIt &loop_end_pos,
211-
size_t loop_id, bool loop_ops_inserted) {
210+
std::pair<LinearIR::constExprIt, LinearIR::constExprIt> LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir, size_t loop_id,
211+
const std::vector<LoopPort>& entries,
212+
const std::vector<LoopPort>& exits) {
212213
OPENVINO_ASSERT(!entries.empty(), "Loop must have entry points");
213214
OPENVINO_ASSERT(!exits.empty(), "Loop must have entry points");
214-
const auto& entry_expr = entries.front().expr_port->get_expr();
215-
loop_begin_pos = linear_ir.find(entry_expr);
216215

216+
const auto& entry_expr = entries.front().expr_port->get_expr();
217+
auto loop_begin_pos = linear_ir.find(entry_expr);
217218
// Some operations in Loop can be before first entry points: Scalars, VectorBuffer.
218219
// We should iterate by them till the expr is in the corresponding Loop
219220
auto prev_loop_ids = (*std::prev(loop_begin_pos))->get_loop_ids();
@@ -222,18 +223,28 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
222223
prev_loop_ids = (*std::prev(loop_begin_pos))->get_loop_ids();
223224
}
224225

225-
if (loop_ops_inserted) {
226-
const auto loop_begin = ov::as_type_ptr<op::LoopBegin>((*std::prev(loop_begin_pos))->get_node());
227-
OPENVINO_ASSERT(loop_begin, "Failed explicit loop bounds getting: LoopBegin has not been found");
228-
const auto loop_end = loop_begin->get_loop_end();
229-
OPENVINO_ASSERT(loop_end->get_id() == loop_id, "Failed explicit loop bounds getting: Loop bounds with correct ID have not been found");
230-
loop_begin_pos = std::prev(loop_begin_pos);
231-
loop_end_pos = linear_ir.find_after(loop_begin_pos, linear_ir.get_expr_by_node(loop_end));
232-
} else {
233-
// At the moment all Loops must have exit points
234-
const auto& exit_expr = exits.back().expr_port->get_expr();
235-
loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr));
226+
const auto& exit_expr = exits.back().expr_port->get_expr();
227+
auto loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr));
228+
// There might be LoopEnd with another `loop_id` but in the target Loop as well.
229+
auto current_loop_ids = (*loop_end_pos)->get_loop_ids();
230+
while (std::find(current_loop_ids.begin(), current_loop_ids.end(), loop_id) != current_loop_ids.end()) {
231+
loop_end_pos = std::next(loop_end_pos);
232+
current_loop_ids = (*loop_end_pos)->get_loop_ids();
236233
}
234+
235+
// Check for the existing LoopBegin/LoopEnd
236+
if (const auto loop_end = ov::as_type_ptr<op::LoopEnd>((*loop_end_pos)->get_node())) {
237+
if (loop_end->get_id() == loop_id) {
238+
// loop_begin_pos is iterator of LoopBegin now
239+
// loop_end_pos is iterator of LoopEnd now
240+
loop_begin_pos = std::prev(loop_begin_pos);
241+
const auto loop_begin = loop_end->get_loop_begin();
242+
OPENVINO_ASSERT((*loop_begin_pos)->get_node() == loop_begin, "LoopBegin has not been found!");
243+
}
244+
}
245+
246+
OPENVINO_ASSERT(loop_begin_pos != linear_ir.cend() && loop_end_pos != linear_ir.cend(), "Loop bounds haven't been found!");
247+
return std::make_pair(loop_begin_pos, loop_end_pos);
237248
}
238249

239250
LinearIR::LoopManager::LoopPort LinearIR::LoopManager::get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id) {
@@ -353,27 +364,35 @@ size_t LinearIR::LoopManager::replace_with_new_loop(const LinearIR& linear_ir,
353364
const std::vector<LoopPort>& entries,
354365
const std::vector<LoopPort>& exits,
355366
const size_t old_id) {
356-
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
357-
const auto loop_id = this->add_loop_info(loop_info);
358-
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
367+
const auto is_bound_explicit_loop_begin = ov::is_type<op::LoopBegin>(loop_begin_pos->get()->get_node());
368+
const auto is_bound_explicit_loop_end = ov::is_type<op::LoopEnd>(std::prev(loop_end_pos)->get()->get_node());
369+
OPENVINO_ASSERT((is_bound_explicit_loop_begin && is_bound_explicit_loop_end) ||
370+
(!is_bound_explicit_loop_begin && !is_bound_explicit_loop_end),
371+
"Incorrect LoopBounds!");
372+
const auto explicit_loop_bounds = is_bound_explicit_loop_begin && is_bound_explicit_loop_end;
373+
374+
const auto loop_id =
375+
this->add_loop_info(std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits));
376+
const auto begin = explicit_loop_bounds ? std::next(loop_begin_pos) : loop_begin_pos;
377+
const auto end = explicit_loop_bounds ? std::prev(loop_end_pos) : loop_end_pos;
378+
for (auto expr_it = begin; expr_it != end; ++expr_it) {
359379
replace_loop_id(*expr_it, old_id, loop_id);
360380
}
361381

362-
const auto old_loop_info = this->get_loop_info(old_id);
363-
const auto old_loop_begin_pos = linear_ir.find(old_loop_info->get_entry_points().front().expr_port->get_expr());
364-
const auto old_loop_end_pos = linear_ir.find(old_loop_info->get_exit_points().back().expr_port->get_expr());
382+
// Check that other expression in LinearIR doesn't have the old loop ID - otherwise completely removed from loop
383+
// manager
384+
const auto old_loop_bounds = get_loop_bounds(linear_ir, old_id);
365385
// If new bounds are equal to old loop bounds, this means that old Loop is removed totally from LIR
366386
// In this case old loop info must be completely removed from loop manager
367-
if (loop_begin_pos == old_loop_begin_pos && loop_end_pos == old_loop_end_pos) {
387+
if (loop_begin_pos == old_loop_bounds.first && loop_end_pos == old_loop_bounds.second) {
368388
this->remove_loop_info(old_id);
369389
}
370390
return loop_id;
371391
}
372392

373393
void LinearIR::LoopManager::fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper) {
374-
LinearIR::constExprIt loop_begin_target, loop_end_target;
375-
get_loop_bounds(linear_ir, fuse_into_upper ? loop_id_lower : loop_id_upper, loop_begin_target, loop_end_target);
376-
fuse_loops(loop_begin_target, loop_end_target, loop_id_upper, loop_id_lower, fuse_into_upper);
394+
const auto loop_bounds = get_loop_bounds(linear_ir, fuse_into_upper ? loop_id_lower : loop_id_upper);
395+
fuse_loops(loop_bounds.first, loop_bounds.second, loop_id_upper, loop_id_lower, fuse_into_upper);
377396
}
378397

379398
void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target, LinearIR::constExprIt loop_end_target,
@@ -541,6 +560,7 @@ void LinearIR::LoopManager::sort_loop_ports(LinearIR::constExprIt& loop_begin_po
541560

542561
void LinearIR::LoopManager::insert_loop_id(const ExpressionPtr& expr, size_t new_id, bool before, size_t target_id) {
543562
OPENVINO_ASSERT(m_map.count(new_id) == 1, "Failed marking expression by Loop ID: the Loop with this ID hasn't registered");
563+
OPENVINO_ASSERT(!is_loop_id_found(expr, new_id), "Expression cannot have several the same Loop IDs");
544564
auto& loop_ids = expr->m_loop_ids;
545565
OPENVINO_ASSERT(std::find(loop_ids.cbegin(), loop_ids.cend(), new_id) == loop_ids.cend(),
546566
"Expression cannot have several the same Loop IDs");
@@ -568,6 +588,7 @@ void LinearIR::LoopManager::insert_loop_ids(const ExpressionPtr& expr, const std
568588

569589
void LinearIR::LoopManager::replace_loop_id(const ExpressionPtr& expr, size_t prev_id, size_t new_id) {
570590
OPENVINO_ASSERT(m_map.count(new_id), "Failed marking expression by Loop ID: the Loop with this ID hasn't registered");
591+
OPENVINO_ASSERT(!is_loop_id_found(expr, new_id), "Expression cannot have several the same Loop IDs");
571592
auto& loop_ids = expr->m_loop_ids;
572593
OPENVINO_ASSERT(std::find(loop_ids.cbegin(), loop_ids.cend(), new_id) == loop_ids.cend(),
573594
"Expression already has the Loop with ID " + std::to_string(new_id));
@@ -584,6 +605,11 @@ void LinearIR::LoopManager::remove_loop_id(const ExpressionPtr& expr, size_t id)
584605
loop_ids.erase(it);
585606
}
586607

608+
bool LinearIR::LoopManager::is_loop_id_found(const ExpressionPtr& expr, size_t id) {
609+
const auto loop_ids = expr->get_loop_ids();
610+
return std::find(loop_ids.cbegin(), loop_ids.cend(), id) != loop_ids.cend();
611+
}
612+
587613
}// namespace lowered
588614
}// namespace snippets
589615
}// namespace ov

src/common/snippets/src/lowered/pass/fuse_loops.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ void FuseLoops::move(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_m
6868
std::map<size_t, std::pair<LinearIR::constExprIt, LinearIR::constExprIt>> outer_loops; // The map: LoopID -> [ LoopBegin, LoopEnd ]
6969
const auto outer_loop_ids = LinearIR::LoopManager::get_outer_expr_loops(*loop_begin_pos, loop_id);
7070
for (const auto& loop_id : outer_loop_ids) {
71-
LinearIR::constExprIt begin, end;
72-
loop_manager->get_loop_bounds(linear_ir, loop_id, begin, end);
71+
const auto loop_bounds = loop_manager->get_loop_bounds(linear_ir, loop_id);
7372
// save previos iterator since the current iterator can be moved
74-
outer_loops[loop_id] = {std::prev(begin), end};
73+
outer_loops[loop_id] = {std::prev(loop_bounds.first), loop_bounds.second};
7574
}
7675
// Secondly, move expressions
7776
for (auto it = loop_begin_pos; it != loop_end_pos;) {
@@ -122,7 +121,7 @@ bool FuseLoops::fuse_upper_into_current(LinearIR& linear_ir, const LinearIR::Loo
122121
return false;
123122

124123
LinearIR::constExprIt target_loop_begin_pos, target_loop_end_pos;
125-
loop_manager->get_loop_bounds(linear_ir, target_loop_id, target_loop_begin_pos, target_loop_end_pos);
124+
std::tie(target_loop_begin_pos, target_loop_end_pos) = loop_manager->get_loop_bounds(linear_ir, target_loop_id);
126125
loop_manager->fuse_loops(target_loop_begin_pos, target_loop_end_pos, target_loop_id, current_loop_id, false);
127126
// Update work_amount for Loop (increment is constant because increments must be the identical for fusion):
128127
loop_current->set_work_amount(std::max(loop_current->get_work_amount(), loop_target->get_work_amount()));
@@ -167,7 +166,7 @@ bool FuseLoops::fuse_lower_into_current(LinearIR& linear_ir, const LinearIR::Loo
167166
return false;
168167

169168
LinearIR::constExprIt target_loop_begin_pos, target_loop_end_pos;
170-
loop_manager->get_loop_bounds(linear_ir, target_loop_id, target_loop_begin_pos, target_loop_end_pos);
169+
std::tie(target_loop_begin_pos, target_loop_end_pos) = loop_manager->get_loop_bounds(linear_ir, target_loop_id);
171170
loop_manager->fuse_loops(target_loop_begin_pos, target_loop_end_pos, current_loop_id, target_loop_id);
172171
// Update work_amount for Loop (increment is constant because increments must be the identical for fusion):
173172
loop_current->set_work_amount(std::max(loop_current->get_work_amount(), loop_target->get_work_amount()));
@@ -212,7 +211,7 @@ bool FuseLoops::run(LinearIR& linear_ir) {
212211

213212
const auto current_loop_info = loop_manager->get_loop_info(current_loop_id);
214213
LinearIR::constExprIt current_loop_begin_pos, current_loop_end_pos;
215-
loop_manager->get_loop_bounds(linear_ir, current_loop_id, current_loop_begin_pos, current_loop_end_pos);
214+
std::tie(current_loop_begin_pos, current_loop_end_pos) = loop_manager->get_loop_bounds(linear_ir, current_loop_id);
216215

217216
// We fuse upper Loops into the current till we can do it.
218217
// After that we fuse lower Loops into the current till we can do it.

0 commit comments

Comments
 (0)