Skip to content

Commit aadf600

Browse files
authored
Specialize exsitentials parameters in struct fields. (shader-slang#1565)
* Specialize exsitentials parameters in struct fields. * Cleanup. Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 274c20a commit aadf600

4 files changed

+102
-7
lines changed

source/slang/slang-ir-specialize.cpp

+43-6
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,24 @@ struct SpecializationContext
808808
return false;
809809
}
810810

811+
/// Used by `maybeSpecailizeBufferLoadCall`, this function returns a new specialized callee that
812+
/// replaces a `specialize(.operator[], oldType)` to `specialize(.operator[], newElementType)`.
813+
IRInst* getNewSpecializedBufferLoadCallee(
814+
IRInst* oldSpecializedCallee,
815+
IRType* newContainerType,
816+
IRType* newElementType)
817+
{
818+
auto oldSpecialize = cast<IRSpecialize>(oldSpecializedCallee);
819+
SLANG_ASSERT(oldSpecialize->getArgCount() == 1);
820+
IRBuilder builder;
821+
builder.sharedBuilder = &sharedBuilderStorage;
822+
builder.setInsertBefore(oldSpecializedCallee);
823+
auto calleeType = builder.getFuncType(1, &newContainerType, newElementType);
824+
auto newSpecialize = builder.emitSpecializeInst(
825+
calleeType, oldSpecialize->getBase(), 1, (IRInst**)&newElementType);
826+
return newSpecialize;
827+
}
828+
811829
/// Transform a buffer load intrinsic call.
812830
/// `bufferLoad(wrapExistential(bufferObj, wrapArgs), loadArgs)` should be transformed into
813831
/// `wrapExistential(bufferLoad(bufferObj, loadArgs), wragArgs)`.
@@ -844,11 +862,18 @@ struct SpecializationContext
844862
{
845863
slotOperands.add(wrapExistential->getSlotOperand(ii));
846864
}
847-
auto newCall = builder.emitCallInst(elementType, inst->getCallee(), args);
865+
// The old callee should be in the form of `specialize(.operator[], IInterfaceType)`,
866+
// we should update it to be `specialize(.operator[], elementType)`, so the return type
867+
// of the load call is `elementType`.
868+
auto oldCallee = inst->getCallee();
869+
auto newCallee = getNewSpecializedBufferLoadCallee(inst->getCallee(), sbType, elementType);
870+
auto newCall = builder.emitCallInst(elementType, newCallee, args);
848871
auto newWrapExistential = builder.emitWrapExistential(
849872
resultType, newCall, slotOperandCount, slotOperands.getBuffer());
850873
inst->replaceUsesWith(newWrapExistential);
851874
inst->removeAndDeallocate();
875+
SLANG_ASSERT(!oldCallee->hasUses());
876+
oldCallee->removeAndDeallocate();
852877
addUsersToWorkList(newWrapExistential);
853878
return true;
854879
}
@@ -1080,7 +1105,8 @@ struct SpecializationContext
10801105
//
10811106
if(as<IRInterfaceType>(type))
10821107
return true;
1083-
1108+
if (calcExistentialTypeParamSlotCount(type) != 0)
1109+
return true;
10841110
// Eventually we will also want to handle arrays over
10851111
// existential types, but that will require careful
10861112
// handling in many places.
@@ -1518,6 +1544,11 @@ struct SpecializationContext
15181544
type = arrayType->getElementType();
15191545
goto top;
15201546
}
1547+
else if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(type))
1548+
{
1549+
type = sbType->getElementType();
1550+
goto top;
1551+
}
15211552
else if( auto structType = as<IRStructType>(type) )
15221553
{
15231554
UInt count = 0;
@@ -1800,6 +1831,11 @@ struct SpecializationContext
18001831
type = ptrLikeType->getElementType();
18011832
goto top;
18021833
}
1834+
else if (auto sbType = as<IRHLSLStructuredBufferTypeBase>(type))
1835+
{
1836+
type = sbType->getElementType();
1837+
goto top;
1838+
}
18031839
else if( auto structType = as<IRStructType>(type) )
18041840
{
18051841
UInt count = 0;
@@ -1872,15 +1908,15 @@ struct SpecializationContext
18721908
baseElementType,
18731909
slotOperandCount,
18741910
type->getExistentialArgs());
1875-
addToWorkList(wrappedElementType);
18761911

