Skip to content

Commit 5da06d4

Browse files
authored
Fix global value inlining for spirv_asm blocks. (shader-slang#4339)
1 parent 7e79669 commit 5da06d4

File tree

3 files changed

+199
-122
lines changed

3 files changed

+199
-122
lines changed

source/slang/slang-ir-insts.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -3190,7 +3190,10 @@ struct IRSPIRVAsmInst : IRInst
31903190

31913191
IRSPIRVAsmOperand* getOpcodeOperand()
31923192
{
3193-
const auto opcodeOperand = cast<IRSPIRVAsmOperand>(getOperand(0));
3193+
auto operand = getOperand(0);
3194+
if (auto globalRef = as<IRGlobalValueRef>(operand))
3195+
operand = globalRef->getValue();
3196+
const auto opcodeOperand = cast<IRSPIRVAsmOperand>(operand);
31943197
// This must be either:
31953198
// - An enum, such as 'OpNop'
31963199
// - The __truncate pseudo-instruction

source/slang/slang-ir-spirv-legalize.cpp

+173-121
Original file line numberDiff line numberDiff line change
@@ -1690,143 +1690,194 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
16901690

16911691
}
16921692

1693-
// Opcodes that can exist in global scope, as long as the operands are.
1694-
bool isLegalGlobalInst(IRInst* inst)
1693+
struct GlobalInstInliningContext
16951694
{
1696-
switch (inst->getOp())
1695+
Dictionary<IRInst*, bool> m_mapGlobalInstToShouldInline;
1696+
1697+
// Opcodes that can exist in global scope, as long as the operands are.
1698+
bool isLegalGlobalInst(IRInst* inst)
16971699
{
1698-
case kIROp_MakeStruct:
1699-
case kIROp_MakeArray:
1700-
case kIROp_MakeArrayFromElement:
1701-
case kIROp_MakeVector:
1702-
case kIROp_MakeMatrix:
1703-
case kIROp_MakeMatrixFromScalar:
1704-
case kIROp_MakeVectorFromScalar:
1705-
return true;
1706-
default:
1707-
return false;
1700+
switch (inst->getOp())
1701+
{
1702+
case kIROp_MakeStruct:
1703+
case kIROp_MakeArray:
1704+
case kIROp_MakeArrayFromElement:
1705+
case kIROp_MakeVector:
1706+
case kIROp_MakeMatrix:
1707+
case kIROp_MakeMatrixFromScalar:
1708+
case kIROp_MakeVectorFromScalar:
1709+
return true;
1710+
default:
1711+
if (as<IRConstant>(inst))
1712+
return true;
1713+
if (as<IRSPIRVAsmOperand>(inst))
1714+
return true;
1715+
return false;
1716+
}
17081717
}
1709-
}
17101718

1711-
// Opcodes that can be inlined into function bodies.
1712-
bool isInlinableGlobalInst(IRInst* inst)
1713-
{
1714-
switch (inst->getOp())
1719+
// Opcodes that can be inlined into function bodies.
1720+
bool isInlinableGlobalInst(IRInst* inst)
1721+
{
1722+
switch (inst->getOp())
1723+
{
1724+
case kIROp_Add:
1725+
case kIROp_Sub:
1726+
case kIROp_Mul:
1727+
case kIROp_FRem:
1728+
case kIROp_IRem:
1729+
case kIROp_Lsh:
1730+
case kIROp_Rsh:
1731+
case kIROp_And:
1732+
case kIROp_Or:
1733+
case kIROp_Not:
1734+
case kIROp_Neg:
1735+
case kIROp_Div:
1736+
case kIROp_FieldExtract:
1737+
case kIROp_FieldAddress:
1738+
case kIROp_GetElement:
1739+
case kIROp_GetElementPtr:
1740+
case kIROp_GetOffsetPtr:
1741+
case kIROp_UpdateElement:
1742+
case kIROp_MakeTuple:
1743+
case kIROp_GetTupleElement:
1744+
case kIROp_MakeStruct:
1745+
case kIROp_MakeArray:
1746+
case kIROp_MakeArrayFromElement:
1747+
case kIROp_MakeVector:
1748+
case kIROp_MakeMatrix:
1749+
case kIROp_MakeMatrixFromScalar:
1750+
case kIROp_MakeVectorFromScalar:
1751+
case kIROp_swizzle:
1752+
case kIROp_swizzleSet:
1753+
case kIROp_MatrixReshape:
1754+
case kIROp_MakeString:
1755+
case kIROp_MakeResultError:
1756+
case kIROp_MakeResultValue:
1757+
case kIROp_GetResultError:
1758+
case kIROp_GetResultValue:
1759+
case kIROp_CastFloatToInt:
1760+
case kIROp_CastIntToFloat:
1761+
case kIROp_CastIntToPtr:
1762+
case kIROp_PtrCast:
1763+
case kIROp_CastPtrToBool:
1764+
case kIROp_CastPtrToInt:
1765+
case kIROp_BitAnd:
1766+
case kIROp_BitNot:
1767+
case kIROp_BitOr:
1768+
case kIROp_BitXor:
1769+
case kIROp_BitCast:
1770+
case kIROp_IntCast:
1771+
case kIROp_FloatCast:
1772+
case kIROp_Greater:
1773+
case kIROp_Less:
1774+
case kIROp_Geq:
1775+
case kIROp_Leq:
1776+
case kIROp_Neq:
1777+
case kIROp_Eql:
1778+
case kIROp_Call:
1779+
case kIROp_SPIRVAsm:
1780+
return true;
1781+
default:
1782+
if (as<IRSPIRVAsmInst>(inst))
1783+
return true;
1784+
if (as<IRSPIRVAsmOperand>(inst))
1785+
return true;
1786+
return false;
1787+
}
1788+
}
1789+
1790+
bool shouldInlineInstImpl(IRInst* inst)
17151791
{
1716-
case kIROp_Add:
1717-
case kIROp_Sub:
1718-
case kIROp_Mul:
1719-
case kIROp_FRem:
1720-
case kIROp_IRem:
1721-
case kIROp_Lsh:
1722-
case kIROp_Rsh:
1723-
case kIROp_And:
1724-
case kIROp_Or:
1725-
case kIROp_Not:
1726-
case kIROp_Neg:
1727-
case kIROp_Div:
1728-
case kIROp_FieldExtract:
1729-
case kIROp_FieldAddress:
1730-
case kIROp_GetElement:
1731-
case kIROp_GetElementPtr:
1732-
case kIROp_GetOffsetPtr:
1733-
case kIROp_UpdateElement:
1734-
case kIROp_MakeTuple:
1735-
case kIROp_GetTupleElement:
1736-
case kIROp_MakeStruct:
1737-
case kIROp_MakeArray:
1738-
case kIROp_MakeArrayFromElement:
1739-
case kIROp_MakeVector:
1740-
case kIROp_MakeMatrix:
1741-
case kIROp_MakeMatrixFromScalar:
1742-
case kIROp_MakeVectorFromScalar:
1743-
case kIROp_swizzle:
1744-
case kIROp_swizzleSet:
1745-
case kIROp_MatrixReshape:
1746-
case kIROp_MakeString:
1747-
case kIROp_MakeResultError:
1748-
case kIROp_MakeResultValue:
1749-
case kIROp_GetResultError:
1750-
case kIROp_GetResultValue:
1751-
case kIROp_CastFloatToInt:
1752-
case kIROp_CastIntToFloat:
1753-
case kIROp_CastIntToPtr:
1754-
case kIROp_PtrCast:
1755-
case kIROp_CastPtrToBool:
1756-
case kIROp_CastPtrToInt:
1757-
case kIROp_BitAnd:
1758-
case kIROp_BitNot:
1759-
case kIROp_BitOr:
1760-
case kIROp_BitXor:
1761-
case kIROp_BitCast:
1762-
case kIROp_IntCast:
1763-
case kIROp_FloatCast:
1764-
case kIROp_Greater:
1765-
case kIROp_Less:
1766-
case kIROp_Geq:
1767-
case kIROp_Leq:
1768-
case kIROp_Neq:
1769-
case kIROp_Eql:
1770-
case kIROp_Call:
1771-
case kIROp_SPIRVAsm:
1792+
if (!isInlinableGlobalInst(inst))
1793+
return false;
1794+
if (isLegalGlobalInst(inst))
1795+
{
1796+
for (UInt i = 0; i < inst->getOperandCount(); i++)
1797+
if (shouldInlineInst(inst->getOperand(i)))
1798+
return true;
1799+
return false;
1800+
}
17721801
return true;
1773-
default:
1774-
return false;
17751802
}
1776-
}
17771803

1778-
bool shouldInlineInst(IRInst* inst)
1779-
{
1780-
if (!isInlinableGlobalInst(inst))
1781-
return false;
1782-
if (isLegalGlobalInst(inst))
1804+
bool shouldInlineInst(IRInst* inst)
17831805
{
1784-
for (UInt i = 0; i < inst->getOperandCount(); i++)
1785-
if (shouldInlineInst(inst->getOperand(i)))
1786-
return true;
1787-
return false;
1806+
bool result = false;
1807+
if (m_mapGlobalInstToShouldInline.tryGetValue(inst, result))
1808+
return result;
1809+
result = shouldInlineInstImpl(inst);
1810+
m_mapGlobalInstToShouldInline[inst] = result;
1811+
return result;
17881812
}
1789-
return true;
1790-
}
17911813

1792-
/// Inline `inst` in the local function body so they can be emitted as a local inst.
1793-
///
1794-
IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* inst, IRCloneEnv& cloneEnv)
1795-
{
1796-
if (!shouldInlineInst(inst))
1814+
IRInst* inlineInst(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* inst)
17971815
{
1798-
switch (inst->getOp())
1816+
IRInst* result;
1817+
if (cloneEnv.mapOldValToNew.tryGetValue(inst, result))
1818+
return result;
1819+
1820+
for (UInt i = 0; i < inst->getOperandCount(); i++)
17991821
{
1800-
case kIROp_Func:
1801-
case kIROp_Specialize:
1802-
case kIROp_Generic:
1803-
case kIROp_LookupWitness:
1804-
return inst;
1805-
}
1806-
if (as<IRType>(inst))
1807-
return inst;
1808-
1809-
// If we encounter a global value that shouldn't be inlined, e.g. a const literal,
1810-
// we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
1811-
// can be pinned to the function body.
1812-
auto result = builder.emitGlobalValueRef(inst);
1822+
auto operand = inst->getOperand(i);
1823+
IRBuilder operandBuilder(builder);
1824+
setInsertBeforeOutsideASM(operandBuilder, builder.getInsertLoc().getInst());
1825+
maybeInlineGlobalValue(operandBuilder, inst, operand, cloneEnv);
1826+
}
1827+
result = cloneInstAndOperands(&cloneEnv, &builder, inst);
18131828
cloneEnv.mapOldValToNew[inst] = result;
1829+
IRBuilder subBuilder(builder);
1830+
subBuilder.setInsertInto(result);
1831+
for (auto child : inst->getDecorations())
1832+
{
1833+
cloneInst(&cloneEnv, &subBuilder, child);
1834+
}
1835+
for (auto child : inst->getChildren())
1836+
{
1837+
inlineInst(subBuilder, cloneEnv, child);
1838+
}
18141839
return result;
18151840
}
18161841

