Skip to content

Commit 0305099

Browse files
authored
Support shader parameters that are an array of existential type. (shader-slang#1542)
* Support shader parameters that are an array of existential type. * Rename to getFirstNonExistentialValueCategory Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 2e3688a commit 0305099

6 files changed

+192
-8
lines changed

source/slang/slang-ir-insts.h

+14
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,20 @@ struct IRFieldAddress : IRInst
12441244

12451245
};
12461246

1247+
struct IRGetElement : IRInst
1248+
{
1249+
IR_LEAF_ISA(getElement);
1250+
IRInst* getBase() { return getOperand(0); }
1251+
IRInst* getIndex() { return getOperand(1); }
1252+
};
1253+
1254+
struct IRGetElementPtr : IRInst
1255+
{
1256+
IR_LEAF_ISA(getElementPtr);
1257+
IRInst* getBase() { return getOperand(0); }
1258+
IRInst* getIndex() { return getOperand(1); }
1259+
};
1260+
12471261
struct IRGetAddress : IRInst
12481262
{
12491263
IR_LEAF_ISA(getAddr);

source/slang/slang-ir-specialize.cpp

+93-1
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,13 @@ struct SpecializationContext
530530
maybeSpecializeFieldAddress(as<IRFieldAddress>(inst));
531531
break;
532532

533+
case kIROp_getElement:
534+
maybeSpecializeGetElement(as<IRGetElement>(inst));
535+
break;
536+
case kIROp_getElementPtr:
537+
maybeSpecializeGetElementAddress(as<IRGetElementPtr>(inst));
538+
break;
539+
533540
case kIROp_BindExistentialsType:
534541
maybeSpecializeBindExistentialsType(as<IRBindExistentialsType>(inst));
535542
break;
@@ -1506,6 +1513,11 @@ struct SpecializationContext
15061513
type = ptrLikeType->getElementType();
15071514
goto top;
15081515
}
1516+
else if (auto arrayType = as<IRArrayTypeBase>(type))
1517+
{
1518+
type = arrayType->getElementType();
1519+
goto top;
1520+
}
15091521
else if( auto structType = as<IRStructType>(type) )
15101522
{
15111523
UInt count = 0;
@@ -1695,6 +1707,82 @@ struct SpecializationContext
16951707
}
16961708
}
16971709

