Skip to content

Commit 66984eb

Browse files
Add WaveGetLane* support for Metal and WGSL (#6371)
* support WaveGetLane* for WGSL and Metal * update test and glsl support * address review comments and fix metal test * add missing pragma guard * update test * Revert "update test" This reverts commit f2b97e9. * update failing glsl metal test and added new test * make hlsl and glsl outputs similar * update test * disable tests for Metal and cleanup * comment fix * add expected failures * correct expected failures list * remove expected failure * add tests to expected failure --------- Co-authored-by: Yong He <yonghe@outlook.com>
1 parent e4b9600 commit 66984eb

23 files changed

+216
-113
lines changed

source/slang/core.meta.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,8 @@ attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribu
440440
__intrinsic_op($(kIROp_RequirePrelude))
441441
void __requirePrelude(constexpr String preludeText);
442442

443-
__intrinsic_op($(kIROp_RequireGLSLExtension))
444-
void __requireGLSLExtension(constexpr String preludeText);
443+
__intrinsic_op($(kIROp_RequireTargetExtension))
444+
void __requireTargetExtension(constexpr String preludeText);
445445

446446
/// @experimetal
447447
/// Perform a compile-time condition check and emit a compile-time error if the condition is false.

source/slang/glsl.meta.slang

+29-27
Original file line numberDiff line numberDiff line change
@@ -4296,7 +4296,7 @@ __generic<T : __BuiltinType>
42964296
case glsl:
42974297
{
42984298
if (__type_equals<T, float>())
4299-
__requireGLSLExtension("GL_EXT_shader_atomic_float");
4299+
__requireTargetExtension("GL_EXT_shader_atomic_float");
43004300
}
43014301
case spirv:
43024302
if (__type_equals<T, float>())
@@ -4318,7 +4318,7 @@ __generic<T : __BuiltinType>
43184318
case glsl:
43194319
{
43204320
if (__type_equals<T, float>())
4321-
__requireGLSLExtension("GL_EXT_shader_atomic_float2");
4321+
__requireTargetExtension("GL_EXT_shader_atomic_float2");
43224322
}
43234323
case spirv:
43244324
if (__type_equals<T, float>())
@@ -4758,7 +4758,7 @@ void requireGLSLExtForRayTracingBuiltin()
47584758
__target_switch
47594759
{
47604760
case glsl:
4761-
__requireGLSLExtension("GL_EXT_ray_tracing");
4761+
__requireTargetExtension("GL_EXT_ray_tracing");
47624762
__intrinsic_asm "";
47634763
default:
47644764
return;
@@ -6304,30 +6304,30 @@ public void traceRayMotionNV(
63046304
__generic<T : __BuiltinType>
63056305
[ForceInline]
63066306
void typeRequireChecks_shader_subgroup_GLSL() {
6307-
// the following is a seperate function call, since else the `__requireGLSLExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm
6307+
// the following is a seperate function call, since else the `__requireTargetExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm
63086308
__target_switch
63096309
{
63106310
case glsl:
63116311
if (__type_equals<T, half>()
63126312
|| __type_equals<T, float16_t>()
6313-
) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
6313+
) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16");
63146314
else if (__type_equals<T, uint8_t>()
63156315
|| __type_equals<T, int8_t>()
6316-
) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int8");
6316+
) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int8");
63176317
else if (__type_equals<T, uint16_t>()
63186318
|| __type_equals<T, int16_t>()
6319-
) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int16");
6319+
) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int16");
63206320
else if (__type_equals<T, uint64_t>()
63216321
|| __type_equals<T, int64_t>()
6322-
) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int64");
6322+
) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int64");
63236323

63246324
__intrinsic_asm "";
63256325
}
63266326
}
63276327

