Skip to content

Commit d8f63e7

Browse files
authored
Issue/legalize resource (shader-slang#4769)
* Fix the issue that NonUniformResourceIndex is ignored Fix the issue that after `specializeFunctionCalls`, `NonUniformResourceIndex` is ignored in the generated specialized function. The reason is that if the function has a non-uniform resource parameter, we will legalize it by replacing the resource parameter with a index, and indexing of the resource will be moved inside the specialized function. e.g. ``` void func(ResourceType resource) { ... } func(resource[NonUniformResourceIndex(0)]) ``` will be specialized into ``` void func(int index) { resource[index]; } func(0); ``` In this case, inside the function, we will loose the information about whether the resource is a non-uniform. So we add the handling for this corner case by adding insert a `NonUniformResourceIndex` into the specialized function: ``` void func(int index) { int nonUniformIdx = NonUniformResourceIndex(index); resource[nonUniformIdx]; } ``` * Fix the issue that arguments mismatch after specilization callsite specializeCall() call could cause arguments mismatch with the parameters of the specialized function. For example, if the function parameter contains a resource type ``` void func(ResourceType res) { ... } int index = ... func(resources[index]); ``` This will be specialized into ``` void func(int index) { resources[index] } int index = ... func(index); ``` However, if we have more than 1 call sites, and the other call site doesn't use `int` as the index, e.g. ``` uint index = ... func(resources[index]); ``` this call site will be specialized into ``` uint index = ... func(index); ``` this will be invalid, because the argument doesn't match the parameter. so we just add the data type of the new arguments into the function key such that For the uniformity info, we add a new attribute "IROp_NonUniformAttr", so we will form a IRAttributedType that encodes both uniformity and data type, and use it as the key of call info. So if there is call site using the different data type for the resource index, we will specialize a new function for this. * Handle the intCast and uintCast operation Since after intCast/uintCast of nonuniformIndex, it's still a nonuniformIndex. So we will handle this case as well. Also, add a new test to cover this.
1 parent f4ff423 commit d8f63e7

File tree

3 files changed

+233
-2
lines changed

3 files changed

+233
-2
lines changed

source/slang/slang-ir-inst-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout)
11731173
INST(UNormAttr, unorm, 0, HOISTABLE)
11741174
INST(SNormAttr, snorm, 0, HOISTABLE)
11751175
INST(NoDiffAttr, no_diff, 0, HOISTABLE)
1176+
INST(NonUniformAttr, nonuniform, 0, HOISTABLE)
11761177

11771178
/* SemanticAttr */
11781179
INST(UserSemanticAttr, userSemantic, 2, HOISTABLE)

source/slang/slang-ir-specialize-function-call.cpp

+91-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "slang-ir-clone.h"
66
#include "slang-ir-insts.h"
77
#include "slang-ir-ssa-simplification.h"
8+
#include "slang-ir-util.h"
89

910
namespace Slang
1011
{
@@ -363,7 +364,7 @@ struct FunctionParameterSpecializationContext
363364
// a new callee function based on the original
364365
// function and the information we gathered.
365366
//
366-
newFunc = generateSpecializedFunc(oldFunc, funcInfo);
367+
newFunc = generateSpecializedFunc(oldFunc, funcInfo, callInfo);
367368
specializedFuncs.add(callInfo.key, newFunc);
368369
}
369370

@@ -381,6 +382,7 @@ struct FunctionParameterSpecializationContext
381382
newCall->insertBefore(oldCall);
382383
oldCall->replaceUsesWith(newCall);
383384
oldCall->removeAndDeallocate();
385+
384386
}
385387

386388
// Before diving into the details on how we gather information
@@ -559,6 +561,21 @@ struct FunctionParameterSpecializationContext
559561
// the arguments at the new call site, and
560562
// don't add anything to the specialization key.
561563
//
564+
// We should also add 2 more things such that our specialization
565+
// can handle the corner cases that if the oldBase is a nonuniform
566+
// resource and also the data type of oldIndex will be handled correctly.
567+
// By doing so, we form an IRAttributedType to include both information
568+
// and add it to the key of call info.
569+
570+
List<IRAttr*> irAttrs;
571+
if (findNonuniformIndexInst(oldIndex))
572+
{
573+
IRAttr* attr = getBuilder()->getAttr(kIROp_NonUniformAttr);
574+
irAttrs.add(attr);
575+
}
576+
auto irType = getBuilder()->getAttributedType(oldIndex->getDataType(), irAttrs);
577+
ioInfo.key.vals.add(irType);
578+
562579
ioInfo.newArgs.add(oldIndex);
563580
}
564581
else if (oldArg->getOp() == kIROp_Load)
@@ -577,6 +594,27 @@ struct FunctionParameterSpecializationContext
577594
}
578595
}
579596