1710+
void maybeSpecializeGetElement(IRGetElement* inst)
1711+
{
1712+
auto baseArg = inst->getBase();
1713+
if (auto wrapInst = as<IRWrapExistential>(baseArg))
1714+
{
1715+
// We have `getElement(wrapExistential(val, ...), index)`
1716+
// We need to replace this instruction with
1717+
// `wrapExistential(getElement(val, index), ...)`
1718+
auto index = inst->getIndex();
1719+
1720+
auto val = wrapInst->getWrappedValue();
1721+
auto resultType = inst->getFullType();
1722+
1723+
IRBuilder builder;
1724+
builder.sharedBuilder = &sharedBuilderStorage;
1725+
builder.setInsertBefore(inst);
1726+
1727+
auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType();
1728+
1729+
List<IRInst*> slotOperands;
1730+
UInt slotOperandCount = wrapInst->getSlotOperandCount();
1731+
1732+
for (UInt ii = 0; ii < slotOperandCount; ++ii)
1733+
{
1734+
slotOperands.add(wrapInst->getSlotOperand(ii));
1735+
}
1736+
1737+
auto newGetElement = builder.emitElementExtract(elementType, val, index);
1738+
1739+
auto newWrapExistentialInst = builder.emitWrapExistential(
1740+
resultType, newGetElement, slotOperandCount, slotOperands.getBuffer());
1741+
1742+
addUsersToWorkList(inst);
1743+
inst->replaceUsesWith(newWrapExistentialInst);
1744+
inst->removeAndDeallocate();
1745+
}
1746+
}
1747+
1748+
void maybeSpecializeGetElementAddress(IRGetElementPtr* inst)
1749+
{
1750+
auto baseArg = inst->getBase();
1751+
if (auto wrapInst = as<IRWrapExistential>(baseArg))
1752+
{
1753+
// We have `getElementPtr(wrapExistential(val, ...), index)`
1754+
// We need to replace this instruction with
1755+
// `wrapExistential(getElementPtr(val, index), ...)`
1756+
auto index = inst->getIndex();
1757+
1758+
auto val = wrapInst->getWrappedValue();
1759+
auto resultType = inst->getFullType();
1760+
1761+
IRBuilder builder;
1762+
builder.sharedBuilder = &sharedBuilderStorage;
1763+
builder.setInsertBefore(inst);
1764+
1765+
auto elementType = cast<IRArrayTypeBase>(val->getDataType())->getElementType();
1766+
1767+
List<IRInst*> slotOperands;
1768+
UInt slotOperandCount = wrapInst->getSlotOperandCount();
1769+
1770+
for (UInt ii = 0; ii < slotOperandCount; ++ii)
1771+
{
1772+
slotOperands.add(wrapInst->getSlotOperand(ii));
1773+
}
1774+
1775+
auto newElementAddr = builder.emitElementAddress(elementType, val, index);
1776+
1777+
auto newWrapExistentialInst = builder.emitWrapExistential(
1778+
resultType, newElementAddr, slotOperandCount, slotOperands.getBuffer());
1779+
1780+
addUsersToWorkList(inst);
1781+
inst->replaceUsesWith(newWrapExistentialInst);
1782+
inst->removeAndDeallocate();
1783+
}
1784+
}
1785+
16981786
UInt calcExistentialTypeParamSlotCount(IRType* type)
16991787
{
17001788
top:
@@ -1764,7 +1852,9 @@ struct SpecializationContext
17641852
type->removeAndDeallocate();
17651853
return;
17661854
}
1767-
else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) )
1855+
else if( as<IRPointerLikeType>(baseType) ||
1856+
as<IRHLSLStructuredBufferTypeBase>(baseType) ||
1857+
as<IRArrayTypeBase>(baseType))
17681858
{
17691859
// A `BindExistentials<P<T>, ...>` can be simplified to
17701860
// `P<BindExistentials<T, ...>>` when `P` is a pointer-like
@@ -1773,6 +1863,8 @@ struct SpecializationContext
17731863
IRType* baseElementType = nullptr;
17741864
if (auto basePtrLikeType = as<IRPointerLikeType>(baseType))
17751865
baseElementType = basePtrLikeType->getElementType();
1866+
else if (auto arrayType = as<IRArrayTypeBase>(baseType))
1867+
baseElementType = arrayType->getElementType();
17761868
else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType))
17771869
baseElementType = baseSBType->getElementType();
17781870

