Skip to content

Commit 4bc2a6d

Browse files
committed
xe: ir: improve inject_alloc_stmts() for nested stmt_seq_t
1 parent a21a8bf commit 4bc2a6d

File tree

3 files changed

+148
-50
lines changed

3 files changed

+148
-50
lines changed

src/gpu/intel/jit/ir/core.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,30 @@ expr_t linear_t::to_expr() const {
307307
return simplify_rewrite(ret);
308308
}
309309

310+
void stmt_seq_flatten(std::vector<stmt_t> &out, const stmt_t &s) {
311+
if (auto *seq = s.as_ptr<stmt_seq_t>()) {
312+
out.insert(out.end(), seq->vec.begin(), seq->vec.end());
313+
return;
314+
}
315+
out.push_back(s);
316+
}
317+
318+
stmt_t stmt_seq_t::make(const std::vector<stmt_t> &_vec) {
319+
std::vector<stmt_t> vec;
320+
for (auto &s : _vec)
321+
stmt_seq_flatten(vec, s);
322+
return stmt_t(new stmt_seq_t(vec));
323+
}
324+
325+
stmt_t stmt_t::append(const stmt_t &s) const {
326+
if (is_empty()) return s;
327+
if (s.is_empty()) return *this;
328+
std::vector<stmt_t> vec;
329+
stmt_seq_flatten(vec, *this);
330+
stmt_seq_flatten(vec, s);
331+
return stmt_seq_t::make(vec);
332+
}
333+
310334
expr_t expr_t::operator[](const expr_t &off) const {
311335
if (is<shuffle_t>()) {
312336
ir_assert(is_const(off)) << "Offset is not constant.";

src/gpu/intel/jit/ir/core.hpp

+1-22
Original file line numberDiff line numberDiff line change
@@ -2437,9 +2437,7 @@ class stmt_seq_t : public stmt_impl_t {
24372437
public:
24382438
IR_DECL_STMT_TYPE_ID(stmt_seq_t)
24392439

2440-
static stmt_t make(const std::vector<stmt_t> &vec) {
2441-
return stmt_t(new stmt_seq_t(vec));
2442-
}
2440+
static stmt_t make(const std::vector<stmt_t> &vec);
24432441

24442442
static stmt_t make(const stmt_t &head, const stmt_t &tail) {
24452443
return head.append(tail);
@@ -2463,25 +2461,6 @@ class stmt_seq_t : public stmt_impl_t {
24632461
: stmt_impl_t(_type_info()), vec(vec) {}
24642462
};
24652463

2466-
inline stmt_t stmt_t::append(const stmt_t &s) const {
2467-
if (is_empty()) return s;
2468-
if (s.is_empty()) return *this;
2469-
auto *seq1 = this->as_ptr<stmt_seq_t>();
2470-
auto *seq2 = s.as_ptr<stmt_seq_t>();
2471-
std::vector<stmt_t> vec;
2472-
if (seq1) {
2473-
vec.insert(vec.end(), seq1->vec.begin(), seq1->vec.end());
2474-
} else {
2475-
vec.push_back(*this);
2476-
}
2477-
if (seq2) {
2478-
vec.insert(vec.end(), seq2->vec.begin(), seq2->vec.end());
2479-
} else {
2480-
vec.push_back(s);
2481-
}
2482-
return stmt_seq_t::make(vec);
2483-
}
2484-
24852464
// Function call attribute.
24862465
class func_call_attr_impl_t : public object_impl_t {
24872466
public:

src/gpu/intel/jit/ir/ir.cpp

+123-28
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,14 @@ class stmt_flattener_t : public ir_visitor_t {
374374

375375
class alloc_injector_t : public ir_mutator_t {
376376
public:
377-
alloc_injector_t(const stmt_t &root, const std::vector<stmt_t> &allocs,
378-
bool put_innermost)
379-
: root_(root), put_innermost_(put_innermost), allocs_(allocs) {
377+
alloc_injector_t(const stmt_t &root, const std::vector<stmt_t> &allocs)
378+
: allocs_(allocs) {
380379
for (auto &_a : allocs) {
381380
auto &a = _a.as<alloc_t>();
382381
if (a.kind != alloc_kind_t::global) ir_assert(a.size > 0) << _a;
383382
alloc_map_.insert({a.buf, _a});
384383
}
385-
mutate(root_);
384+
mutate(root);
386385
buf_total_refs_ = buf_cur_refs_;
387386
for (auto &kv : buf_cur_refs_)
388387
kv.second = 0;
@@ -404,39 +403,127 @@ class alloc_injector_t : public ir_mutator_t {
404403
template <typename T>
405404
object_t mutate_stmt(const T &obj) {
406405
if (in_ctor_) return ir_mutator_t::_mutate(obj);
407-
object_t new_obj = obj;
408-
object_set_t<expr_t> undef_bufs;
409-
if (put_innermost_) {
410-
for (auto &kv : buf_cur_refs_)
411-
if (kv.second == 0) undef_bufs.insert(kv.first);
412-
new_obj = ir_mutator_t::_mutate(obj);
406+
if (T::_type_id() == ir_type_id_t::stmt_seq_t) {
407+
return mutate_stmt_seq(obj);
413408
}
414-
for (auto &a : allocs_) {
415-
auto it = alloc_map_.find(a.as<alloc_t>().buf);
416-
auto &buf = it->first;
417-
if (it->second.is_empty()) continue; // Already injected.
418-
bool do_inject = false;
419-
if (put_innermost_) {
420-
int cur_refs = buf_cur_refs_[buf];
421-
int total_refs = buf_total_refs_[buf];
422-
bool was_undef = (undef_bufs.count(buf) != 0);
423-
do_inject = was_undef && (cur_refs == total_refs);
424-
} else {
425-
do_inject = root_.is_same(obj);
409+
auto undef_bufs = get_undef_bufs();
410+
auto new_obj = ir_mutator_t::_mutate(obj);
411+
new_obj = maybe_inject(new_obj, undef_bufs);
412+
return new_obj;
413+
}
414+
415+
// Handle stmt_seq_t in a special way:
416+
// 1. Walk through the sequence and record the first and the last statement
417+
// where a buffer is referenced
418+
// 2. Inject alloc statements according to the usage
419+
object_t mutate_stmt_seq(const object_t &obj) {
420+
auto stmt_vec = obj.as<stmt_seq_t>().vec;
421+
ir_assert(!stmt_vec.empty());
422+
int nstmts = (int)stmt_vec.size();
423+
// Mutate statments and record buffer usage in the form: buf: [first, last].
424+
object_map_t<expr_t, int> last_undef;
425+
object_map_t<expr_t, std::pair<int, int>> entries;
426+
for (int i = 0; i < nstmts; i++) {
427+
auto &s = stmt_vec[i];
428+
for (auto &b : get_undef_bufs()) {
429+
auto it = alloc_map_.find(b);
430+
if (it == alloc_map_.end() || it->second.is_empty()) continue;
431+
last_undef[b] = i;
432+
}
433+
s = mutate(s);
434+
for (auto &kv : last_undef) {
435+
auto &buf = kv.first;
436+
if (entries.count(buf) != 0) continue;
437+
if (buf_cur_refs_[buf] == buf_total_refs_[buf]) {
438+
entries[buf] = std::make_pair(kv.second, i);
439+
}
426440
}
427-
if (do_inject) {
428-
auto &a = it->second.as<alloc_t>();
441+
}
442+
// Sort buffers based on the number of statements they span. This is to
443+
// inject more local allocations first.
444+
std::vector<expr_t> bufs;
445+
for (auto &kv : entries) {
446+
if (alloc_map_.at(kv.first).is_empty()) continue;
447+
bufs.push_back(kv.first);
448+
}
449+
std::sort(bufs.begin(), bufs.end(),
450+
[&](const expr_t &a, const expr_t &b) {
451+
auto &ea = entries.at(a);
452+
auto &eb = entries.at(b);
453+
return (ea.second - ea.first) < (eb.second - eb.first);
454+
});
455+
// Use union-find to incrementally merge statements based on the common
456+
// buffers.
457+
std::vector<int> parent(nstmts);
458+
std::iota(parent.begin(), parent.end(), 0);
459+
std::function<int(int)> _find;
460+
std::function<void(int, int)> _union;
461+
_find = [&](int i) {
462+
if (parent[i] == i) return i;
463+
return parent[i] = _find(parent[i]);
464+
};
465+
_union = [&](int i, int j) {
466+
i = _find(i);
467+
j = _find(j);
468+
parent[j] = i;
469+
};
470+
std::vector<stmt_t> new_stmt_seq = stmt_vec;
471+
for (auto &buf : bufs) {
472+
auto &e = entries.at(buf);
473+
stmt_t stmt;
474+
for (int i = e.first; i <= e.second; i++) {
475+
int idx = _find(i);
476+
stmt = stmt.append(new_stmt_seq[idx]);
477+
new_stmt_seq[idx] = stmt_t();
478+
_union(e.first, i);
479+
}
480+
auto it = alloc_map_.find(buf);
481+
auto &a = it->second.as<alloc_t>();
482+
stmt = alloc_t::make(a.buf, a.size, a.kind, a.attrs, stmt);
483+
new_stmt_seq[_find(e.first)] = stmt;
484+
it->second = stmt_t();
485+
}
486+
stmt_t new_obj;
487+
for (auto &s : new_stmt_seq) {
488+
if (s.is_empty()) continue;
489+
new_obj = new_obj.append(s);
490+
}
491+
return std::move(new_obj);
492+
}
493+
494+
object_set_t<expr_t> get_undef_bufs() const {
495+
object_set_t<expr_t> ret;
496+
for (auto &kv : buf_cur_refs_)
497+
if (kv.second == 0) ret.insert(kv.first);
498+
return ret;
499+
}
500+
501+
object_t maybe_inject(
502+
const object_t &obj, const object_set_t<expr_t> &undef_bufs) {
503+
auto new_obj = obj;
504+
for (auto &kv : alloc_map_) {
505+
if (kv.second.is_empty()) continue;
506+
auto &buf = kv.first;
507+
auto &a = kv.second.as<alloc_t>();
508+
if (do_inject(buf, undef_bufs)) {
429509
new_obj = alloc_t::make(
430510
a.buf, a.size, a.kind, a.attrs, new_obj);
431-
it->second = stmt_t();
511+
kv.second = stmt_t();
432512
}
433513
}
434514
return new_obj;
435515
}
436516

517+
bool do_inject(
518+
const expr_t &buf, const object_set_t<expr_t> &undef_bufs) const {
519+
if (buf.is_empty()) return false; // Already injected.
520+
int cur_refs = buf_cur_refs_.at(buf);
521+
int total_refs = buf_total_refs_.at(buf);
522+
bool was_undef = (undef_bufs.count(buf) != 0);
523+
return was_undef && (cur_refs == total_refs);
524+
}
525+
437526
bool in_ctor_ = true;
438-
const stmt_t &root_;
439-
bool put_innermost_;
440527
std::vector<stmt_t> allocs_;
441528
object_map_t<expr_t, stmt_t> alloc_map_;
442529
object_map_t<expr_t, int> buf_total_refs_;
@@ -480,7 +567,15 @@ std::vector<stmt_t> flatten_statements(const stmt_t &root) {
480567

481568
stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
482569
bool put_innermost) {
483-
alloc_injector_t injector(stmt, allocs, put_innermost);
570+
if (!put_innermost) {
571+
auto ret = stmt;
572+
for (auto &_a : allocs) {
573+
auto &a = _a.as<alloc_t>();
574+
ret = alloc_t::make(a.buf, a.size, a.kind, a.attrs, ret);
575+
}
576+
return ret;
577+
}
578+
alloc_injector_t injector(stmt, allocs);
484579
return injector.mutate(stmt);
485580
}
486581

0 commit comments

Comments
 (0)