597+
IRInst* findNonuniformIndexInst(IRInst* inst)
598+
{
599+
while(1)
600+
{
601+
if (inst == nullptr)
602+
return nullptr;
603+
604+
if (inst->getOp() == kIROp_NonUniformResourceIndex)
605+
return inst;
606+
607+
if (inst->getOp() == kIROp_IntCast)
608+
{
609+
inst = inst->getOperand(0);
610+
}
611+
else
612+
{
613+
return nullptr;
614+
}
615+
}
616+
}
617+
580618
// The remaining information we've discussed is only
581619
// gathered once we decide we want to generate a
582620
// specialized function, but it follows much the same flow.
@@ -803,7 +841,8 @@ struct FunctionParameterSpecializationContext
803841
//
804842
IRFunc* generateSpecializedFunc(
805843
IRFunc* oldFunc,
806-
FuncSpecializationInfo const& funcInfo)
844+
FuncSpecializationInfo const& funcInfo,
845+
CallSpecializationInfo const& callInfo)
807846
{
808847
// We will make use of the infrastructure for cloning
809848
// IR code, that is defined in `ir-clone.{h,cpp}`.
@@ -933,6 +972,18 @@ struct FunctionParameterSpecializationContext
933972
newBodyInst->insertBefore(newFirstOrdinary);
934973
}
935974

975+
// We need to handle a corner case where the new argument of
976+
// the callee of this specialized function could be a use of
977+
// NonUniformResourceIndex(), in such case, any indexing operation
978+
// on the global buffer by using this new argument should be
979+
// decorated with NonUniformDecoration. However, inside the new
980+
// specialized function, we don't have that information anymore.
981+
// Therefore, we will need to scan the new argument list to find out
982+
// this case, and insert the NonUniformResourceIndex() instruction
983+
// on the corresponding parameter of the new specialized function.
984+
maybeInsertNonUniformResourceIndex(newFunc, funcInfo, callInfo);
985+
986+
936987
// At this point we've created a new specialized function,
937988
// and as such it may contain call sites that were not
938989
// covered when we built our initial work list.
@@ -964,6 +1015,44 @@ struct FunctionParameterSpecializationContext
9641015

9651016
return newFunc;
9661017
}
1018+
1019+
void maybeInsertNonUniformResourceIndex(
1020+
IRFunc* newFunc,
1021+
FuncSpecializationInfo const& funcInfo,
1022+
CallSpecializationInfo const& callInfo)
1023+
{
1024+
auto builder = getBuilder();
1025+
uint32_t paramIndex = 0;
1026+
1027+
SLANG_ASSERT(callInfo.newArgs.getCount() == funcInfo.newParams.getCount());
1028+
1029+
// Iterate over the new arguments, new parameters pair, and
1030+
// find out if there is any use of NonUniformResourceIndex()
1031+
// in the new arguments.
1032+
for (auto newArg : callInfo.newArgs)
1033+
{
1034+
if (auto nonuniformIdxInst = findNonuniformIndexInst(newArg))
1035+
{
1036+
auto firstOrdinary = newFunc->getFirstOrdinaryInst();
1037+
1038+
IRCloneEnv cloneEnv;
1039+
auto newParam = funcInfo.newParams[paramIndex];
1040+
1041+
// Clone the NonUniformResourceIndex call and insert it at beginning
1042+
// of the function. Then replace every use of the parameter with the
1043+
// NonUniformResourceIndex.
1044+
auto clonedInst = cloneInstAndOperands(&cloneEnv, builder, nonuniformIdxInst);
1045+
clonedInst->insertBefore(firstOrdinary);
1046+
newParam->replaceUsesWith(clonedInst);
1047+
1048+
// At last, set the operand of the NonUniformResourceIndex to the new parameter
1049+
// because we haven't done it yet during inst clone.
1050+
clonedInst->setOperand(0, newParam);
1051+
}
1052+
paramIndex++;
1053+
}
1054+
1055+
}
9671056
};
9681057

