Skip to content

Commit e5b796d

Browse files
authored
Allow existential types in StructuredBuffer element type. (shader-slang#1536)
* Allow existential types in `StructuredBuffer` element type. * Handle StructuredBuffer.Load/.Consume methods * Clean up unnecessary changes * Code cleanup * Update test comment
1 parent d6a2d29 commit e5b796d

10 files changed

+163
-20
lines changed

source/slang/slang-check-shader.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ namespace Slang
7474
loc);
7575
return;
7676
}
77+
else if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type))
78+
{
79+
_collectExistentialSpecializationParamsRec(
80+
astBuilder, ioSpecializationParams, structuredBufferType->getElementType(), loc);
81+
return;
82+
}
7783

7884
if( auto declRefType = as<DeclRefType>(type) )
7985
{

source/slang/slang-emit-c-like.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
222222
List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWitnessTable* witnessTable)
223223
{
224224
List<IRWitnessTableEntry*> sortedWitnessTableEntries;
225-
auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
225+
auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
226226
auto witnessTableItems = witnessTable->getChildren();
227227
// Build a dictionary of witness table entries for fast lookup.
228228
Dictionary<IRInst*, IRWitnessTableEntry*> witnessTableEntryDictionary;

source/slang/slang-emit-cpp.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,7 @@ void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
16061606

16071607
void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
16081608
{
1609-
auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
1609+
auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
16101610

16111611
// Ignore witness tables for builtin interface types.
16121612
if (isBuiltin(interfaceType))
@@ -1634,7 +1634,7 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
16341634
{
16351635
for (auto witnessTable : pendingWitnessTableDefinitions)
16361636
{
1637-
auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
1637+
auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
16381638
List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable);
16391639
m_writer->emit("extern \"C\"\n{\n");
16401640
m_writer->indent();

source/slang/slang-ir-insts.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -1531,12 +1531,7 @@ struct IRWitnessTable : IRInst
15311531

15321532
IRInst* getConformanceType()
15331533
{
1534-
return getOperand(0);
1535-
}
1536-
1537-
void setConformanceType(IRInst* type)
1538-
{
1539-
setOperand(0, type);
1534+
return cast<IRWitnessTableType>(getDataType())->getConformanceType();
15401535
}
15411536

15421537
IR_LEAF_ISA(WitnessTable)

source/slang/slang-ir-link.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ IRWitnessTable* cloneWitnessTableImpl(
571571
IRWitnessTable* clonedTable = dstTable;
572572
if (!clonedTable)
573573
{
574-
auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getOperand(0)));
574+
auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType()));
575575
clonedTable = builder->createWitnessTable(clonedBaseType);
576576
}
577577
cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);

source/slang/slang-ir-lower-generic-function.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,15 @@ namespace Slang
201201
void lowerWitnessTable(IRWitnessTable* witnessTable)
202202
{
203203
auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType()));
204+
IRBuilder builderStorage;
205+
auto builder = &builderStorage;
206+
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
207+
builder->setInsertBefore(witnessTable);
204208
if (interfaceType != witnessTable->getConformanceType())
205-
witnessTable->setConformanceType(interfaceType);
209+
{
210+
auto newWitnessTableType = builder->getWitnessTableType(interfaceType);
211+
witnessTable->setFullType(newWitnessTableType);
212+
}
206213
if (isBuiltin(interfaceType))
207214
return;
208215
for (auto child : witnessTable->getChildren())
@@ -223,10 +230,6 @@ namespace Slang
223230
{
224231
// Translate a Type value to an RTTI object pointer.
225232
auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal());
226-
IRBuilder builderStorage;
227-
auto builder = &builderStorage;
228-
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
229-
builder->setInsertBefore(witnessTable);
230233
auto rttiObjectPtr = builder->emitGetAddress(
231234
builder->getPtrType(builder->getRTTIType()),
232235
rttiObject);

source/slang/slang-ir-specialize.cpp

+94-3
Original file line numberDiff line numberDiff line change
@@ -770,13 +770,99 @@ struct SpecializationContext
770770
}
771771
}
772772

