Skip to content

Commit c5c8cfb

Browse files
Handle the case where the parent if-else region's after-block is unreachable. (shader-slang#3241)
Also added a test for this. Co-authored-by: Yong He <yonghe@outlook.com>
1 parent a18dca2 commit c5c8cfb

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

source/slang/slang-ir-autodiff-cfg-norm.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,30 @@ struct CFGNormalizationPass
356356
//
357357
afterBaseRegion = true;
358358

359+
// One case we do check for is if the after block is 'unreachable'
360+
// i.e. the terminator is an `unreachable` instruction.
361+
// In this case, we can safely assume that the after block does not
362+
// have anything to execute. Further, we need to re-wire the
363+
// previously unreachable block to the parent break block.
364+
// Note that this operation is safe because if the after block was
365+
// originally unreachable, all potential paths to it must have
366+
// broken out of the region.
367+
//
368+
if (auto unreachInst = as<IRUnreachable>(afterBlock->getTerminator()))
369+
{
370+
// Link it to the parentAfterBlock.
371+
builder.setInsertInto(afterBlock);
372+
unreachInst->removeAndDeallocate();
373+
374+
builder.emitBranch(parentAfterBlock);
375+
376+
// We can now safely assume that the after block is empty.
377+
// Set 'afterBaseRegion' to false, which should lead the rest
378+
// of the logic to avoid splitting the after block
379+
//
380+
afterBaseRegion = false;
381+
}
382+
359383
// Do we need to split the after region?
360384
if (afterBaseRegion && afterBreakRegion)
361385
{

tests/autodiff/control-flow-bug.slang

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
[Differentiable] [PreferRecompute]
8+
float3 fetch(float2 uv)
9+
{
10+
if (uv.x > 0.5f)
11+
{
12+
if (uv.x > 0.7f)
13+
return float3(2.) * uv.y;
14+
else
15+
return float3(1.) * uv.y;
16+
}
17+
else
18+
{
19+
if (uv.x > 0.3f)
20+
return float3(4.) * uv.y;
21+
else
22+
return float3(3.) * uv.y;
23+
}
24+
}
25+
26+
[numthreads(1, 1, 1)]
27+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
28+
{
29+
float2 uv = (float2)dispatchThreadID.xy / float2(512, 512);
30+
float3 color = fetch(uv);
31+
outputBuffer[0] = color.x; // Expect: 0.0
32+
33+
{
34+
DifferentialPair<float2> dpuv = diffPair(float2(0.6f));
35+
bwd_diff(fetch)(dpuv, float3(1.f));
36+
37+
outputBuffer[1] = dpuv.d.y; // Expect: 1.0 * 3 = 3
38+
}
39+
40+
{
41+
DifferentialPair<float2> dpuv = diffPair(float2(0.8f));
42+
bwd_diff(fetch)(dpuv, float3(1.f));
43+
44+
outputBuffer[2] = dpuv.d.y; // Expect: 2.0 * 3 = 6
45+
}
46+
47+
{
48+
DifferentialPair<float2> dpuv = diffPair(float2(0.1f));
49+
bwd_diff(fetch)(dpuv, float3(1.f));
50+
51+
outputBuffer[3] = dpuv.d.y; // Expect: 3.0 * 3 = 9
52+
}
53+
54+
{
55+
DifferentialPair<float2> dpuv = diffPair(float2(0.4f));
56+
bwd_diff(fetch)(dpuv, float3(1.f));
57+
58+
outputBuffer[4] = dpuv.d.y; // Expect: 4.0 * 3 = 12
59+
}
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type: float
2+
0.000000
3+
3.000000
4+
6.000000
5+
9.000000
6+
12.000000

0 commit comments

Comments
 (0)