Skip to content

Commit 5783534

Browse files
authored
Fix UB-introducing rewrite in FindIntrinsics (#8539)
FindIntrinsics was rewriting i8(rounding_shift_left(i16(foo_i8), 11)) to rounding_shift_left(foo_i8, 11) I.e. it decided it could do the shift in the narrower type. However this isn't correct, because 11 is a valid shift for a 16-bit int, but not for an 8-bit int. The former is zero, the latter gets turned into a poison value because we lower it to llvm's shl. This was discovered by a random failure in test/correctness/lossless_cast.cpp in another PR for seed 826708018. This PR fixes this case by adding a compile-time check that the shift is in-range. For the examples in test/correctness/intrinsics.cpp the shift amount ends up in a let, so making this work on those cases required handling a TODO: tracking the constant integer bounds of variables in scope in FindIntrinsics and therefore also in lossless_cast.
1 parent 097aee9 commit 5783534

File tree

4 files changed

+109
-48
lines changed

4 files changed

+109
-48
lines changed

src/FindIntrinsics.cpp

+58-5
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,10 @@ class FindIntrinsics : public IRMutator {
101101
Scope<ConstantInterval> let_var_bounds;
102102

103103
Expr lossless_cast(Type t, const Expr &e) {
104-
return Halide::Internal::lossless_cast(t, e, &bounds_cache);
104+
return Halide::Internal::lossless_cast(t, e, let_var_bounds, &bounds_cache);
105105
}
106106

107107
ConstantInterval constant_integer_bounds(const Expr &e) {
108-
// TODO: Use the scope - add let visitors
109108
return Halide::Internal::constant_integer_bounds(e, let_var_bounds, &bounds_cache);
110109
}
111110

@@ -210,6 +209,51 @@ class FindIntrinsics : public IRMutator {
210209
return Expr();
211210
}
212211

212+
template<typename LetOrLetStmt>
213+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
214+
struct Frame {
215+
const LetOrLetStmt *orig;
216+
Expr new_value;
217+
ScopedBinding<ConstantInterval> bind;
218+
Frame(const LetOrLetStmt *orig,
219+
Expr &&new_value,
220+
ScopedBinding<ConstantInterval> &&bind)
221+
: orig(orig), new_value(std::move(new_value)), bind(std::move(bind)) {
222+
}
223+
};
224+
std::vector<Frame> frames;
225+
decltype(op->body) body;
226+
while (op) {
227+
Expr v = mutate(op->value);
228+
ConstantInterval b = constant_integer_bounds(v);
229+
frames.emplace_back(op,
230+
std::move(v),
231+
ScopedBinding<ConstantInterval>(let_var_bounds, op->name, b));
232+
body = op->body;
233+
op = body.template as<LetOrLetStmt>();
234+
}
235+
236+
body = mutate(body);
237+
238+
for (const auto &f : reverse_view(frames)) {
239+
if (f.new_value.same_as(f.orig->value) && body.same_as(f.orig->body)) {
240+
body = f.orig;
241+
} else {
242+
body = LetOrLetStmt::make(f.orig->name, f.new_value, body);
243+
}
244+
}
245+
246+
return body;
247+
}
248+
249+
Expr visit(const Let *op) override {
250+
return visit_let(op);
251+
}
252+
253+
Stmt visit(const LetStmt *op) override {
254+
return visit_let(op);
255+
}
256+
213257
Expr visit(const Add *op) override {
214258
if (!find_intrinsics_for_type(op->type)) {
215259
return IRMutator::visit(op);
@@ -697,7 +741,12 @@ class FindIntrinsics : public IRMutator {
697741
bool is_saturated = op->value.as<Max>() || op->value.as<Min>();
698742
Expr a = lossless_cast(op->type, shift->args[0]);
699743
Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]);
700-
if (a.defined() && b.defined()) {
744+
// Doing the shift in the narrower type might introduce UB where
745+
// there was no UB before, so we need to make sure b is bounded.
746+
auto b_bounds = constant_integer_bounds(b);
747+
const int max_shift = op->type.bits() - 1;
748+
749+
if (a.defined() && b.defined() && b_bounds >= -max_shift && b_bounds <= max_shift) {
701750
if (!is_saturated ||
702751
(shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) ||
703752
(shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) {
@@ -1118,8 +1167,12 @@ class SubstituteInWideningLets : public IRMutator {
11181167
std::string name;
11191168
Expr new_value;
11201169
ScopedBinding<Expr> bind;
1121-
Frame(const std::string &name, const Expr &new_value, ScopedBinding<Expr> &&bind)
1122-
: name(name), new_value(new_value), bind(std::move(bind)) {
1170+
Frame(const std::string &name,
1171+
const Expr &new_value,
1172+
ScopedBinding<Expr> &&bind)
1173+
: name(name),
1174+
new_value(new_value),
1175+
bind(std::move(bind)) {
11231176
}
11241177
};
11251178
std::vector<Frame> frames;

src/IROperator.cpp

+30-27
Original file line numberDiff line numberDiff line change
@@ -427,17 +427,20 @@ Expr const_false(int w) {
427427
return make_zero(UInt(1, w));
428428
}
429429

430-
Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache) {
430+
Expr lossless_cast(Type t,
431+
Expr e,
432+
const Scope<ConstantInterval> &scope,
433+
std::map<Expr, ConstantInterval, ExprCompare> *cache) {
431434
if (!e.defined() || t == e.type()) {
432435
return e;
433436
} else if (t.can_represent(e.type())) {
434437
return cast(t, std::move(e));
435438
} else if (const Cast *c = e.as<Cast>()) {
436439
if (c->type.can_represent(c->value.type())) {
437-
return lossless_cast(t, c->value, cache);
440+
return lossless_cast(t, c->value, scope, cache);
438441
}
439442
} else if (const Broadcast *b = e.as<Broadcast>()) {
440-
Expr v = lossless_cast(t.element_of(), b->value, cache);
443+
Expr v = lossless_cast(t.element_of(), b->value, scope, cache);
441444
if (v.defined()) {
442445
return Broadcast::make(v, b->lanes);
443446
}
@@ -456,7 +459,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
456459
} else if (const Shuffle *shuf = e.as<Shuffle>()) {
457460
std::vector<Expr> vecs;
458461
for (const auto &vec : shuf->vectors) {
459-
vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, cache));
462+
vecs.emplace_back(lossless_cast(t.with_lanes(vec.type().lanes()), vec, scope, cache));
460463
if (!vecs.back().defined()) {
461464
return Expr();
462465
}
@@ -465,73 +468,73 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
465468
} else if (t.is_int_or_uint()) {
466469
// Check the bounds. If they're small enough, we can throw narrowing
467470
// casts around e, or subterms.
468-
ConstantInterval ci = constant_integer_bounds(e, Scope<ConstantInterval>::empty_scope(), cache);
471+
ConstantInterval ci = constant_integer_bounds(e, scope, cache);
469472

470473
if (t.can_represent(ci)) {
471474
// There are certain IR nodes where if the result is expressible
472475
// using some type, and the args are expressible using that type,
473476
// then the operation can just be done in that type.
474477
if (const Add *op = e.as<Add>()) {
475-
Expr a = lossless_cast(t, op->a, cache);
476-
Expr b = lossless_cast(t, op->b, cache);
478+
Expr a = lossless_cast(t, op->a, scope, cache);
479+
Expr b = lossless_cast(t, op->b, scope, cache);
477480
if (a.defined() && b.defined()) {
478481
return Add::make(a, b);
479482
}
480483
} else if (const Sub *op = e.as<Sub>()) {
481-
Expr a = lossless_cast(t, op->a, cache);
482-
Expr b = lossless_cast(t, op->b, cache);
484+
Expr a = lossless_cast(t, op->a, scope, cache);
485+
Expr b = lossless_cast(t, op->b, scope, cache);
483486
if (a.defined() && b.defined()) {
484487
return Sub::make(a, b);
485488
}
486489
} else if (const Mul *op = e.as<Mul>()) {
487-
Expr a = lossless_cast(t, op->a, cache);
488-
Expr b = lossless_cast(t, op->b, cache);
490+
Expr a = lossless_cast(t, op->a, scope, cache);
491+
Expr b = lossless_cast(t, op->b, scope, cache);
489492
if (a.defined() && b.defined()) {
490493
return Mul::make(a, b);
491494
}
492495
} else if (const Min *op = e.as<Min>()) {
493-
Expr a = lossless_cast(t, op->a, cache);
494-
Expr b = lossless_cast(t, op->b, cache);
496+
Expr a = lossless_cast(t, op->a, scope, cache);
497+
Expr b = lossless_cast(t, op->b, scope, cache);
495498
if (a.defined() && b.defined()) {
496499
debug(0) << a << " " << b << "\n";
497500
return Min::make(a, b);
498501
}
499502
} else if (const Max *op = e.as<Max>()) {
500-
Expr a = lossless_cast(t, op->a, cache);
501-
Expr b = lossless_cast(t, op->b, cache);
503+
Expr a = lossless_cast(t, op->a, scope, cache);
504+
Expr b = lossless_cast(t, op->b, scope, cache);
502505
if (a.defined() && b.defined()) {
503506
return Max::make(a, b);
504507
}
505508
} else if (const Mod *op = e.as<Mod>()) {
506-
Expr a = lossless_cast(t, op->a, cache);
507-
Expr b = lossless_cast(t, op->b, cache);
509+
Expr a = lossless_cast(t, op->a, scope, cache);
510+
Expr b = lossless_cast(t, op->b, scope, cache);
508511
if (a.defined() && b.defined()) {
509512
return Mod::make(a, b);
510513
}
511514
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_add, Call::widen_right_add})) {
512-
Expr a = lossless_cast(t, op->args[0], cache);
513-
Expr b = lossless_cast(t, op->args[1], cache);
515+
Expr a = lossless_cast(t, op->args[0], scope, cache);
516+
Expr b = lossless_cast(t, op->args[1], scope, cache);
514517
if (a.defined() && b.defined()) {
515518
return Add::make(a, b);
516519
}
517520
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_sub, Call::widen_right_sub})) {
518-
Expr a = lossless_cast(t, op->args[0], cache);
519-
Expr b = lossless_cast(t, op->args[1], cache);
521+
Expr a = lossless_cast(t, op->args[0], scope, cache);
522+
Expr b = lossless_cast(t, op->args[1], scope, cache);
520523
if (a.defined() && b.defined()) {
521524
return Sub::make(a, b);
522525
}
523526
} else if (const Call *op = Call::as_intrinsic(e, {Call::widening_mul, Call::widen_right_mul})) {
524-
Expr a = lossless_cast(t, op->args[0], cache);
525-
Expr b = lossless_cast(t, op->args[1], cache);
527+
Expr a = lossless_cast(t, op->args[0], scope, cache);
528+
Expr b = lossless_cast(t, op->args[1], scope, cache);
526529
if (a.defined() && b.defined()) {
527530
return Mul::make(a, b);
528531
}
529532
} else if (const Call *op = Call::as_intrinsic(e, {Call::shift_left, Call::widening_shift_left,
530533
Call::shift_right, Call::widening_shift_right})) {
531-
Expr a = lossless_cast(t, op->args[0], cache);
532-
Expr b = lossless_cast(t, op->args[1], cache);
534+
Expr a = lossless_cast(t, op->args[0], scope, cache);
535+
Expr b = lossless_cast(t, op->args[1], scope, cache);
533536
if (a.defined() && b.defined()) {
534-
ConstantInterval cb = constant_integer_bounds(b, Scope<ConstantInterval>::empty_scope(), cache);
537+
ConstantInterval cb = constant_integer_bounds(b, scope, cache);
535538
if (cb > -t.bits() && cb < t.bits()) {
536539
if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
537540
return a << b;
@@ -544,7 +547,7 @@ Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare>
544547
if (op->op == VectorReduce::Add ||
545548
op->op == VectorReduce::Min ||
546549
op->op == VectorReduce::Max) {
547-
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, cache);
550+
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache);
548551
if (v.defined()) {
549552
return VectorReduce::make(op->op, v, op->type.lanes());
550553
}

src/IROperator.h

+12-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
#include <map>
1212
#include <optional>
1313

14+
#include "ConstantInterval.h"
1415
#include "Expr.h"
16+
#include "Scope.h"
1517
#include "Target.h"
1618
#include "Tuple.h"
1719

@@ -150,13 +152,16 @@ Expr const_false(int lanes = 1);
150152
/** Attempt to cast an expression to a smaller type while provably not losing
151153
* information. If it can't be done, return an undefined Expr.
152154
*
153-
* Optionally accepts a map that gives the constant bounds of exprs already
154-
* analyzed to avoid redoing work across many calls to lossless_cast. It is not
155-
* safe to use this optional map in contexts where the same Expr object may
156-
* take on a different value. For example:
157-
* (let x = 4 in some_expr_object) + (let x = 5 in the_same_expr_object)).
158-
* It is safe to use it after uniquify_variable_names has been run. */
159-
Expr lossless_cast(Type t, Expr e, std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);
155+
* Optionally accepts a scope giving the constant bounds of any variables, and a
156+
* map that gives the constant bounds of exprs already analyzed to avoid redoing
157+
* work across many calls to lossless_cast. It is not safe to use this optional
158+
* map in contexts where the same Expr object may take on a different value. For
159+
* example: (let x = 4 in some_expr_object) + (let x = 5 in
160+
* the_same_expr_object)). It is safe to use it after uniquify_variable_names
161+
* has been run. */
162+
Expr lossless_cast(Type t, Expr e,
163+
const Scope<ConstantInterval> &scope = Scope<ConstantInterval>::empty_scope(),
164+
std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);
160165

161166
/** Attempt to negate x without introducing new IR and without overflow.
162167
* If it can't be done, return an undefined Expr. */

test/correctness/intrinsics.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -255,17 +255,17 @@ int main(int argc, char **argv) {
255255
check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4)));
256256
check(u16(min((u64(u32x) + 8) / 16, 65535)), u16_sat(rounding_shift_right(u32x, 4)));
257257

258-
// And with variable shifts.
259-
check(i8(widening_add(i8x, (i8(1) << u8y) / 2) >> u8y), rounding_shift_right(i8x, u8y));
258+
// And with variable shifts. These won't match unless Halide can statically
259+
// prove it's not an out-of-range shift.
260+
Expr u8yc = Min::make(u8y, make_const(u8y.type(), 7));
261+
Expr i8yc = Max::make(Min::make(i8y, make_const(i8y.type(), 7)), make_const(i8y.type(), -7));
262+
check(i8(widening_add(i8x, (i8(1) << u8yc) / 2) >> u8yc), rounding_shift_right(i8x, u8yc));
260263
check((i32x + (i32(1) << u32y) / 2) >> u32y, rounding_shift_right(i32x, u32y));
261-
262-
check(i8(widening_add(i8x, (i8(1) << max(i8y, 0)) / 2) >> i8y), rounding_shift_right(i8x, i8y));
264+
check(i8(widening_add(i8x, (i8(1) << max(i8yc, 0)) / 2) >> i8yc), rounding_shift_right(i8x, i8yc));
263265
check((i32x + (i32(1) << max(i32y, 0)) / 2) >> i32y, rounding_shift_right(i32x, i32y));
264-
265-
check(i8(widening_add(i8x, (i8(1) >> min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y));
266+
check(i8(widening_add(i8x, (i8(1) >> min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc));
266267
check((i32x + (i32(1) >> min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));
267-
268-
check(i8(widening_add(i8x, (i8(1) << -min(i8y, 0)) / 2) << i8y), rounding_shift_left(i8x, i8y));
268+
check(i8(widening_add(i8x, (i8(1) << -min(i8yc, 0)) / 2) << i8yc), rounding_shift_left(i8x, i8yc));
269269
check((i32x + (i32(1) << -min(i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));
270270
check((i32x + (i32(1) << max(-i32y, 0)) / 2) << i32y, rounding_shift_left(i32x, i32y));
271271

@@ -372,7 +372,7 @@ int main(int argc, char **argv) {
372372
f(x) = cast<uint8_t>(x);
373373
f.compute_root();
374374

375-
g(x) = rounding_shift_right(x, 0) + rounding_shift_left(x, 8);
375+
g(x) = rounding_shift_right(f(x), 0) + u8(rounding_shift_left(u16(f(x)), 11));
376376

377377
g.compile_jit();
378378
}

0 commit comments

Comments
 (0)