63286328
__generic<T : __BuiltinType>
63296329
void shader_subgroup_preamble() {
6330-
// checks needed for shader_subgroup functions; __requireGLSLExtension does not work
6330+
// checks needed for shader_subgroup functions; __requireTargetExtension does not work
63316331
// (does not add the ext specified correctly to the compile output; using extended type
63326332
// will result in error for using the type)
63336333
__target_switch
@@ -6347,14 +6347,14 @@ void requireGLSLExtForSubgroupBasicBuiltin() {
63476347
__target_switch
63486348
{
63496349
case glsl:
6350-
__requireGLSLExtension("GL_KHR_shader_subgroup_basic");
6350+
__requireTargetExtension("GL_KHR_shader_subgroup_basic");
63516351
__intrinsic_asm "";
63526352
default:
63536353
return;
63546354
}
63556355
}
63566356

6357-
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
6357+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
63586358
void setupExtForSubgroupBasicBuiltIn() {
63596359
__target_switch
63606360
{
@@ -6371,7 +6371,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() {
63716371
__target_switch
63726372
{
63736373
case glsl:
6374-
__requireGLSLExtension("GL_KHR_shader_subgroup_ballot");
6374+
__requireTargetExtension("GL_KHR_shader_subgroup_ballot");
63756375
__intrinsic_asm "";
63766376
default:
63776377
return;
@@ -6429,7 +6429,8 @@ public property uint gl_SubgroupID
64296429

64306430
public property uint gl_SubgroupSize
64316431
{
6432-
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
6432+
[ForceInline]
6433+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
64336434
get {
64346435
setupExtForSubgroupBasicBuiltIn();
64356436
return WaveGetLaneCount();
@@ -6438,7 +6439,8 @@ public property uint gl_SubgroupSize
64386439

64396440
public property uint gl_SubgroupInvocationID
64406441
{
6441-
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
6442+
[ForceInline]
6443+
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
64426444
get {
64436445
setupExtForSubgroupBasicBuiltIn();
64446446
return WaveGetLaneIndex();
@@ -8388,8 +8390,8 @@ void typeRequireChecks_atomic_using_float0_tier()
83888390
{
83898391
case glsl:
83908392
{
8391-
if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>())
8392-
__requireGLSLExtension("GL_EXT_shader_atomic_int64");
8393+
if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>())
8394+
__requireTargetExtension("GL_EXT_shader_atomic_int64");
83938395
}
83948396
case spirv:
83958397
return;
@@ -8405,16 +8407,16 @@ void typeRequireChecks_atomic_using_float1_tier()
84058407
case glsl:
84068408
{
84078409
if (__type_equals<T, float>())
8408-
__requireGLSLExtension("GL_EXT_shader_atomic_float");
8410+
__requireTargetExtension("GL_EXT_shader_atomic_float");
84098411
else if (__type_equals<T, half>() || __type_equals<T, float16_t>())
84108412
{
8411-
__requireGLSLExtension("GL_EXT_shader_atomic_float2");
8412-
__requireGLSLExtension("GL_EXT_shader_explicit_arithmetic_types");
8413+
__requireTargetExtension("GL_EXT_shader_atomic_float2");
8414+
__requireTargetExtension("GL_EXT_shader_explicit_arithmetic_types");
84138415
}
84148416
else if (__type_equals<T, double>())
8415-
__requireGLSLExtension("GL_EXT_shader_atomic_float");
8417+
__requireTargetExtension("GL_EXT_shader_atomic_float");
84168418
else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>())
8417-
__requireGLSLExtension("GL_EXT_shader_atomic_int64");
8419+
__requireTargetExtension("GL_EXT_shader_atomic_int64");
84188420
}
84198421
case spirv:
84208422
return;
@@ -8430,16 +8432,16 @@ void typeRequireChecks_atomic_using_float2_tier()
84308432
case glsl:
84318433
{
84328434
if (__type_equals<T, float>())
8433-
__requireGLSLExtension("GL_EXT_shader_atomic_float2");
8435+
__requireTargetExtension("GL_EXT_shader_atomic_float2");
84348436
else if (__type_equals<T, half>() || __type_equals<T, float16_t>())
84358437
{
8436-
__requireGLSLExtension("GL_EXT_shader_atomic_float2");
8437-
__requireGLSLExtension("GL_EXT_shader_explicit_arithmetic_types");
8438+
__requireTargetExtension("GL_EXT_shader_atomic_float2");
8439+
__requireTargetExtension("GL_EXT_shader_explicit_arithmetic_types");
84388440
}
84398441
else if (__type_equals<T, double>())
8440-
__requireGLSLExtension("GL_EXT_shader_atomic_float2");
8441-
else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>())
8442-
__requireGLSLExtension("GL_EXT_shader_atomic_int64");
8442+
__requireTargetExtension("GL_EXT_shader_atomic_float2");
8443+
else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>())
8444+
__requireTargetExtension("GL_EXT_shader_atomic_int64");
84438445
}
84448446
case spirv:
84458447
return;

0 commit comments

Comments
 (0)