Skip to content

Commit b11c257

Browse files
Fix def-use issue from multi-level break elimination (#6134)
1 parent f3d7aa6 commit b11c257

5 files changed

+172
-107
lines changed

source/slang/slang-ir-autodiff-cfg-norm.cpp

-107
Original file line numberDiff line numberDiff line change
@@ -707,113 +707,6 @@ struct CFGNormalizationPass
707707
}
708708
};
709709

710-
static void legalizeDefUse(IRGlobalValueWithCode* func)
711-
{
712-
auto dom = computeDominatorTree(func);
713-
for (auto block : func->getBlocks())
714-
{
715-
for (auto inst : block->getModifiableChildren())
716-
{
717-
// Inspect all uses of `inst` and find the common dominator of all use sites.
718-
IRBlock* commonDominator = block;
719-
for (auto use = inst->firstUse; use; use = use->nextUse)
720-
{
721-
auto userBlock = as<IRBlock>(use->getUser()->getParent());
722-
if (!userBlock)
723-
continue;
724-
while (commonDominator && !dom->dominates(commonDominator, userBlock))
725-
{
726-
commonDominator = dom->getImmediateDominator(commonDominator);
727-
}
728-
}
729-
SLANG_ASSERT(commonDominator);
730-
731-
if (commonDominator == block)
732-
continue;
733-
734-
// If the common dominator is not `block`, it means we have detected
735-
// uses that is no longer dominated by the current definition, and need
736-
// to be fixed.
737-
738-
// Normally, we can simply move the definition to the common dominator.
739-
// An exception is when the common dominator is the target block of a
740-
// loop. Note that after normalization, loops are in the form of:
741-
// ```
742-
// loop { if (condition) block; else break; }
743-
// ```
744-
// If we find ourselves needing to make the inst available right before
745-
// the `if`, it means we are seeing uses of the inst outside the loop.
746-
// In this case, we should insert a var/move the inst before the loop
747-
// instead of before the `if`. This situation can occur in the IR if
748-
// the original code is lowered from a `do-while` loop.
749-
for (auto use = commonDominator->firstUse; use; use = use->nextUse)
750-
{
751-
if (auto loopUser = as<IRLoop>(use->getUser()))
752-
{
753-
if (loopUser->getTargetBlock() == commonDominator)
754-
{
755-
bool shouldMoveToHeader = false;
756-
// Check that the break-block dominates any of the uses are past the break
757-
// block
758-
for (auto _use = inst->firstUse; _use; _use = _use->nextUse)
759-
{
760-
if (dom->dominates(
761-
loopUser->getBreakBlock(),
762-
_use->getUser()->getParent()))
763-
{
764-
shouldMoveToHeader = true;
765-
break;
766-
}
767-
}
768-
769-
if (shouldMoveToHeader)
770-
commonDominator = as<IRBlock>(loopUser->getParent());
771-
break;
772-
}
773-
}
774-
}
775-
// Now we can legalize uses based on the type of `inst`.
776-
if (auto var = as<IRVar>(inst))
777-
{
778-
// If inst is an var, this is easy, we just move it to the
779-
// common dominator.
780-
var->insertBefore(commonDominator->getTerminator());
781-
}
782-
else
783-
{
784-
// For all other insts, we need to create a local var for it,
785-
// and replace all uses with a load from the local var.
786-
IRBuilder builder(func);
787-
builder.setInsertBefore(commonDominator->getTerminator());
788-
IRVar* tempVar = builder.emitVar(inst->getFullType());
789-
auto defaultVal = builder.emitDefaultConstruct(inst->getFullType());
790-
builder.emitStore(tempVar, defaultVal);
791-
792-
builder.setInsertAfter(inst);
793-
builder.emitStore(tempVar, inst);
794-
795-
traverseUses(
796-
inst,
797-
[&](IRUse* use)
798-
{
799-
auto userBlock = as<IRBlock>(use->getUser()->getParent());
800-
if (!userBlock)
801-
return;
802-
// Only fix the use of the current definition of `inst` does not
803-
// dominate it.
804-
if (!dom->dominates(block, userBlock))
805-
{
806-
// Replace the use with a load of tempVar.
807-
builder.setInsertBefore(use->getUser());
808-
auto load = builder.emitLoad(tempVar);
809-
builder.replaceOperand(use, load);
810-
}
811-
});
812-
}
813-
}
814-
}
815-
}
816-
817710
void normalizeCFG(
818711
IRModule* module,
819712
IRGlobalValueWithCode* func,

source/slang/slang-ir-eliminate-multilevel-break.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "slang-ir-dominators.h"
66
#include "slang-ir-eliminate-phis.h"
77
#include "slang-ir-insts.h"
8+
#include "slang-ir-util.h"
89
#include "slang-ir.h"
910

1011
namespace Slang
@@ -475,6 +476,8 @@ struct EliminateMultiLevelBreakContext
475476
}
476477
}
477478
}
479+
480+
legalizeDefUse(func);
478481
}
479482
};
480483