18771912
auto newPtrLikeType = builder.getType(
18781913
baseType->op,
18791914
1,
18801915
&wrappedElementType);
1916+
addUsersToWorkList(type);
18811917
addToWorkList(newPtrLikeType);
1918+
addToWorkList(wrappedElementType);
18821919

1883-
addUsersToWorkList(type);
18841920
type->replaceUsesWith(newPtrLikeType);
18851921
type->removeAndDeallocate();
18861922
return;
@@ -1911,10 +1947,13 @@ struct SpecializationContext
19111947
}
19121948

19131949
IRStructType* newStructType = nullptr;
1950+
addUsersToWorkList(type);
1951+
19141952
if( !existentialSpecializedStructs.TryGetValue(key, newStructType) )
19151953
{
19161954
builder.setInsertBefore(baseStructType);
19171955
newStructType = builder.createStructType();
1956+
addToWorkList(newStructType);
19181957

19191958
auto fieldSlotArgs = type->getExistentialArgs();
19201959

@@ -1939,10 +1978,8 @@ struct SpecializationContext
19391978
}
19401979

19411980
existentialSpecializedStructs.Add(key, newStructType);
1942-
addToWorkList(newStructType);
19431981
}
19441982

1945-
addUsersToWorkList(type);
19461983
type->replaceUsesWith(newStructType);
19471984
type->removeAndDeallocate();
19481985
return;

source/slang/slang-type-layout.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -2536,8 +2536,19 @@ createStructuredBufferTypeLayout(
25362536
typeLayout->addResourceUsage(info.kind, info.size);
25372537
}
25382538

2539+
// If element type contains existential type params and object params,
2540+
// we need to propagate them through the StructuredBufferLayout.
2541+
if (auto existentialTypeInfo = elementTypeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam))
2542+
{
2543+
typeLayout->addResourceUsage(existentialTypeInfo->kind, existentialTypeInfo->count);
2544+
}
2545+
if (auto existentialObjInfo = elementTypeLayout->FindResourceInfo(LayoutResourceKind::ExistentialObjectParam))
2546+
{
2547+
typeLayout->addResourceUsage(existentialObjInfo->kind, existentialObjInfo->count);
2548+
}
2549+
25392550
// Note: for now we don't deal with the case of a structured
2540-
// buffer that might contain anything other than "uniform" data,
2551+
// buffer that might contain any other resource types,
25412552
// because there really isn't a way to implement that.
25422553

25432554
return typeLayout;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Tests specializing a function with existential-struct-typed param.
2+
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cuda
4+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cpu
5+
6+
[anyValueSize(8)]
7+
interface IInterface
8+
{
9+
uint eval();
10+
}
11+
12+
struct Impl : IInterface
13+
{
14+
uint val;
15+
uint eval()
16+
{
17+
return val;
18+
}
19+
};
20+
21+
struct Params
22+
{
23+
StructuredBuffer<IInterface> obj;
24+
};
25+
26+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer
27+
RWStructuredBuffer<uint> gOutputBuffer;
28+
29+
void compute(uint tid, Params p)
30+
{
31+
gOutputBuffer[tid] = p.obj[0].eval();
32+
}
33+
34+
//TEST_INPUT: entryPointExistentialType Impl
35+
36+
[numthreads(4, 1, 1)]
37+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID,
38+
//TEST_INPUT:ubuffer(data=[0 0 0 0 1 0], stride=4):name=params.obj
39+
uniform Params params)
40+
{
41+
uint tid = dispatchThreadID.x;
42+
compute(tid, params);
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1
2+
1
3+
1
4+
1

0 commit comments

Comments
 (0)