@@ -427,17 +427,20 @@ Expr const_false(int w) {
427
427
return make_zero (UInt (1 , w));
428
428
}
429
429
430
- Expr lossless_cast (Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache) {
430
+ Expr lossless_cast (Type t,
431
+ Expr e,
432
+ const Scope<ConstantInterval> &scope,
433
+ std::map<Expr, ConstantInterval, ExprCompare> *cache) {
431
434
if (!e.defined () || t == e.type ()) {
432
435
return e;
433
436
} else if (t.can_represent (e.type ())) {
434
437
return cast (t, std::move (e));
435
438
} else if (const Cast *c = e.as <Cast>()) {
436
439
if (c->type .can_represent (c->value .type ())) {
437
- return lossless_cast (t, c->value , cache);
440
+ return lossless_cast (t, c->value , scope, cache);
438
441
}
439
442
} else if (const Broadcast *b = e.as <Broadcast>()) {
440
- Expr v = lossless_cast (t.element_of (), b->value , cache);
443
+ Expr v = lossless_cast (t.element_of (), b->value , scope, cache);
441
444
if (v.defined ()) {
442
445
return Broadcast::make (v, b->lanes );
443
446
}
@@ -456,7 +459,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
456
459
} else if (const Shuffle *shuf = e.as <Shuffle>()) {
457
460
std::vector<Expr> vecs;
458
461
for (const auto &vec : shuf->vectors ) {
459
- vecs.emplace_back (lossless_cast (t.with_lanes (vec.type ().lanes ()), vec, cache));
462
+ vecs.emplace_back (lossless_cast (t.with_lanes (vec.type ().lanes ()), vec, scope, cache));
460
463
if (!vecs.back ().defined ()) {
461
464
return Expr ();
462
465
}
@@ -465,73 +468,73 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
465
468
} else if (t.is_int_or_uint ()) {
466
469
// Check the bounds. If they're small enough, we can throw narrowing
467
470
// casts around e, or subterms.
468
- ConstantInterval ci = constant_integer_bounds (e, Scope<ConstantInterval>:: empty_scope () , cache);
471
+ ConstantInterval ci = constant_integer_bounds (e, scope , cache);
469
472
470
473
if (t.can_represent (ci)) {
471
474
// There are certain IR nodes where if the result is expressible
472
475
// using some type, and the args are expressible using that type,
473
476
// then the operation can just be done in that type.
474
477
if (const Add *op = e.as <Add>()) {
475
- Expr a = lossless_cast (t, op->a , cache);
476
- Expr b = lossless_cast (t, op->b , cache);
478
+ Expr a = lossless_cast (t, op->a , scope, cache);
479
+ Expr b = lossless_cast (t, op->b , scope, cache);
477
480
if (a.defined () && b.defined ()) {
478
481
return Add::make (a, b);
479
482
}
480
483
} else if (const Sub *op = e.as <Sub>()) {
481
- Expr a = lossless_cast (t, op->a , cache);
482
- Expr b = lossless_cast (t, op->b , cache);
484
+ Expr a = lossless_cast (t, op->a , scope, cache);
485
+ Expr b = lossless_cast (t, op->b , scope, cache);
483
486
if (a.defined () && b.defined ()) {
484
487
return Sub::make (a, b);
485
488
}
486
489
} else if (const Mul *op = e.as <Mul>()) {
487
- Expr a = lossless_cast (t, op->a , cache);
488
- Expr b = lossless_cast (t, op->b , cache);
490
+ Expr a = lossless_cast (t, op->a , scope, cache);
491
+ Expr b = lossless_cast (t, op->b , scope, cache);
489
492
if (a.defined () && b.defined ()) {
490
493
return Mul::make (a, b);
491
494
}
492
495
} else if (const Min *op = e.as <Min>()) {
493
- Expr a = lossless_cast (t, op->a , cache);
494
- Expr b = lossless_cast (t, op->b , cache);
496
+ Expr a = lossless_cast (t, op->a , scope, cache);
497
+ Expr b = lossless_cast (t, op->b , scope, cache);
495
498
if (a.defined () && b.defined ()) {
496
499
debug (0 ) << a << " " << b << " \n " ;
497
500
return Min::make (a, b);
498
501
}
499
502
} else if (const Max *op = e.as <Max>()) {
500
- Expr a = lossless_cast (t, op->a , cache);
501
- Expr b = lossless_cast (t, op->b , cache);
503
+ Expr a = lossless_cast (t, op->a , scope, cache);
504
+ Expr b = lossless_cast (t, op->b , scope, cache);
502
505
if (a.defined () && b.defined ()) {
503
506
return Max::make (a, b);
504
507
}
505
508
} else if (const Mod *op = e.as <Mod>()) {
506
- Expr a = lossless_cast (t, op->a , cache);
507
- Expr b = lossless_cast (t, op->b , cache);
509
+ Expr a = lossless_cast (t, op->a , scope, cache);
510
+ Expr b = lossless_cast (t, op->b , scope, cache);
508
511
if (a.defined () && b.defined ()) {
509
512
return Mod::make (a, b);
510
513
}
511
514
} else if (const Call *op = Call::as_intrinsic (e, {Call::widening_add, Call::widen_right_add})) {
512
- Expr a = lossless_cast (t, op->args [0 ], cache);
513
- Expr b = lossless_cast (t, op->args [1 ], cache);
515
+ Expr a = lossless_cast (t, op->args [0 ], scope, cache);
516
+ Expr b = lossless_cast (t, op->args [1 ], scope, cache);
514
517
if (a.defined () && b.defined ()) {
515
518
return Add::make (a, b);
516
519
}
517
520
} else if (const Call *op = Call::as_intrinsic (e, {Call::widening_sub, Call::widen_right_sub})) {
518
- Expr a = lossless_cast (t, op->args [0 ], cache);
519
- Expr b = lossless_cast (t, op->args [1 ], cache);
521
+ Expr a = lossless_cast (t, op->args [0 ], scope, cache);
522
+ Expr b = lossless_cast (t, op->args [1 ], scope, cache);
520
523
if (a.defined () && b.defined ()) {
521
524
return Sub::make (a, b);
522
525
}
523
526
} else if (const Call *op = Call::as_intrinsic (e, {Call::widening_mul, Call::widen_right_mul})) {
524
- Expr a = lossless_cast (t, op->args [0 ], cache);
525
- Expr b = lossless_cast (t, op->args [1 ], cache);
527
+ Expr a = lossless_cast (t, op->args [0 ], scope, cache);
528
+ Expr b = lossless_cast (t, op->args [1 ], scope, cache);
526
529
if (a.defined () && b.defined ()) {
527
530
return Mul::make (a, b);
528
531
}
529
532
} else if (const Call *op = Call::as_intrinsic (e, {Call::shift_left, Call::widening_shift_left,
530
533
Call::shift_right, Call::widening_shift_right})) {
531
- Expr a = lossless_cast (t, op->args [0 ], cache);
532
- Expr b = lossless_cast (t, op->args [1 ], cache);
534
+ Expr a = lossless_cast (t, op->args [0 ], scope, cache);
535
+ Expr b = lossless_cast (t, op->args [1 ], scope, cache);
533
536
if (a.defined () && b.defined ()) {
534
- ConstantInterval cb = constant_integer_bounds (b, Scope<ConstantInterval>:: empty_scope () , cache);
537
+ ConstantInterval cb = constant_integer_bounds (b, scope , cache);
535
538
if (cb > -t.bits () && cb < t.bits ()) {
536
539
if (op->is_intrinsic ({Call::shift_left, Call::widening_shift_left})) {
537
540
return a << b;
@@ -544,7 +547,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
544
547
if (op->op == VectorReduce::Add ||
545
548
op->op == VectorReduce::Min ||
546
549
op->op == VectorReduce::Max) {
547
- Expr v = lossless_cast (t.with_lanes (op->value .type ().lanes ()), op->value , cache);
550
+ Expr v = lossless_cast (t.with_lanes (op->value .type ().lanes ()), op->value , scope, cache);
548
551
if (v.defined ()) {
549
552
return VectorReduce::make (op->op , v, op->type .lanes ());
550
553
}
0 commit comments