9691058
// The top-level function for invoking the specialization pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry main -stage compute
2+
//TEST:SIMPLE(filecheck=CHECK_GLSL_SPV):-target spirv -entry main -stage compute -emit-spirv-via-glsl
3+
//TEST:SIMPLE(filecheck=CHECK_GLSL):-target glsl -entry main -stage compute
4+
//TEST:SIMPLE(filecheck=CHECK_HLSL):-target hlsl -entry main -stage compute
5+
RWStructuredBuffer<uint> globalBuffer[] : register(t0, space0);
6+
RWStructuredBuffer<uint3> outputBuffer;
7+
8+
struct MyStruct
9+
{
10+
uint a;
11+
uint b;
12+
uint c;
13+
};
14+
15+
16+
MyStruct func(RWStructuredBuffer<uint> buffer)
17+
{
18+
MyStruct a;
19+
20+
// CHECK_GLSL: globalBuffer_0[nonuniformEXT({{.*}})]
21+
// CHECK_GLSL: globalBuffer_0[nonuniformEXT({{.*}})]
22+
23+
// For the last test case 3 that the callee passes globalBuffer[bufferIdx3] to the function,
24+
// we should not see nonuniformEXT here.
25+
26+
// CHECK_GLSL: globalBuffer_0[_{{.*}})]
27+
// CHECK_GLSL: globalBuffer_0[_{{.*}})]
28+
a.a = buffer[0];
29+
a.b = a.a + 1;
30+
a.c = a.a + a.b + 1;
31+
32+
return a;
33+
}
34+
35+
[shader("compute")]
36+
[numthreads(1, 1, 1)]
37+
void main(uint2 pixelIndex : SV_DispatchThreadID)
38+
{
39+
40+
// CHECK_SPV: OpDecorate %[[VAR1:[a-zA-Z0-9_]+]] NonUniform
41+
// CHECK_SPV: OpDecorate %[[VAR2:[a-zA-Z0-9_]+]] NonUniform
42+
// CHECK_SPV: OpDecorate %[[VAR3:[a-zA-Z0-9_]+]] NonUniform
43+
44+
45+
// CHECK_GLSL_SPV: OpDecorate %[[VAR1:[a-zA-Z0-9_]+]] NonUniform
46+
// CHECK_GLSL_SPV: OpDecorate %[[VAR2:[a-zA-Z0-9_]+]] NonUniform
47+
// CHECK_GLSL_SPV: OpDecorate %[[VAR3:[a-zA-Z0-9_]+]] NonUniform
48+
// CHECK_GLSL_SPV: OpDecorate %[[VAR4:[a-zA-Z0-9_]+]] NonUniform
49+
// CHECK_GLSL_SPV: OpDecorate %[[VAR5:[a-zA-Z0-9_]+]] NonUniform
50+
// CHECK_GLSL_SPV: OpDecorate %[[VAR6:[a-zA-Z0-9_]+]] NonUniform
51+
// CHECK_GLSL_SPV: OpDecorate %[[VAR7:[a-zA-Z0-9_]+]] NonUniform
52+
// CHECK_GLSL_SPV: OpDecorate %[[VAR8:[a-zA-Z0-9_]+]] NonUniform
53+
// CHECK_GLSL_SPV: OpDecorate %[[VAR9:[a-zA-Z0-9_]+]] NonUniform
54+
// CHECK_GLSL_SPV: OpDecorate %[[VAR10:[a-zA-Z0-9_]+]] NonUniform
55+
// CHECK_GLSL_SPV: OpDecorate %[[VAR11:[a-zA-Z0-9_]+]] NonUniform
56+
// CHECK_GLSL_SPV: OpDecorate %[[VAR12:[a-zA-Z0-9_]+]] NonUniform
57+
// CHECK_GLSL_SPV: OpDecorate %[[VAR13:[a-zA-Z0-9_]+]] NonUniform
58+
// CHECK_GLSL_SPV: OpDecorate %[[VAR14:[a-zA-Z0-9_]+]] NonUniform
59+
// CHECK_GLSL_SPV: OpDecorate %[[VAR15:[a-zA-Z0-9_]+]] NonUniform
60+
// CHECK_GLSL_SPV: OpDecorate %[[VAR16:[a-zA-Z0-9_]+]] NonUniform
61+
// CHECK_GLSL_SPV: OpDecorate %[[VAR17:[a-zA-Z0-9_]+]] NonUniform
62+
// CHECK_GLSL_SPV: OpDecorate %[[VAR18:[a-zA-Z0-9_]+]] NonUniform
63+
64+
65+
// Test case 1: slang will specialize the func call to 'MyStruct func(uint)'
66+
uint bufferIdx = pixelIndex.x;
67+
uint nonUniformIdx = NonUniformResourceIndex(bufferIdx);
68+
RWStructuredBuffer<uint> buffer = globalBuffer[nonUniformIdx];
69+
70+
// CHECK_SPV: %[[VAR1]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx
71+
72+
// CHECK_GLSL_SPV: %[[VAR1]] = OpCopyObject %uint %{{.*}}
73+
74+
// CHECK_GLSL_SPV: %[[VAR4]] = OpCopyObject %uint %[[VAR1]]
75+
// CHECK_GLSL_SPV: %[[VAR5]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR4]] %int_0 %int_0
76+
// CHECK_GLSL_SPV: %[[VAR6]] = OpLoad %uint %[[VAR5]]
77+
78+
// CHECK_GLSL_SPV: %[[VAR7]] = OpCopyObject %uint %[[VAR1]]
79+
// CHECK_GLSL_SPV: %[[VAR8]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR7]] %int_0 %int_0
80+
// CHECK_GLSL_SPV: %[[VAR9]] = OpLoad %uint %[[VAR8]]
81+
82+
// CHECK_GLSL: func_0({{.*}}nonuniformEXT({{.*}}))
83+
// CHECK_HLSL: func_0(globalBuffer_0[NonUniformResourceIndex({{.*}})])
84+
MyStruct myStruct = func(buffer);
85+
86+
int bufferIdx2 = pixelIndex.y;
87+
88+
// Test case 2: Make sure we cover the case for the different data type of the index.
89+
// In this case, slang will specialize the function to 'MyStruct func(int)'
90+
// CHECK_SPV: %[[VAR2]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx2
91+
92+
93+
// CHECK_GLSL_SPV: %[[VAR2]] = OpCopyObject %int %{{.*}}
94+
95+
// CHECK_GLSL-SPV: %[[VAR10]] = OpCopyObject %int %[[VAR2]]
96+
// CHECK_GLSL-SPV: %[[VAR11]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR10]] %int_0 %int_0
97+
// CHECK_GLSL-SPV: %[[VAR12]] = OpLoad %uint %[[VAR11]]
98+
99+
// CHECK_GLSL-SPV: %[[VAR13]] = OpCopyObject %int %[[VAR2]]
100+
// CHECK_GLSL-SPV: %[[VAR14]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR13]] %int_0 %int_0
101+
// CHECK_GLSL-SPV: %[[VAR15]] = OpLoad %uint %[[VAR14]]
102+
RWStructuredBuffer<uint> buffer2 = globalBuffer[NonUniformResourceIndex(bufferIdx2)];
103+
104+
// CHECK_GLSL: func_1({{.*}}nonuniformEXT({{.*}}))
105+
// CHECK_HLSL: func_0(globalBuffer_0[NonUniformResourceIndex({{.*}})])
106+
MyStruct myStruct2 = func(buffer2);
107+
108+
// Test case 3: Test the case that we handle the uniformity correctly, the NonUniformResourceIndex will not propagate
109+
// to the function, so there should no NonUniform decoration appeared.
110+
int bufferIdx3 = pixelIndex.y;
111+
RWStructuredBuffer<uint> buffer3 = globalBuffer[bufferIdx3];
112+
113+
// CHECK_SPV: %[[VAR4:[a-zA-Z0-9_]+]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx2
114+
115+
// Test to make sure this command is not decorated with NonUniform:
116+
// CHECK_SPV-NOT: OpDecorate %[[VAR4]] NonUniform
117+
MyStruct myStruct3 = func(buffer3);
118+
119+
120+
// Test case 4: Test to make sure we correctly cover the case that intCast or uintCast of a NonUniformResourceIndex
121+
// is still a NonUniformResourceIndex.
122+
123+
// CHECK_SPV: %[[VAR5:[a-zA-Z0-9_]+]] = OpBitcast %uint %{{.*}}
124+
// CHECK_SPV: %[[VAR3]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %[[VAR5]]
125+
126+
// CHECK_GLSL-SPV: %[[VAR19:[a-zA-Z0-9_]+]] = OpBitcast %int %[[VAR3]]
127+
// CHECK_GLSL-SPV: %[[VAR16]] = OpCopyObject %int %[[VAR19]]
128+
// CHECK_GLSL-SPV: %[[VAR17]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR16]] %int_0 %int_0
129+
// CHECK_GLSL-SPV: %[[VAR18]] = OpLoad %uint %[[VAR17]]
130+
//
131+
// Since after the nested cast, the index data type is 'uint' now, make sure it calls the same function as the test case 1.
132+
// CHECK_GLSL: func_0({{.*}}nonuniformEXT({{.*}}))
133+
RWStructuredBuffer<uint> buffer4 = globalBuffer[(uint)((int)NonUniformResourceIndex(bufferIdx))];
134+
MyStruct myStruct4 = func(buffer4);
135+
136+
outputBuffer[0] = uint3(myStruct.a, myStruct.b, myStruct.c);
137+
outputBuffer[1] = uint3(myStruct2.a, myStruct2.b, myStruct2.c);
138+
outputBuffer[2] = uint3(myStruct3.a, myStruct3.b, myStruct3.c);
139+
outputBuffer[3] = uint3(myStruct4.a, myStruct4.b, myStruct4.c);
140+
}
141+

0 commit comments

Comments
 (0)