@@ -374,15 +374,14 @@ class stmt_flattener_t : public ir_visitor_t {
374
374
375
375
class alloc_injector_t : public ir_mutator_t {
376
376
public:
377
- alloc_injector_t (const stmt_t &root, const std::vector<stmt_t > &allocs,
378
- bool put_innermost)
379
- : root_(root), put_innermost_(put_innermost), allocs_(allocs) {
377
+ alloc_injector_t (const stmt_t &root, const std::vector<stmt_t > &allocs)
378
+ : allocs_(allocs) {
380
379
for (auto &_a : allocs) {
381
380
auto &a = _a.as <alloc_t >();
382
381
if (a.kind != alloc_kind_t ::global) ir_assert (a.size > 0 ) << _a;
383
382
alloc_map_.insert ({a.buf , _a});
384
383
}
385
- mutate (root_ );
384
+ mutate (root );
386
385
buf_total_refs_ = buf_cur_refs_;
387
386
for (auto &kv : buf_cur_refs_)
388
387
kv.second = 0 ;
@@ -404,39 +403,127 @@ class alloc_injector_t : public ir_mutator_t {
404
403
template <typename T>
405
404
object_t mutate_stmt (const T &obj) {
406
405
if (in_ctor_) return ir_mutator_t::_mutate (obj);
407
- object_t new_obj = obj;
408
- object_set_t <expr_t > undef_bufs;
409
- if (put_innermost_) {
410
- for (auto &kv : buf_cur_refs_)
411
- if (kv.second == 0 ) undef_bufs.insert (kv.first );
412
- new_obj = ir_mutator_t::_mutate (obj);
406
+ if (T::_type_id () == ir_type_id_t ::stmt_seq_t ) {
407
+ return mutate_stmt_seq (obj);
413
408
}
414
- for (auto &a : allocs_) {
415
- auto it = alloc_map_.find (a.as <alloc_t >().buf );
416
- auto &buf = it->first ;
417
- if (it->second .is_empty ()) continue ; // Already injected.
418
- bool do_inject = false ;
419
- if (put_innermost_) {
420
- int cur_refs = buf_cur_refs_[buf];
421
- int total_refs = buf_total_refs_[buf];
422
- bool was_undef = (undef_bufs.count (buf) != 0 );
423
- do_inject = was_undef && (cur_refs == total_refs);
424
- } else {
425
- do_inject = root_.is_same (obj);
409
+ auto undef_bufs = get_undef_bufs ();
410
+ auto new_obj = ir_mutator_t::_mutate (obj);
411
+ new_obj = maybe_inject (new_obj, undef_bufs);
412
+ return new_obj;
413
+ }
414
+
415
+ // Handle stmt_seq_t in a special way:
416
+ // 1. Walk through the sequence and record the first and the last statement
417
+ // where a buffer is referenced
418
+ // 2. Inject alloc statements according to the usage
419
+ object_t mutate_stmt_seq (const object_t &obj) {
420
+ auto stmt_vec = obj.as <stmt_seq_t >().vec ;
421
+ ir_assert (!stmt_vec.empty ());
422
+ int nstmts = (int )stmt_vec.size ();
423
+ // Mutate statments and record buffer usage in the form: buf: [first, last].
424
+ object_map_t <expr_t , int > last_undef;
425
+ object_map_t <expr_t , std::pair<int , int >> entries;
426
+ for (int i = 0 ; i < nstmts; i++) {
427
+ auto &s = stmt_vec[i];
428
+ for (auto &b : get_undef_bufs ()) {
429
+ auto it = alloc_map_.find (b);
430
+ if (it == alloc_map_.end () || it->second .is_empty ()) continue ;
431
+ last_undef[b] = i;
432
+ }
433
+ s = mutate (s);
434
+ for (auto &kv : last_undef) {
435
+ auto &buf = kv.first ;
436
+ if (entries.count (buf) != 0 ) continue ;
437
+ if (buf_cur_refs_[buf] == buf_total_refs_[buf]) {
438
+ entries[buf] = std::make_pair (kv.second , i);
439
+ }
426
440
}
427
- if (do_inject) {
428
- auto &a = it->second .as <alloc_t >();
441
+ }
442
+ // Sort buffers based on the number of statements they span. This is to
443
+ // inject more local allocations first.
444
+ std::vector<expr_t > bufs;
445
+ for (auto &kv : entries) {
446
+ if (alloc_map_.at (kv.first ).is_empty ()) continue ;
447
+ bufs.push_back (kv.first );
448
+ }
449
+ std::sort (bufs.begin (), bufs.end (),
450
+ [&](const expr_t &a, const expr_t &b) {
451
+ auto &ea = entries.at (a);
452
+ auto &eb = entries.at (b);
453
+ return (ea.second - ea.first ) < (eb.second - eb.first );
454
+ });
455
+ // Use union-find to incrementally merge statements based on the common
456
+ // buffers.
457
+ std::vector<int > parent (nstmts);
458
+ std::iota (parent.begin (), parent.end (), 0 );
459
+ std::function<int (int )> _find;
460
+ std::function<void (int , int )> _union;
461
+ _find = [&](int i) {
462
+ if (parent[i] == i) return i;
463
+ return parent[i] = _find (parent[i]);
464
+ };
465
+ _union = [&](int i, int j) {
466
+ i = _find (i);
467
+ j = _find (j);
468
+ parent[j] = i;
469
+ };
470
+ std::vector<stmt_t > new_stmt_seq = stmt_vec;
471
+ for (auto &buf : bufs) {
472
+ auto &e = entries.at (buf);
473
+ stmt_t stmt;
474
+ for (int i = e.first ; i <= e.second ; i++) {
475
+ int idx = _find (i);
476
+ stmt = stmt.append (new_stmt_seq[idx]);
477
+ new_stmt_seq[idx] = stmt_t ();
478
+ _union (e.first , i);
479
+ }
480
+ auto it = alloc_map_.find (buf);
481
+ auto &a = it->second .as <alloc_t >();
482
+ stmt = alloc_t::make (a.buf , a.size , a.kind , a.attrs , stmt);
483
+ new_stmt_seq[_find (e.first )] = stmt;
484
+ it->second = stmt_t ();
485
+ }
486
+ stmt_t new_obj;
487
+ for (auto &s : new_stmt_seq) {
488
+ if (s.is_empty ()) continue ;
489
+ new_obj = new_obj.append (s);
490
+ }
491
+ return std::move (new_obj);
492
+ }
493
+
494
+ object_set_t <expr_t > get_undef_bufs () const {
495
+ object_set_t <expr_t > ret;
496
+ for (auto &kv : buf_cur_refs_)
497
+ if (kv.second == 0 ) ret.insert (kv.first );
498
+ return ret;
499
+ }
500
+
501
+ object_t maybe_inject (
502
+ const object_t &obj, const object_set_t <expr_t > &undef_bufs) {
503
+ auto new_obj = obj;
504
+ for (auto &kv : alloc_map_) {
505
+ if (kv.second .is_empty ()) continue ;
506
+ auto &buf = kv.first ;
507
+ auto &a = kv.second .as <alloc_t >();
508
+ if (do_inject (buf, undef_bufs)) {
429
509
new_obj = alloc_t::make (
430
510
a.buf , a.size , a.kind , a.attrs , new_obj);
431
- it-> second = stmt_t ();
511
+ kv. second = stmt_t ();
432
512
}
433
513
}
434
514
return new_obj;
435
515
}
436
516
517
+ bool do_inject (
518
+ const expr_t &buf, const object_set_t <expr_t > &undef_bufs) const {
519
+ if (buf.is_empty ()) return false ; // Already injected.
520
+ int cur_refs = buf_cur_refs_.at (buf);
521
+ int total_refs = buf_total_refs_.at (buf);
522
+ bool was_undef = (undef_bufs.count (buf) != 0 );
523
+ return was_undef && (cur_refs == total_refs);
524
+ }
525
+
437
526
bool in_ctor_ = true ;
438
- const stmt_t &root_;
439
- bool put_innermost_;
440
527
std::vector<stmt_t > allocs_;
441
528
object_map_t <expr_t , stmt_t > alloc_map_;
442
529
object_map_t <expr_t , int > buf_total_refs_;
@@ -480,7 +567,15 @@ std::vector<stmt_t> flatten_statements(const stmt_t &root) {
480
567
481
568
stmt_t inject_alloc_stmts (const stmt_t &stmt, const std::vector<stmt_t > &allocs,
482
569
bool put_innermost) {
483
- alloc_injector_t injector (stmt, allocs, put_innermost);
570
+ if (!put_innermost) {
571
+ auto ret = stmt;
572
+ for (auto &_a : allocs) {
573
+ auto &a = _a.as <alloc_t >();
574
+ ret = alloc_t::make (a.buf , a.size , a.kind , a.attrs , ret);
575
+ }
576
+ return ret;
577
+ }
578
+ alloc_injector_t injector (stmt, allocs);
484
579
return injector.mutate (stmt);
485
580
}
486
581
0 commit comments