Skip to content

Commit 2cc9690

Browse files
authored
Implement for metal SV_GroupIndex (shader-slang#4385)
* Implement for metal `SV_GroupIndex` 1. If we don't have `sv_GroupThreadId` available we create one using `SV_GroupIndex`s location data. 2. We emit code emulating `sv_GroupThreadId` from the same logic that CUDA/CPP uses. * address most review comments Addressed all but two: [1](shader-slang#4385 (comment)) and [2](shader-slang#4385 (comment)) I want to enable tests and be sure there is no bugs using CI before I redesign the code so I have a working fallback. * address comment, enable tests enable now functioning tests due to `SV_GroupIndex` working with metal * syntax error with groupThreadID search did `= param` instead of `= i.param` * add `sv_groupid` for test + test fixes * disable test that won't work regardless
1 parent a6b8348 commit 2cc9690

14 files changed

+235
-129
lines changed

source/slang/slang-ir-legalize-varying-params.cpp

+113-111
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,118 @@ void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src)
177177
assign(builder, dest, LegalizedVaryingVal::makeValue(src));
178178
}
179179

180+
181+
// Several of the derived calcluations rely on having
182+
// access to the "group extents" of a compute shader.
183+
// That information is expected to be present on
184+
// the entry point as a `[numthreads(...)]` attribute,
185+
// and we define a convenience routine for accessing
186+
// that information.
187+
188+
IRInst* emitCalcGroupExtents(
189+
IRBuilder& builder,
190+
IRFunc* entryPoint,
191+
IRVectorType* type)
192+
{
193+
if (auto numThreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>())
194+
{
195+
static const int kAxisCount = 3;
196+
IRInst* groupExtentAlongAxis[kAxisCount] = {};
197+
198+
for (int axis = 0; axis < kAxisCount; axis++)
199+
{
200+
auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
201+
if (!litValue)
202+
return nullptr;
203+
204+
groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
205+
}
206+
207+
return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
208+
}
209+
210+
// TODO: We may want to implement a backup option here,
211+
// in case we ever want to support compute shaders with
212+
// dynamic/flexible group size on targets that allow it.
213+
//
214+
SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
215+
UNREACHABLE_RETURN(nullptr);
216+
}
217+
218+
// There are some cases of system-value inputs that can be derived
219+
// from other inputs; notably compute shaders support `SV_DispatchThreadID`
220+
// and `SV_GroupIndex` which can both be derived from the more primitive
221+
// `SV_GroupID` and `SV_GroupThreadID`, together with the extents
222+
// of the thread group (which are specified with `[numthreads(...)]`).
223+
//
224+
// As a utilty to target-specific subtypes, we define helpers for
225+
// calculating the value of these derived system values from the
226+
// more primitive ones.
227+
228+
/// Emit code to calculate `SV_DispatchThreadID`
229+
IRInst* emitCalcDispatchThreadID(
230+
IRBuilder& builder,
231+
IRType* type,
232+
IRInst* groupID,
233+
IRInst* groupThreadID,
234+
IRInst* groupExtents)
235+
{
236+
// The dispatch thread ID can be computed as:
237+
//
238+
// dispatchThreadID = groupID*groupExtents + groupThreadID
239+
//
240+
// where `groupExtents` is the X,Y,Z extents of
241+
// each thread group in threads (as given by
242+
// `[numthreads(X,Y,Z)]`).
243+
244+
return builder.emitAdd(type,
245+
builder.emitMul(type,
246+
groupID,
247+
groupExtents),
248+
groupThreadID);
249+
}
250+
251+
/// Emit code to calculate `SV_GroupIndex`
252+
IRInst* emitCalcGroupThreadIndex(
253+
IRBuilder& builder,
254+
IRInst* groupThreadID,
255+
IRInst* groupExtents)
256+
{
257+
auto intType = builder.getIntType();
258+
auto uintType = builder.getBasicType(BaseType::UInt);
259+
260+
// The group thread index can be computed as:
261+
//
262+
// groupThreadIndex = groupThreadID.x
263+
// + groupThreadID.y*groupExtents.x
264+
// + groupThreadID.z*groupExtents.x*groupExtents.z;
265+
//
266+
// or equivalently (with one less multiply):
267+
//
268+
// groupThreadIndex = (groupThreadID.z * groupExtents.y
269+
// + groupThreadID.y) * groupExtents.x
270+
// + groupThreadID.x;
271+
//
272+
273+
// `offset = groupThreadID.z`
274+
auto zAxis = builder.getIntValue(intType, 2);
275+
IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
276+
277+
// `offset *= groupExtents.y`
278+
// `offset += groupExtents.y`
279+
auto yAxis = builder.getIntValue(intType, 1);
280+
offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
281+
offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
282+
283+
// `offset *= groupExtents.x`
284+
// `offset += groupExtents.x`
285+
auto xAxis = builder.getIntValue(intType, 0);
286+
offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
287+
offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
288+
289+
return offset;
290+
}
291+
180292
/// Context for the IR pass that legalizing entry-point
181293
/// varying parameters for a target.
182294
///
@@ -915,116 +1027,6 @@ struct EntryPointVaryingParamLegalizeContext
9151027