source/slang/slang-type-layout.cpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -3489,6 +3489,13 @@ static TypeLayoutResult _createTypeLayout(
34893489
{
34903490
arrayResourceCount = elementResourceInfo.count;
34913491
}
3492+
// The second exception to this is arrays of an existential type
3493+
// where the entire array should be specialized to a single concrete type.
3494+
//
3495+
else if (elementResourceInfo.kind == LayoutResourceKind::ExistentialTypeParam)
3496+
{
3497+
arrayResourceCount = elementResourceInfo.count;
3498+
}
34923499
//
34933500
// The next big exception is when we are forming an unbounded-size
34943501
// array and the element type got "adjusted," because that means
@@ -3677,6 +3684,7 @@ static TypeLayoutResult _createTypeLayout(
36773684
typeLayout->rules = rules;
36783685

36793686
LayoutSize fixedExistentialValueSize = 0;
3687+
LayoutSize uniformSlotSize = 0;
36803688
bool targetSupportsPointer =
36813689
isCPUTarget(context.targetReq) || isCUDATarget(context.targetReq);
36823690

@@ -3689,7 +3697,7 @@ static TypeLayoutResult _createTypeLayout(
36893697
fixedExistentialValueSize = anyValueAttr->size;
36903698
}
36913699
// Append 16 bytes to accommodate RTTI pointer and witness table pointer.
3692-
auto uniformSlotSize = fixedExistentialValueSize + 16;
3700+
uniformSlotSize = fixedExistentialValueSize + 16;
36933701
typeLayout->addResourceUsage(LayoutResourceKind::Uniform, uniformSlotSize);
36943702
}
36953703
typeLayout->addResourceUsage(LayoutResourceKind::ExistentialTypeParam, 1);
@@ -3736,8 +3744,9 @@ static TypeLayoutResult _createTypeLayout(
37363744
typeLayout->pendingDataTypeLayout = concreteTypeLayout;
37373745
}
37383746
}
3739-
3740-
return TypeLayoutResult(typeLayout, SimpleLayoutInfo());
3747+
// Interface type occupies a uniform slot for the fixed size storage, with alignment of 4 bytes.
3748+
return TypeLayoutResult(
3749+
typeLayout, SimpleLayoutInfo(LayoutResourceKind::Uniform, uniformSlotSize, 4));
37413750
}
37423751
else if( auto enumDeclRef = declRef.as<EnumDecl>() )
37433752
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Test using existential shader parameter that is an interface array.
2+
3+
//TEST(compute):COMPARE_COMPUTE:-cpu
4+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda
5+
6+
[anyValueSize(8)]
7+
interface IInterface
8+
{
9+
int run(int input);
10+
}
11+
12+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
13+
RWStructuredBuffer<int> gOutputBuffer;
14+
15+
struct Params
16+
{
17+
IInterface values[2];
18+
};
19+
20+
//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
21+
ConstantBuffer<Params> gCb;
22+
23+
//TEST_INPUT:cbuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb2
24+
ConstantBuffer<Params> gCb2;
25+
26+
[numthreads(4, 1, 1)]
27+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
28+
{
29+
let tid = dispatchThreadID.x;
30+
31+
let inputVal : int = tid;
32+
let outputVal = gCb.values[0].run(inputVal) + gCb2.values[0].run(inputVal);
33+
34+
gOutputBuffer[tid] = outputVal;
35+
}
36+
37+
//TEST_INPUT: globalExistentialType MyImpl
38+
//TEST_INPUT: globalExistentialType __Dynamic
39+
40+
// Type must be marked `public` to ensure it is visible in the generated DLL.
41+
public struct MyImpl : IInterface
42+
{
43+
int val;
44+
int run(int input)
45+
{
46+
return input + val;
47+
}
48+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2
2+
4
3+
6
4+
8

tools/render-test/bind-location.cpp

+21-4
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,23 @@ void BindSet::calcValueLocations(const BindLocation& location, Slang::List<BindL
328328
}
329329
}
330330

331+
// Finds the first category from layout reflection that represents an actual value
332+
// i.e. that is not ExistentialType or ExistentialObject.
333+
template<typename LayoutReflectionType>
334+
slang::ParameterCategory getFirstNonExistentialValueCategory(LayoutReflectionType* layout)
335+
{
336+
slang::ParameterCategory category = slang::ParameterCategory::None;
337+
for (UInt i = 0; i < layout->getCategoryCount(); i++)
338+
{
339+
auto currentCategory = layout->getCategoryByIndex((unsigned int)i);
340+
if (currentCategory == slang::ParameterCategory::ExistentialTypeParam ||
341+
currentCategory == slang::ParameterCategory::ExistentialObjectParam)
342+
continue;
343+
category = currentCategory;
344+
}
345+
return category;
346+
}
347+
331348
BindLocation BindSet::toField(const BindLocation& loc, slang::VariableLayoutReflection* field) const
332349
{
333350
const Index categoryCount = Index(field->getCategoryCount());
@@ -363,8 +380,8 @@ BindLocation BindSet::toField(const BindLocation& loc, slang::VariableLayoutRefl
363380
}
364381
else
365382
{
366-
SLANG_ASSERT(categoryCount == 1);
367-
auto category = field->getCategoryByIndex(0);
383+
slang::ParameterCategory category = getFirstNonExistentialValueCategory(field);
384+
SLANG_ASSERT(category != slang::ParameterCategory::None);
368385

369386
// If I'm going from mixed, then I will have multiple items being tracked (so won't be here)
370387
// If I'm not, then I'm getting an inplace field. It must be relative
@@ -496,8 +513,8 @@ BindLocation BindSet::toIndex(const BindLocation& loc, Index index) const
496513
}
497514
else
498515
{
499-
SLANG_ASSERT(categoryCount == 1);
500-
auto category = elementTypeLayout->getCategoryByIndex(0);
516+
slang::ParameterCategory category = getFirstNonExistentialValueCategory(elementTypeLayout);
517+
SLANG_ASSERT(category != slang::ParameterCategory::None);
501518

502519
const auto elementStride = typeLayout->getElementStride(category);
503520

0 commit comments

Comments
 (0)