source/slang/slang-ir-util.cpp

+107
Original file line numberDiff line numberDiff line change
@@ -1986,4 +1986,111 @@ Int getSpecializationConstantId(IRGlobalParam* param)
19861986
return offset->getOffset();
19871987
}
19881988

1989+
void legalizeDefUse(IRGlobalValueWithCode* func)
1990+
{
1991+
auto dom = computeDominatorTree(func);
1992+
for (auto block : func->getBlocks())
1993+
{
1994+
for (auto inst : block->getModifiableChildren())
1995+
{
1996+
// Inspect all uses of `inst` and find the common dominator of all use sites.
1997+
IRBlock* commonDominator = block;
1998+
for (auto use = inst->firstUse; use; use = use->nextUse)
1999+
{
2000+
auto userBlock = as<IRBlock>(use->getUser()->getParent());
2001+
if (!userBlock)
2002+
continue;
2003+
while (commonDominator && !dom->dominates(commonDominator, userBlock))
2004+
{
2005+
commonDominator = dom->getImmediateDominator(commonDominator);
2006+
}
2007+
}
2008+
SLANG_ASSERT(commonDominator);
2009+
2010+
if (commonDominator == block)
2011+
continue;
2012+
2013+
// If the common dominator is not `block`, it means we have detected
2014+
// uses that is no longer dominated by the current definition, and need
2015+
// to be fixed.
2016+
2017+
// Normally, we can simply move the definition to the common dominator.
2018+
// An exception is when the common dominator is the target block of a
2019+
// loop. Note that after normalization, loops are in the form of:
2020+
// ```
2021+
// loop { if (condition) block; else break; }
2022+
// ```
2023+
// If we find ourselves needing to make the inst available right before
2024+
// the `if`, it means we are seeing uses of the inst outside the loop.
2025+
// In this case, we should insert a var/move the inst before the loop
2026+
// instead of before the `if`. This situation can occur in the IR if
2027+
// the original code is lowered from a `do-while` loop.
2028+
for (auto use = commonDominator->firstUse; use; use = use->nextUse)
2029+
{
2030+
if (auto loopUser = as<IRLoop>(use->getUser()))
2031+
{
2032+
if (loopUser->getTargetBlock() == commonDominator)
2033+
{
2034+
bool shouldMoveToHeader = false;
2035+
// Check that the break-block dominates any of the uses are past the break
2036+
// block
2037+
for (auto _use = inst->firstUse; _use; _use = _use->nextUse)
2038+
{
2039+
if (dom->dominates(
2040+
loopUser->getBreakBlock(),
2041+
_use->getUser()->getParent()))
2042+
{
2043+
shouldMoveToHeader = true;
2044+
break;
2045+
}
2046+
}
2047+
2048+
if (shouldMoveToHeader)
2049+
commonDominator = as<IRBlock>(loopUser->getParent());
2050+
break;
2051+
}
2052+
}
2053+
}
2054+
// Now we can legalize uses based on the type of `inst`.
2055+
if (auto var = as<IRVar>(inst))
2056+
{
2057+
// If inst is an var, this is easy, we just move it to the
2058+
// common dominator.
2059+
var->insertBefore(commonDominator->getTerminator());
2060+
}
2061+
else
2062+
{
2063+
// For all other insts, we need to create a local var for it,
2064+
// and replace all uses with a load from the local var.
2065+
IRBuilder builder(func);
2066+
builder.setInsertBefore(commonDominator->getTerminator());
2067+
IRVar* tempVar = builder.emitVar(inst->getFullType());
2068+
auto defaultVal = builder.emitDefaultConstruct(inst->getFullType());
2069+
builder.emitStore(tempVar, defaultVal);
2070+
2071+
builder.setInsertAfter(inst);
2072+
builder.emitStore(tempVar, inst);
2073+
2074+
traverseUses(
2075+
inst,
2076+
[&](IRUse* use)
2077+
{
2078+
auto userBlock = as<IRBlock>(use->getUser()->getParent());
2079+
if (!userBlock)
2080+
return;
2081+
// Only fix the use of the current definition of `inst` does not
2082+
// dominate it.
2083+
if (!dom->dominates(block, userBlock))
2084+
{
2085+
// Replace the use with a load of tempVar.
2086+
builder.setInsertBefore(use->getUser());
2087+
auto load = builder.emitLoad(tempVar);
2088+
builder.replaceOperand(use, load);
2089+
}
2090+
});
2091+
}
2092+
}
2093+
}
2094+
}
2095+
19892096
} // namespace Slang

