@@ -177,6 +177,118 @@ void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src)
177
177
assign (builder, dest, LegalizedVaryingVal::makeValue (src));
178
178
}
179
179
180
+
181
+ // Several of the derived calcluations rely on having
182
+ // access to the "group extents" of a compute shader.
183
+ // That information is expected to be present on
184
+ // the entry point as a `[numthreads(...)]` attribute,
185
+ // and we define a convenience routine for accessing
186
+ // that information.
187
+
188
+ IRInst* emitCalcGroupExtents (
189
+ IRBuilder& builder,
190
+ IRFunc* entryPoint,
191
+ IRVectorType* type)
192
+ {
193
+ if (auto numThreadsDecor = entryPoint->findDecoration <IRNumThreadsDecoration>())
194
+ {
195
+ static const int kAxisCount = 3 ;
196
+ IRInst* groupExtentAlongAxis[kAxisCount ] = {};
197
+
198
+ for (int axis = 0 ; axis < kAxisCount ; axis++)
199
+ {
200
+ auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis (axis));
201
+ if (!litValue)
202
+ return nullptr ;
203
+
204
+ groupExtentAlongAxis[axis] = builder.getIntValue (type->getElementType (), litValue->getValue ());
205
+ }
206
+
207
+ return builder.emitMakeVector (type, kAxisCount , groupExtentAlongAxis);
208
+ }
209
+
210
+ // TODO: We may want to implement a backup option here,
211
+ // in case we ever want to support compute shaders with
212
+ // dynamic/flexible group size on targets that allow it.
213
+ //
214
+ SLANG_UNEXPECTED (" Expected '[numthreads(...)]' attribute on compute entry point." );
215
+ UNREACHABLE_RETURN (nullptr );
216
+ }
217
+
218
+ // There are some cases of system-value inputs that can be derived
219
+ // from other inputs; notably compute shaders support `SV_DispatchThreadID`
220
+ // and `SV_GroupIndex` which can both be derived from the more primitive
221
+ // `SV_GroupID` and `SV_GroupThreadID`, together with the extents
222
+ // of the thread group (which are specified with `[numthreads(...)]`).
223
+ //
224
+ // As a utilty to target-specific subtypes, we define helpers for
225
+ // calculating the value of these derived system values from the
226
+ // more primitive ones.
227
+
228
+ // / Emit code to calculate `SV_DispatchThreadID`
229
+ IRInst* emitCalcDispatchThreadID (
230
+ IRBuilder& builder,
231
+ IRType* type,
232
+ IRInst* groupID,
233
+ IRInst* groupThreadID,
234
+ IRInst* groupExtents)
235
+ {
236
+ // The dispatch thread ID can be computed as:
237
+ //
238
+ // dispatchThreadID = groupID*groupExtents + groupThreadID
239
+ //
240
+ // where `groupExtents` is the X,Y,Z extents of
241
+ // each thread group in threads (as given by
242
+ // `[numthreads(X,Y,Z)]`).
243
+
244
+ return builder.emitAdd (type,
245
+ builder.emitMul (type,
246
+ groupID,
247
+ groupExtents),
248
+ groupThreadID);
249
+ }
250
+
251
+ // / Emit code to calculate `SV_GroupIndex`
252
+ IRInst* emitCalcGroupThreadIndex (
253
+ IRBuilder& builder,
254
+ IRInst* groupThreadID,
255
+ IRInst* groupExtents)
256
+ {
257
+ auto intType = builder.getIntType ();
258
+ auto uintType = builder.getBasicType (BaseType::UInt);
259
+
260
+ // The group thread index can be computed as:
261
+ //
262
+ // groupThreadIndex = groupThreadID.x
263
+ // + groupThreadID.y*groupExtents.x
264
+ // + groupThreadID.z*groupExtents.x*groupExtents.z;
265
+ //
266
+ // or equivalently (with one less multiply):
267
+ //
268
+ // groupThreadIndex = (groupThreadID.z * groupExtents.y
269
+ // + groupThreadID.y) * groupExtents.x
270
+ // + groupThreadID.x;
271
+ //
272
+
273
+ // `offset = groupThreadID.z`
274
+ auto zAxis = builder.getIntValue (intType, 2 );
275
+ IRInst* offset = builder.emitElementExtract (uintType, groupThreadID, zAxis);
276
+
277
+ // `offset *= groupExtents.y`
278
+ // `offset += groupExtents.y`
279
+ auto yAxis = builder.getIntValue (intType, 1 );
280
+ offset = builder.emitMul (uintType, offset, builder.emitElementExtract (uintType, groupExtents, yAxis));
281
+ offset = builder.emitAdd (uintType, offset, builder.emitElementExtract (uintType, groupThreadID, yAxis));
282
+
283
+ // `offset *= groupExtents.x`
284
+ // `offset += groupExtents.x`
285
+ auto xAxis = builder.getIntValue (intType, 0 );
286
+ offset = builder.emitMul (uintType, offset, builder.emitElementExtract (uintType, groupExtents, xAxis));
287
+ offset = builder.emitAdd (uintType, offset, builder.emitElementExtract (uintType, groupThreadID, xAxis));
288
+
289
+ return offset;
290
+ }
291
+
180
292
// / Context for the IR pass that legalizing entry-point
181
293
// / varying parameters for a target.
182
294
// /
@@ -915,116 +1027,6 @@ struct EntryPointVaryingParamLegalizeContext
915
1027
916
1028
return LegalizedVaryingVal ();
917
1029
}
918
-
919
- // There are some cases of system-value inputs that can be derived
920
- // from other inputs; notably compute shaders support `SV_DispatchThreadID`
921
- // and `SV_GroupIndex` which can both be derived from the more primitive
922
- // `SV_GroupID` and `SV_GroupThreadID`, together with the extents
923
- // of the thread group (which are specified with `[numthreads(...)]`).
924
- //
925
- // As a utilty to target-specific subtypes, we define helpers for
926
- // calculating the value of these derived system values from the
927
- // more primitive ones.
928
-
929
- // / Emit code to calculate `SV_DispatchThreadID`
930
- IRInst* emitCalcDispatchThreadID (
931
- IRBuilder& builder,
932
- IRType* type,
933
- IRInst* groupID,
934
- IRInst* groupThreadID,
935
- IRInst* groupExtents)
936
- {
937
- // The dispatch thread ID can be computed as:
938
- //
939
- // dispatchThreadID = groupID*groupExtents + groupThreadID
940
- //
941
- // where `groupExtents` is the X,Y,Z extents of
942
- // each thread group in threads (as given by
943
- // `[numthreads(X,Y,Z)]`).
944
-
945
- return builder.emitAdd (type,
946
- builder.emitMul (type,
947
- groupID,
948
- groupExtents),
949
- groupThreadID);
950
- }
951
-
952
- // / Emit code to calculate `SV_GroupIndex`
953
- IRInst* emitCalcGroupThreadIndex (
954
- IRBuilder& builder,
955
- IRInst* groupThreadID,
956
- IRInst* groupExtents)
957
- {
958
- auto intType = builder.getIntType ();
959
- auto uintType = builder.getBasicType (BaseType::UInt);
960
-
961
- // The group thread index can be computed as:
962
- //
963
- // groupThreadIndex = groupThreadID.x
964
- // + groupThreadID.y*groupExtents.x
965
- // + groupThreadID.z*groupExtents.x*groupExtents.z;
966
- //
967
- // or equivalently (with one less multiply):
968
- //
969
- // groupThreadIndex = (groupThreadID.z * groupExtents.y
970
- // + groupThreadID.y) * groupExtents.x
971
- // + groupThreadID.x;
972
- //
973
-
974
- // `offset = groupThreadID.z`
975
- auto zAxis = builder.getIntValue (intType, 2 );
976
- IRInst* offset = builder.emitElementExtract (uintType, groupThreadID, zAxis);
977
-
978
- // `offset *= groupExtents.y`
979
- // `offset += groupExtents.y`
980
- auto yAxis = builder.getIntValue (intType, 1 );
981
- offset = builder.emitMul (uintType, offset, builder.emitElementExtract (uintType, groupExtents, yAxis));
982
- offset = builder.emitAdd (uintType, offset, builder.emitElementExtract (uintType, groupThreadID, yAxis));
983
-
984
- // `offset *= groupExtents.x`
985
- // `offset += groupExtents.x`
986
- auto xAxis = builder.getIntValue (intType, 0 );
987
- offset = builder.emitMul (uintType, offset, builder.emitElementExtract (uintType, groupExtents, xAxis));
988
- offset = builder.emitAdd (uintType, offset, builder.emitElementExtract (uintType, groupThreadID, xAxis));
989
-
990
- return offset;
991
- }
992
-
993
- // Several of the derived calcluations rely on having
994
- // access to the "group extents" of a compute shader.
995
- // That information is expected to be present on
996
- // the entry point as a `[numthreads(...)]` attribute,
997
- // and we define a convenience routine for accessing
998
- // that information.
999
-
1000
- IRInst* emitCalcGroupExtents (
1001
- IRBuilder& builder,
1002
- IRVectorType* type)
1003
- {
1004
- if (auto numThreadsDecor = m_entryPointFunc->findDecoration <IRNumThreadsDecoration>())
1005
- {
1006
- static const int kAxisCount = 3 ;
1007
- IRInst* groupExtentAlongAxis[kAxisCount ] = {};
1008
-
1009
- for ( int axis = 0 ; axis < kAxisCount ; axis++ )
1010
- {
1011
- auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis (axis));
1012
- if (!litValue)
1013
- return nullptr ;
1014
-
1015
- groupExtentAlongAxis[axis] = builder.getIntValue (type->getElementType (), litValue->getValue ());
1016
- }
1017
-
1018
- return builder.emitMakeVector (type, kAxisCount , groupExtentAlongAxis);
1019
- }
1020
-
1021
- // TODO: We may want to implement a backup option here,
1022
- // in case we ever want to support compute shaders with
1023
- // dynamic/flexible group size on targets that allow it.
1024
- //
1025
- SLANG_UNEXPECTED (" Expected '[numthreads(...)]' attribute on compute entry point." );
1026
- UNREACHABLE_RETURN (nullptr );
1027
- }
1028
1030
};
1029
1031
1030
1032
// With the target-independent core of the pass out of the way, we can
@@ -1391,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
1391
1393
// CPU target, we'd need to change it so that the thread-group size can
1392
1394
// be passed in as part of `ComputeVaryingThreadInput`.
1393
1395
//
1394
- groupExtents = emitCalcGroupExtents (builder, uint3Type);
1396
+ groupExtents = emitCalcGroupExtents (builder, m_entryPointFunc, uint3Type);
1395
1397
1396
1398
dispatchThreadID = emitCalcDispatchThreadID (builder, uint3Type, groupID, groupThreadID, groupExtents);
1397
1399
0 commit comments