1817-
// If the global value is inlinable, we make all its operands avaialble locally, and then copy it
1818-
// to the local scope.
1819-
ShortList<IRInst*> args;
1820-
for (UInt i = 0; i < inst->getOperandCount(); i++)
1842+
/// Inline `inst` in the local function body so they can be emitted as a local inst.
1843+
///
1844+
IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* user, IRInst* inst, IRCloneEnv& cloneEnv)
18211845
{
1822-
auto operand = inst->getOperand(i);
1823-
auto inlinedOperand = maybeInlineGlobalValue(builder, operand, cloneEnv);
1824-
args.add(inlinedOperand);
1846+
if (!shouldInlineInst(inst))
1847+
{
1848+
switch (inst->getOp())
1849+
{
1850+
case kIROp_Func:
1851+
case kIROp_Specialize:
1852+
case kIROp_Generic:
1853+
case kIROp_LookupWitness:
1854+
return inst;
1855+
}
1856+
if (as<IRType>(inst))
1857+
return inst;
1858+
1859+
// If we encounter a global value that shouldn't be inlined, e.g. a const literal,
1860+
// we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
1861+
// can be pinned to the function body.
1862+
auto result = inst;
1863+
bool shouldWrapGlobalRef = true;
1864+
if (!isLegalGlobalInst(user) && !getIROpInfo(user->getOp()).isHoistable())
1865+
shouldWrapGlobalRef = false;
1866+
else if (as<IRSPIRVAsmOperand>(user) && as<IRSPIRVAsmOperandInst>(user))
1867+
shouldWrapGlobalRef = false;
1868+
else if (as<IRSPIRVAsmInst>(user))
1869+
shouldWrapGlobalRef = false;
1870+
if (shouldWrapGlobalRef)
1871+
result = builder.emitGlobalValueRef(inst);
1872+
cloneEnv.mapOldValToNew[inst] = result;
1873+
return result;
1874+
}
1875+
1876+
// If the global value is inlinable, we make all its operands avaialble locally, and then copy it
1877+
// to the local scope.
1878+
return inlineInst(builder, cloneEnv, inst);
18251879
}
1826-
auto result = cloneInst(&cloneEnv, &builder, inst);
1827-
cloneEnv.mapOldValToNew[inst] = result;
1828-
return result;
1829-
}
1880+
};
18301881