773+
// Finds any `IRTargetDecoration` from `inst`. Recursively chasing `specialize` chains.
774+
IRTargetIntrinsicDecoration* findTargetIntrinsicDecorationRec(IRInst* inst)
775+
{
776+
while (auto specialize = as<IRSpecialize>(inst))
777+
{
778+
inst = specialize->getBase();
779+
}
780+
while (auto genericInst = as<IRGeneric>(inst))
781+
{
782+
inst = findGenericReturnVal(genericInst);
783+
}
784+
if (auto decor = inst->findDecoration<IRTargetIntrinsicDecoration>())
785+
return decor;
786+
return nullptr;
787+
}
788+
789+
// Returns true if the call inst represents a call to
790+
// StructuredBuffer::operator[]/Load/Consume methods.
791+
bool isBufferLoadCall(IRCall* inst)
792+
{
793+
if (auto targetIntrinsic = findTargetIntrinsicDecorationRec(inst->getCallee()))
794+
{
795+
auto name = targetIntrinsic->getDefinition();
796+
if (name == ".operator[]" || name == ".Load" || name == ".Consume")
797+
{
798+
return true;
799+
}
800+
}
801+
return false;
802+
}
803+
804+
/// Transform a buffer load intrinsic call.
805+
/// `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into
806+
/// `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`.
807+
/// Returns true if `inst` matches the pattern and the load is transformed, otherwise,
808+
/// returns false.
809+
bool maybeSpecializeBufferLoadCall(IRCall* inst)
810+
{
811+
if (isBufferLoadCall(inst))
812+
{
813+
SLANG_ASSERT(inst->getArgCount() > 0);
814+
if (auto wrapExistential = as<IRWrapExistential>(inst->getArg(0)))
815+
{
816+
if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(
817+
wrapExistential->getWrappedValue()->getDataType()))
818+
{
819+
// We are seeing the instruction sequence in the form of
820+
// .operator[](wrapExistential(structuredBuffer), idx).
821+
// Similar to handling load(wrapExistential(..)) insts,
822+
// we need to replace it into wrapExistential(.operator[](sb, idx))
823+
auto resultType = inst->getFullType();
824+
auto elementType = sbType->getElementType();
825+
826+
IRBuilder builder;
827+
builder.sharedBuilder = &sharedBuilderStorage;
828+
builder.setInsertBefore(inst);
829+
830+
List<IRInst*> args;
831+
args.add(wrapExistential->getWrappedValue());
832+
for (UInt i = 1; i < inst->getArgCount(); i++)
833+
args.add(inst->getArg(i));
834+
List<IRInst*> slotOperands;
835+
UInt slotOperandCount = wrapExistential->getSlotOperandCount();
836+
for (UInt ii = 0; ii < slotOperandCount; ++ii)
837+
{
838+
slotOperands.add(wrapExistential->getSlotOperand(ii));
839+
}
840+
auto newCall = builder.emitCallInst(elementType, inst->getCallee(), args);
841+
auto newWrapExistential = builder.emitWrapExistential(
842+
resultType, newCall, slotOperandCount, slotOperands.getBuffer());
843+
inst->replaceUsesWith(newWrapExistential);
844+
inst->removeAndDeallocate();
845+
addUsersToWorkList(newWrapExistential);
846+
return true;
847+
}
848+
}
849+
}
850+
return false;
851+
}
852+
773853
// Given a `call` instruction in the IR, we need to detect the case
774854
// where the callee has some interface-type parameter(s) and at the
775855
// call site it is statically clear what concrete type(s) the arguments
776856
// will have.
777857
//
778858
void maybeSpecializeExistentialsForCall(IRCall* inst)
779859
{
860+
// Handle a special case of `StructuredBuffer.operator[]/Load/Consume`
861+
// calls first. These calls on builtin generic types should be handled
862+
// the same way as a `load` inst.
863+
if (maybeSpecializeBufferLoadCall(inst))
864+
return;
865+
780866
// We can only specialize a call when the callee function is known.
781867
//
782868
auto calleeFunc = as<IRFunc>(inst->getCallee());
@@ -1678,21 +1764,26 @@ struct SpecializationContext
16781764
type->removeAndDeallocate();
16791765
return;
16801766
}
1681-
else if( auto basePtrLikeType = as<IRPointerLikeType>(baseType) )
1767+
else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) )
16821768
{
16831769
// A `BindExistentials<P<T>, ...>` can be simplified to
16841770
// `P<BindExistentials<T, ...>>` when `P` is a pointer-like
16851771
// type constructor.
16861772
//
1687-
auto baseElementType = basePtrLikeType->getElementType();
1773+
IRType* baseElementType = nullptr;
1774+
if (auto basePtrLikeType = as<IRPointerLikeType>(baseType))
1775+
baseElementType = basePtrLikeType->getElementType();
1776+
else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType))
1777+
baseElementType = baseSBType->getElementType();
1778+
16881779
IRInst* wrappedElementType = builder.getBindExistentialsType(
16891780
baseElementType,
16901781
slotOperandCount,
16911782
type->getExistentialArgs());
16921783
addToWorkList(wrappedElementType);
16931784

16941785
auto newPtrLikeType = builder.getType(
1695-
basePtrLikeType->op,
1786+
baseType->op,
16961787
1,
16971788
&wrappedElementType);
16981789
addToWorkList(newPtrLikeType);

source/slang/slang-ir.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -2979,8 +2979,7 @@ namespace Slang
29792979
IRWitnessTable* witnessTable = createInst<IRWitnessTable>(
29802980
this,
29812981
kIROp_WitnessTable,
2982-
getWitnessTableType(baseType),
2983-
baseType);
2982+
getWitnessTableType(baseType));
29842983
addGlobalValue(this, witnessTable);
29852984
return witnessTable;
29862985
}
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Test using interface typed shader parameters wrapped inside a `StructuredBuffer`.
2+
3+
//TEST(compute):COMPARE_COMPUTE:-cpu
4+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization
5+
6+
[anyValueSize(8)]
7+
interface IInterface
8+
{
9+
int run(int input);
10+
}
11+
12+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
13+
RWStructuredBuffer<int> gOutputBuffer;
14+
15+
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb
16+
StructuredBuffer<IInterface> gCb;
17+
18+
//TEST_INPUT:ubuffer(data=[rtti(MyImpl) witness(MyImpl, IInterface) 1 0], stride=4):name=gCb1
19+
StructuredBuffer<IInterface> gCb1;
20+
21+
[numthreads(4, 1, 1)]
22+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
23+
{
24+
let tid = dispatchThreadID.x;
25+
26+
let inputVal : int = tid;
27+
IInterface v0 = gCb.Load(0);
28+
IInterface v1 = gCb1[0];
29+
let outputVal = v0.run(inputVal) + v1.run(inputVal);
30+
31+
gOutputBuffer[tid] = outputVal;
32+
}
33+
34+
// Specialize gCb1, but not gCb2
35+
//TEST_INPUT: globalExistentialType MyImpl
36+
//TEST_INPUT: globalExistentialType __Dynamic
37+
// Type must be marked `public` to ensure it is visible in the generated DLL.
38+
public struct MyImpl : IInterface
39+
{
40+
int val;
41+
int run(int input)
42+
{
43+
return input + val;
44+
}
45+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2
2+
4
3+
6
4+
8

0 commit comments

Comments
 (0)