Skip to content

Commit

Permalink
Handle the case where the parent if-else region's after-block is unre…
Browse files Browse the repository at this point in the history
…achable. (#3241)

Also added a test for this.

Co-authored-by: Yong He <yonghe@outlook.com>
  • Loading branch information
saipraveenb25 and csyonghe authored Sep 27, 2023
1 parent a18dca2 commit c5c8cfb
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
24 changes: 24 additions & 0 deletions source/slang/slang-ir-autodiff-cfg-norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,30 @@ struct CFGNormalizationPass
//
afterBaseRegion = true;

// One case we do check for is if the after block is 'unreachable'
// i.e. the terminator is an `unreachable` instruction.
// In this case, we can safely assume that the after block does not
// have anything to execute. Further, we need to re-wire the
// previously unreachable block to the parent break block.
// Note that this operation is safe because if the after block was
// originally unreachable, all potential paths to it must have
// broken out of the region.
//
if (auto unreachInst = as<IRUnreachable>(afterBlock->getTerminator()))
{
// Link it to the parentAfterBlock.
builder.setInsertInto(afterBlock);
unreachInst->removeAndDeallocate();

builder.emitBranch(parentAfterBlock);

// We can now safely assume that the after block is empty.
// Set 'afterBaseRegion' to false, which should lead the rest
// of the logic to avoid splitting the after block
//
afterBaseRegion = false;
}

// Do we need to split the after region?
if (afterBaseRegion && afterBreakRegion)
{
Expand Down
60 changes: 60 additions & 0 deletions tests/autodiff/control-flow-bug.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

[Differentiable] [PreferRecompute]
float3 fetch(float2 uv)
{
if (uv.x > 0.5f)
{
if (uv.x > 0.7f)
return float3(2.) * uv.y;
else
return float3(1.) * uv.y;
}
else
{
if (uv.x > 0.3f)
return float3(4.) * uv.y;
else
return float3(3.) * uv.y;
}
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
float2 uv = (float2)dispatchThreadID.xy / float2(512, 512);
float3 color = fetch(uv);
outputBuffer[0] = color.x; // Expect: 0.0

{
DifferentialPair<float2> dpuv = diffPair(float2(0.6f));
bwd_diff(fetch)(dpuv, float3(1.f));

outputBuffer[1] = dpuv.d.y; // Expect: 1.0 * 3 = 3
}

{
DifferentialPair<float2> dpuv = diffPair(float2(0.8f));
bwd_diff(fetch)(dpuv, float3(1.f));

outputBuffer[2] = dpuv.d.y; // Expect: 2.0 * 3 = 6
}

{
DifferentialPair<float2> dpuv = diffPair(float2(0.1f));
bwd_diff(fetch)(dpuv, float3(1.f));

outputBuffer[3] = dpuv.d.y; // Expect: 3.0 * 3 = 9
}

{
DifferentialPair<float2> dpuv = diffPair(float2(0.4f));
bwd_diff(fetch)(dpuv, float3(1.f));

outputBuffer[4] = dpuv.d.y; // Expect: 4.0 * 3 = 12
}
}
6 changes: 6 additions & 0 deletions tests/autodiff/control-flow-bug.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
type: float
0.000000
3.000000
6.000000
9.000000
12.000000

0 comments on commit c5c8cfb

Please sign in to comment.