9161028
return LegalizedVaryingVal();
9171029
}
918-
919-
// There are some cases of system-value inputs that can be derived
920-
// from other inputs; notably compute shaders support `SV_DispatchThreadID`
921-
// and `SV_GroupIndex` which can both be derived from the more primitive
922-
// `SV_GroupID` and `SV_GroupThreadID`, together with the extents
923-
// of the thread group (which are specified with `[numthreads(...)]`).
924-
//
925-
// As a utilty to target-specific subtypes, we define helpers for
926-
// calculating the value of these derived system values from the
927-
// more primitive ones.
928-
929-
/// Emit code to calculate `SV_DispatchThreadID`
930-
IRInst* emitCalcDispatchThreadID(
931-
IRBuilder& builder,
932-
IRType* type,
933-
IRInst* groupID,
934-
IRInst* groupThreadID,
935-
IRInst* groupExtents)
936-
{
937-
// The dispatch thread ID can be computed as:
938-
//
939-
// dispatchThreadID = groupID*groupExtents + groupThreadID
940-
//
941-
// where `groupExtents` is the X,Y,Z extents of
942-
// each thread group in threads (as given by
943-
// `[numthreads(X,Y,Z)]`).
944-
945-
return builder.emitAdd(type,
946-
builder.emitMul(type,
947-
groupID,
948-
groupExtents),
949-
groupThreadID);
950-
}
951-
952-
/// Emit code to calculate `SV_GroupIndex`
953-
IRInst* emitCalcGroupThreadIndex(
954-
IRBuilder& builder,
955-
IRInst* groupThreadID,
956-
IRInst* groupExtents)
957-
{
958-
auto intType = builder.getIntType();
959-
auto uintType = builder.getBasicType(BaseType::UInt);
960-
961-
// The group thread index can be computed as:
962-
//
963-
// groupThreadIndex = groupThreadID.x
964-
// + groupThreadID.y*groupExtents.x
965-
// + groupThreadID.z*groupExtents.x*groupExtents.z;
966-
//
967-
// or equivalently (with one less multiply):
968-
//
969-
// groupThreadIndex = (groupThreadID.z * groupExtents.y
970-
// + groupThreadID.y) * groupExtents.x
971-
// + groupThreadID.x;
972-
//
973-
974-
// `offset = groupThreadID.z`
975-
auto zAxis = builder.getIntValue(intType, 2);
976-
IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
977-
978-
// `offset *= groupExtents.y`
979-
// `offset += groupExtents.y`
980-
auto yAxis = builder.getIntValue(intType, 1);
981-
offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
982-
offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
983-
984-
// `offset *= groupExtents.x`
985-
// `offset += groupExtents.x`
986-
auto xAxis = builder.getIntValue(intType, 0);
987-
offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
988-
offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
989-
990-
return offset;
991-
}
992-
993-
// Several of the derived calcluations rely on having
994-
// access to the "group extents" of a compute shader.
995-
// That information is expected to be present on
996-
// the entry point as a `[numthreads(...)]` attribute,
997-
// and we define a convenience routine for accessing
998-
// that information.
999-
1000-
IRInst* emitCalcGroupExtents(
1001-
IRBuilder& builder,
1002-
IRVectorType* type)
1003-
{
1004-
if(auto numThreadsDecor = m_entryPointFunc->findDecoration<IRNumThreadsDecoration>())
1005-
{
1006-
static const int kAxisCount = 3;
1007-
IRInst* groupExtentAlongAxis[kAxisCount] = {};
1008-
1009-
for( int axis = 0; axis < kAxisCount; axis++ )
1010-
{
1011-
auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
1012-
if(!litValue)
1013-
return nullptr;
1014-
1015-
groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
1016-
}
1017-
1018-
return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
1019-
}
1020-
1021-
// TODO: We may want to implement a backup option here,
1022-
// in case we ever want to support compute shaders with
1023-
// dynamic/flexible group size on targets that allow it.
1024-
//
1025-
SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
1026-
UNREACHABLE_RETURN(nullptr);
1027-
}
10281030
};
10291031

10301032
// With the target-independent core of the pass out of the way, we can
@@ -1391,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
13911393
// CPU target, we'd need to change it so that the thread-group size can
13921394
// be passed in as part of `ComputeVaryingThreadInput`.
13931395
//
1394-
groupExtents = emitCalcGroupExtents(builder, uint3Type);
1396+
groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type);
13951397

13961398
dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
13971399

source/slang/slang-ir-legalize-varying-params.h

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ class DiagnosticSink;
88

99
struct IRFunc;
1010
struct IRModule;
11+
struct IRInst;
12+
struct IRFunc;
13+
struct IRVectorType;
14+
struct IRBuilder;
1115

1216
void legalizeEntryPointVaryingParamsForCPU(
1317
IRModule* module,
@@ -17,4 +21,13 @@ void legalizeEntryPointVaryingParamsForCUDA(
1721
IRModule* module,
1822
DiagnosticSink* sink);
1923

24+
IRInst* emitCalcGroupThreadIndex(
25+
IRBuilder& builder,
26+
IRInst* groupThreadID,
27+
IRInst* groupExtents);
28+
29+
IRInst* emitCalcGroupExtents(
30+
IRBuilder& builder,
31+
IRFunc* entryPoint,
32+
IRVectorType* type);
2033
}

0 commit comments

Comments
 (0)