18311882
void processBranch(IRInst* branch)
18321883
{
@@ -2079,7 +2130,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
20792130
}
20802131
}
20812132

2082-
void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
2133+
static void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
20832134
{
20842135
auto parent = beforeInst->getParent();
20852136
while (parent)
@@ -2234,6 +2285,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
22342285
// Inline global values that can't represented by SPIRV constant inst
22352286
// to their use sites.
22362287
List<IRUse*> globalInstUsesToInline;
2288+
GlobalInstInliningContext globalInstInliningContext;
22372289

22382290
for (auto globalInst : m_module->getGlobalInsts())
22392291
{
@@ -2248,7 +2300,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
22482300
sortBlocksInFunc(func);
22492301
}
22502302

2251-
if (isInlinableGlobalInst(globalInst))
2303+
if (globalInstInliningContext.isInlinableGlobalInst(globalInst))
22522304
{
22532305
for (auto use = globalInst->firstUse; use; use = use->nextUse)
22542306
{
@@ -2264,7 +2316,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
22642316
IRBuilder builder(user);
22652317
setInsertBeforeOutsideASM(builder, user);
22662318
IRCloneEnv cloneEnv;
2267-
auto val = maybeInlineGlobalValue(builder, use->get(), cloneEnv);
2319+
auto val = globalInstInliningContext.maybeInlineGlobalValue(builder, use->getUser(), use->get(), cloneEnv);
22682320
if (val != use->get())
22692321
builder.replaceOperand(use, val);
22702322
}
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv
2+
3+
// Test that we can use intrinsics in global scope constant array, which causes
4+
// the spirv_asm to be inlined in global module scope.
5+
// Our global value inlining pass should be able to clean up the global scope spirv_asm
6+
// blocks and inlining them to use sites.
7+
8+
// CHECK: %main = OpFunction
9+
// CHECK: OpStore
10+
11+
static const uint staticArr[] = {
12+
uint((((uint)round(saturate(1) * 255) << 24) | ((uint)round(saturate(0) * 255) << 16) | ((uint)round(saturate(0) * 255) << 8) | 0xff)),
13+
uint((((uint)round(saturate(1) * 255) << 24) | ((uint)round(saturate(0) * 255) << 16) | ((uint)round(saturate(1) * 255) << 8) | 0xff))
14+
};
15+
16+
RWStructuredBuffer<int> buffer;
17+
18+
[numthreads(1,1,1)]
19+
void main(int i : SV_DispatchThreadID)
20+
{
21+
buffer[0] = staticArr[i] + staticArr[1];
22+
}

0 commit comments

Comments
 (0)