Skip to content

Commit 5dd401e

Browse files
authored
Fix div-by-zero error during sccp. (shader-slang#2911)
* Diagnose on div-by-zero during sccp. * fix --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 57f0ab4 commit 5dd401e

8 files changed

+89
-25
lines changed

source/slang/slang-emit.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ Result linkAndOptimizeIR(
324324
}
325325

326326
lowerOptionalType(irModule, sink);
327-
simplifyIR(irModule);
327+
simplifyIR(irModule, sink);
328328

329329
switch (target)
330330
{
@@ -450,7 +450,7 @@ Result linkAndOptimizeIR(
450450

451451
validateIRModuleIfEnabled(codeGenContext, irModule);
452452

453-
simplifyIR(irModule);
453+
simplifyIR(irModule, sink);
454454

455455
if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc))
456456
{
@@ -483,7 +483,7 @@ Result linkAndOptimizeIR(
483483
// up downstream passes like type legalization, so we
484484
// will run a DCE pass to clean up after the specialization.
485485
//
486-
simplifyIR(irModule);
486+
simplifyIR(irModule, sink);
487487

488488
#if 0
489489
dumpIRIfEnabled(codeGenContext, irModule, "AFTER DCE");
@@ -565,7 +565,7 @@ Result linkAndOptimizeIR(
565565
// to see if we can clean up any temporaries created by legalization.
566566
// (e.g., things that used to be aggregated might now be split up,
567567
// so that we can work with the individual fields).
568-
simplifyIR(irModule);
568+
simplifyIR(irModule, sink);
569569

570570
#if 0
571571
dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA");
@@ -591,7 +591,7 @@ Result linkAndOptimizeIR(
591591
{
592592
specializeArrayParameters(codeGenContext, irModule);
593593
}
594-
simplifyIR(irModule);
594+
simplifyIR(irModule, sink);
595595

596596
// Rewrite functions that return arrays to return them via `out` parameter,
597597
// since our target languages doesn't allow returning arrays.
@@ -835,7 +835,7 @@ Result linkAndOptimizeIR(
835835
//
836836
// We run IR simplification passes again to clean things up.
837837
//
838-
simplifyIR(irModule);
838+
simplifyIR(irModule, sink);
839839

840840
if (isKhronosTarget(targetRequest))
841841
{
@@ -865,7 +865,7 @@ Result linkAndOptimizeIR(
865865
// Lower all bit_cast operations on complex types into leaf-level
866866
// bit_cast on basic types.
867867
lowerBitCast(targetRequest, irModule);
868-
simplifyIR(irModule);
868+
simplifyIR(irModule, sink);
869869

870870
eliminateMultiLevelBreak(irModule);
871871

source/slang/slang-ir-autodiff-fwd.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
591591
}
592592
else
593593
{
594-
diffCallee = findOrTranscribeDiffInst(builder, origCallee);
594+
if (_isDifferentiableFunc(origCallee))
595+
diffCallee = findOrTranscribeDiffInst(builder, origCallee);
595596
primalCallee = substPrimalCallee;
596597
}
597598

source/slang/slang-ir-sccp.cpp

+37-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace Slang {
1515
struct SharedSCCPContext
1616
{
1717
IRModule* module;
18+
DiagnosticSink* sink;
1819
};
1920
//
2021
// Next we have a context struct that will be applied for each function (or other
@@ -580,7 +581,7 @@ struct SCCPContext
580581
type,
581582
v0,
582583
v1,
583-
[](IRIntegerValue c0, IRIntegerValue c1) { return c0 / c1; },
584+
[](IRIntegerValue c0, IRIntegerValue c1) { return c0 / c1; },
584585
[](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 / c1; });
585586
}
586587
LatticeVal evalEql(IRType* type, LatticeVal v0, LatticeVal v1)
@@ -870,10 +871,27 @@ struct SCCPContext
870871
getLatticeVal(inst->getOperand(0)),
871872
getLatticeVal(inst->getOperand(1)));
872873
case kIROp_Div:
874+
{
875+
// Detect divide by zero error.
876+
auto divisor = getLatticeVal(inst->getOperand(1));
877+
if (divisor.flavor == LatticeVal::Flavor::Constant)
878+
{
879+
if (isIntegralType(divisor.value->getDataType()))
880+
{
881+
auto c = as<IRConstant>(divisor.value);
882+
if (c->value.intVal == 0)
883+
{
884+
if (shared->sink)
885+
shared->sink->diagnose(inst->sourceLoc, Diagnostics::divideByZero);
886+
return LatticeVal::getAny();
887+
}
888+
}
889+
}
873890
return evalDiv(
874891
inst->getDataType(),
875892
getLatticeVal(inst->getOperand(0)),
876-
getLatticeVal(inst->getOperand(1)));
893+
divisor);
894+
}
877895
case kIROp_Eql:
878896
return evalEql(
879897
inst->getDataType(),
@@ -1658,10 +1676,15 @@ static bool applySparseConditionalConstantPropagationRec(
16581676
}
16591677

16601678
bool applySparseConditionalConstantPropagation(
1661-
IRModule* module)
1679+
IRModule* module,
1680+
DiagnosticSink* sink)
16621681
{
1682+
if (sink && sink->getErrorCount())
1683+
return false;
1684+
16631685
SharedSCCPContext shared;
16641686
shared.module = module;
1687+
shared.sink = sink;
16651688

16661689
// First we fold constants at global scope.
16671690
SCCPContext globalContext;
@@ -1676,21 +1699,30 @@ bool applySparseConditionalConstantPropagation(
16761699
}
16771700

16781701
bool applySparseConditionalConstantPropagationForGlobalScope(
1679-
IRModule* module)
1702+
IRModule* module,
1703+
DiagnosticSink* sink)
16801704
{
1705+
if (sink && sink->getErrorCount())
1706+
return false;
1707+
16811708
SharedSCCPContext shared;
16821709
shared.module = module;
1710+
shared.sink = sink;
16831711
SCCPContext globalContext;
16841712
globalContext.shared = &shared;
16851713
globalContext.code = nullptr;
16861714
bool changed = globalContext.applyOnGlobalScope(module);
16871715
return changed;
16881716
}
16891717

1690-
bool applySparseConditionalConstantPropagation(IRInst* func)
1718+
bool applySparseConditionalConstantPropagation(IRInst* func, DiagnosticSink* sink)
16911719
{
1720+
if (sink && sink->getErrorCount())
1721+
return false;
1722+
16921723
SharedSCCPContext shared;
16931724
shared.module = func->getModule();
1725+
shared.sink = sink;
16941726

16951727
SCCPContext globalContext;
16961728
globalContext.shared = &shared;

source/slang/slang-ir-sccp.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ namespace Slang
55
{
66
struct IRModule;
77
struct IRInst;
8+
class DiagnosticSink;
89

910
/// Apply Sparse Conditional Constant Propagation (SCCP) to a module.
1011
///
@@ -15,11 +16,13 @@ namespace Slang
1516
/// becoming dead code)
1617
/// Returns true if IR is changed.
1718
bool applySparseConditionalConstantPropagation(
18-
IRModule* module);
19+
IRModule* module,
20+
DiagnosticSink* sink);
1921
bool applySparseConditionalConstantPropagationForGlobalScope(
20-
IRModule* module);
22+
IRModule* module,
23+
DiagnosticSink* sink);
2124

22-
bool applySparseConditionalConstantPropagation(IRInst* func);
25+
bool applySparseConditionalConstantPropagation(IRInst* func, DiagnosticSink* sink);
2326

2427
IRInst* tryConstantFoldInst(IRModule* module, IRInst* inst);
2528
}

source/slang/slang-ir-ssa-simplification.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace Slang
1616
{
1717
// Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
1818
// until no more changes are possible.
19-
void simplifyIR(IRModule* module)
19+
void simplifyIR(IRModule* module, DiagnosticSink* sink)
2020
{
2121
bool changed = true;
2222
const int kMaxIterations = 8;
@@ -25,12 +25,16 @@ namespace Slang
2525

2626
while (changed && iterationCounter < kMaxIterations)
2727
{
28+
if (sink && sink->getErrorCount())
29+
break;
30+
2831
changed = false;
32+
2933
changed |= hoistConstants(module);
3034
changed |= deduplicateGenericChildren(module);
3135
changed |= propagateFuncProperties(module);
3236
changed |= removeUnusedGenericParam(module);
33-
changed |= applySparseConditionalConstantPropagationForGlobalScope(module);
37+
changed |= applySparseConditionalConstantPropagationForGlobalScope(module, sink);
3438
changed |= peepholeOptimize(module);
3539

3640
for (auto inst : module->getGlobalInsts())
@@ -43,7 +47,7 @@ namespace Slang
4347
while (funcChanged && funcIterationCount < kMaxFuncIterations)
4448
{
4549
funcChanged = false;
46-
funcChanged |= applySparseConditionalConstantPropagation(func);
50+
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
4751
funcChanged |= peepholeOptimize(func);
4852
funcChanged |= removeRedundancyInFunc(func);
4953
funcChanged |= simplifyCFG(func);
@@ -85,15 +89,18 @@ namespace Slang
8589
}
8690

8791

88-
void simplifyFunc(IRGlobalValueWithCode* func)
92+
void simplifyFunc(IRGlobalValueWithCode* func, DiagnosticSink* sink)
8993
{
9094
bool changed = true;
9195
const int kMaxIterations = 8;
9296
int iterationCounter = 0;
9397
while (changed && iterationCounter < kMaxIterations)
9498
{
99+
if (sink && sink->getErrorCount())
100+
break;
101+
95102
changed = false;
96-
changed |= applySparseConditionalConstantPropagation(func);
103+
changed |= applySparseConditionalConstantPropagation(func, sink);
97104
changed |= peepholeOptimize(func);
98105
changed |= removeRedundancyInFunc(func);
99106
changed |= simplifyCFG(func);
@@ -106,6 +113,7 @@ namespace Slang
106113
changed |= constructSSA(func);
107114

108115
iterationCounter++;
116+
109117
}
110118
}
111119
}

source/slang/slang-ir-ssa-simplification.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ namespace Slang
55
{
66
struct IRModule;
77
struct IRGlobalValueWithCode;
8+
class DiagnosticSink;
89

910
// Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
1011
// until no more changes are possible.
11-
void simplifyIR(IRModule* module);
12+
void simplifyIR(IRModule* module, DiagnosticSink* sink = nullptr);
1213

1314
// Run simplifications on IR that is out of SSA form.
1415
void simplifyNonSSAIR(IRModule* module);
1516

16-
void simplifyFunc(IRGlobalValueWithCode* func);
17+
void simplifyFunc(IRGlobalValueWithCode* func, DiagnosticSink* sink = nullptr);
1718
}

source/slang/slang-lower-to-ir.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -9657,7 +9657,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
96579657
//
96589658
constructSSA(module);
96599659
simplifyCFG(module);
9660-
applySparseConditionalConstantPropagation(module);
9660+
applySparseConditionalConstantPropagation(module, compileRequest->getSink());
96619661

96629662
// Next, inline calls to any functions that have been
96639663
// marked for mandatory "early" inlining.
@@ -9677,7 +9677,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
96779677
//
96789678
constructSSA(module);
96799679
simplifyCFG(module);
9680-
applySparseConditionalConstantPropagation(module);
9680+
applySparseConditionalConstantPropagation(module, compileRequest->getSink());
96819681

96829682
// Propagate `constexpr`-ness through the dataflow graph (and the
96839683
// call graph) based on constraints imposed by different instructions.
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -entry computeMain -profile cs_5_0 -target hlsl
2+
RWStructuredBuffer<uint> outputBuffer;
3+
4+
// CHECK: divide by zero
5+
uint check<let b : bool>()
6+
{
7+
return 1 / int(b);
8+
}
9+
10+
[numthreads(4, 1, 1)]
11+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
12+
{
13+
uint tid = dispatchThreadID.x;
14+
15+
uint a = check<false>();
16+
uint b = check<true>();
17+
18+
outputBuffer[tid] = a + b;
19+
}

0 commit comments

Comments
 (0)