source/slang/slang-ir-util.h

+2
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ IRType* getIRVectorBaseType(IRType* type);
375375

376376
Int getSpecializationConstantId(IRGlobalParam* param);
377377

378+
void legalizeDefUse(IRGlobalValueWithCode* func);
379+
378380
} // namespace Slang
379381

380382
#endif
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
2+
3+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
4+
RWStructuredBuffer<float> outputBuffer;
5+
6+
typedef DifferentialPair<float> dpfloat;
7+
typedef float.Differential dfloat;
8+
9+
10+
struct SpatialVertex : IDifferentiable
11+
{
12+
float x;
13+
};
14+
15+
struct MaterialVertex
16+
{
17+
float x;
18+
};
19+
20+
//TEST_INPUT:ubuffer(data=[2.0 2.0 2.0 2.0 2.0], stride=4):name=pathVertices
21+
RWStructuredBuffer<MaterialVertex> pathVertices;
22+
23+
[Differentiable]
24+
SpatialVertex transform(float p, MaterialVertex m)
25+
{
26+
return { p * m.x };
27+
}
28+
29+
[Differentiable]
30+
float test_simple_loop(float y)
31+
{
32+
SpatialVertex vShade[2];
33+
int pathLength = 1;
34+
35+
[ForceUnroll]
36+
for (int i = 0; i < 2; i++)
37+
{
38+
if (!(pathVertices[i].x > 1.4))
39+
{
40+
pathLength = i;
41+
break;
42+
}
43+
44+
vShade[i] = transform(y, pathVertices[i]);
45+
}
46+
47+
return vShade[0].x + vShade[1].x;
48+
}
49+
50+
[numthreads(1, 1, 1)]
51+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
52+
{
53+
{
54+
dpfloat dpy = dpfloat(1.0, 1.0);
55+
56+
var dpresult = fwd_diff(test_simple_loop)(dpy);
57+
outputBuffer[0] = pathVertices[0].x; // CHECK: 2.0
58+
outputBuffer[1] = dpresult.d; // CHECK: 4.0
59+
}
60+
}

0 commit comments

Comments
 (0)