diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index da1b47e137..e2fb8bbf27 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -440,8 +440,8 @@ attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribu __intrinsic_op($(kIROp_RequirePrelude)) void __requirePrelude(constexpr String preludeText); -__intrinsic_op($(kIROp_RequireGLSLExtension)) -void __requireGLSLExtension(constexpr String preludeText); +__intrinsic_op($(kIROp_RequireTargetExtension)) +void __requireTargetExtension(constexpr String preludeText); /// @experimetal /// Perform a compile-time condition check and emit a compile-time error if the condition is false. diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index eed6cc6908..2a89f2b66d 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -4296,7 +4296,7 @@ __generic<T : __BuiltinType> case glsl: { if (__type_equals<T, float>()) - __requireGLSLExtension("GL_EXT_shader_atomic_float"); + __requireTargetExtension("GL_EXT_shader_atomic_float"); } case spirv: if (__type_equals<T, float>()) @@ -4318,7 +4318,7 @@ __generic<T : __BuiltinType> case glsl: { if (__type_equals<T, float>()) - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); } case spirv: if (__type_equals<T, float>()) @@ -4758,7 +4758,7 @@ void requireGLSLExtForRayTracingBuiltin() __target_switch { case glsl: - __requireGLSLExtension("GL_EXT_ray_tracing"); + __requireTargetExtension("GL_EXT_ray_tracing"); __intrinsic_asm ""; default: return; @@ -6304,22 +6304,22 @@ public void traceRayMotionNV( __generic<T : __BuiltinType> [ForceInline] void typeRequireChecks_shader_subgroup_GLSL() { - // the following is a seperate function call, since else the `__requireGLSLExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm + // the following is a seperate function call, since else the `__requireTargetExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm __target_switch { case glsl: if (__type_equals<T, half>() || __type_equals<T, float16_t>() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); else if (__type_equals<T, uint8_t>() || __type_equals<T, int8_t>() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int8"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int8"); else if (__type_equals<T, uint16_t>() || __type_equals<T, int16_t>() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int16"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int16"); else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int64"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int64"); __intrinsic_asm ""; } @@ -6327,7 +6327,7 @@ void typeRequireChecks_shader_subgroup_GLSL() { __generic<T : __BuiltinType> void shader_subgroup_preamble() { - // checks needed for shader_subgroup functions; __requireGLSLExtension does not work + // checks needed for shader_subgroup functions; __requireTargetExtension does not work // (does not add the ext specified correctly to the compile output; using extended type // will result in error for using the type) __target_switch @@ -6347,14 +6347,14 @@ void requireGLSLExtForSubgroupBasicBuiltin() { __target_switch { case glsl: - __requireGLSLExtension("GL_KHR_shader_subgroup_basic"); + __requireTargetExtension("GL_KHR_shader_subgroup_basic"); __intrinsic_asm ""; default: return; } } -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] void setupExtForSubgroupBasicBuiltIn() { __target_switch { @@ -6371,7 +6371,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() { __target_switch { case glsl: - __requireGLSLExtension("GL_KHR_shader_subgroup_ballot"); + __requireTargetExtension("GL_KHR_shader_subgroup_ballot"); __intrinsic_asm ""; default: return; @@ -6429,7 +6429,8 @@ public property uint gl_SubgroupID public property uint gl_SubgroupSize { - [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] + [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); return WaveGetLaneCount(); @@ -6438,7 +6439,8 @@ public property uint gl_SubgroupSize public property uint gl_SubgroupInvocationID { - [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] + [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); return WaveGetLaneIndex(); @@ -8388,8 +8390,8 @@ void typeRequireChecks_atomic_using_float0_tier() { case glsl: { - if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>()) - __requireGLSLExtension("GL_EXT_shader_atomic_int64"); + if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>()) + __requireTargetExtension("GL_EXT_shader_atomic_int64"); } case spirv: return; @@ -8405,16 +8407,16 @@ void typeRequireChecks_atomic_using_float1_tier() case glsl: { if (__type_equals<T, float>()) - __requireGLSLExtension("GL_EXT_shader_atomic_float"); + __requireTargetExtension("GL_EXT_shader_atomic_float"); else if (__type_equals<T, half>() || __type_equals<T, float16_t>()) { - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); - __requireGLSLExtension("GL_EXT_shader_explicit_arithmetic_types"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_explicit_arithmetic_types"); } else if (__type_equals<T, double>()) - __requireGLSLExtension("GL_EXT_shader_atomic_float"); + __requireTargetExtension("GL_EXT_shader_atomic_float"); else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>()) - __requireGLSLExtension("GL_EXT_shader_atomic_int64"); + __requireTargetExtension("GL_EXT_shader_atomic_int64"); } case spirv: return; @@ -8430,16 +8432,16 @@ void typeRequireChecks_atomic_using_float2_tier() case glsl: { if (__type_equals<T, float>()) - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); else if (__type_equals<T, half>() || __type_equals<T, float16_t>()) { - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); - __requireGLSLExtension("GL_EXT_shader_explicit_arithmetic_types"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_explicit_arithmetic_types"); } else if (__type_equals<T, double>()) - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); - else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>()) - __requireGLSLExtension("GL_EXT_shader_atomic_int64"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); + else if (__type_equals<T, uint64_t>() || __type_equals<T, int64_t>()) + __requireTargetExtension("GL_EXT_shader_atomic_int64"); } case spirv: return; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a2b685b692..c9f3fb5337 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -3,8 +3,14 @@ typedef uint UINT; -__intrinsic_op($(kIROp_RequireGLSLExtension)) -void __requireGLSLExtension(String extensionName); +__intrinsic_op($(kIROp_RequireTargetExtension)) +void __requireTargetExtension(constexpr String extensionName); + +/// Built-in values or system value semantics represented as in/out global variables. +/// This allows the built-ins to be arbitrarily used from a global scope without being +/// explicitly passed as entry point parameters. +in uint __builtinWaveLaneIndex : SV_WaveLaneIndex; +in uint __builtinWaveLaneCount : SV_WaveLaneCount; //@public: /// Represents an interface for buffer data layout. @@ -3505,7 +3511,7 @@ extension _Texture<T,Shape,isArray,0,sampleCount,0,isShadow,isCombined,format> __intrinsic_asm "<invalid intrinsics>"; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetch($0, ($1).$w1b, ($1).$w1e)$z"; case spirv: const int lodLoc = Shape.dimensions+isArray; @@ -3569,7 +3575,7 @@ extension _Texture<T,Shape,isArray,0,sampleCount,0,isShadow,isCombined,format> __intrinsic_asm ".Load"; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetchOffset($0, ($1).$w1b, ($1).$w1e, ($2))$z"; case spirv: const int lodLoc = Shape.dimensions+isArray; @@ -3625,7 +3631,7 @@ extension _Texture<T,Shape,isArray,0,sampleCount,0,isShadow,isCombined,format> return Load(__makeVector(location, 0)); case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); return Load(__makeVector(location, 0)); case spirv: @@ -3702,7 +3708,7 @@ extension _Texture<T,Shape,isArray,1,sampleCount,0,isShadow,isCombined,format> __intrinsic_asm "<Not supported>"; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetch($0, $1, ($2))$z"; case spirv: if (isCombined != 0) @@ -3752,7 +3758,7 @@ extension _Texture<T,Shape,isArray,1,sampleCount,0,isShadow,isCombined,format> __intrinsic_asm ".Load"; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetchOffset($0, $1, ($2), ($3))$z"; case spirv: if (isCombined != 0) @@ -3807,7 +3813,7 @@ extension _Texture<T,Shape,isArray,1,sampleCount,0,isShadow,isCombined,format> return Load(location, 0); case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); return Load(location, 0); } } @@ -3830,7 +3836,7 @@ extension _Texture<T,Shape,isArray,1,sampleCount,0,isShadow,isCombined,format> return Load(location, sampleIndex); case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); return Load(location, sampleIndex); } } @@ -13913,7 +13919,7 @@ T WaveMaskSum(WaveMask mask, T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAdd($1)"; case cuda: __intrinsic_asm "_waveSum($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveSum($1)"; @@ -13940,7 +13946,7 @@ vector<T,N> WaveMaskSum(WaveMask mask, vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAdd($1)"; case cuda: __intrinsic_asm "_waveSumMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveSum($1)"; @@ -13979,7 +13985,7 @@ bool WaveMaskAllEqual(WaveMask mask, T value) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAllEqual($1)"; case hlsl: __intrinsic_asm "WaveActiveAllEqual($1)"; @@ -14003,7 +14009,7 @@ bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAllEqual($1)"; case hlsl: __intrinsic_asm "WaveActiveAllEqual($1)"; @@ -14040,7 +14046,7 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($1)"; case cuda: __intrinsic_asm "_wavePrefixProduct($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; @@ -14067,7 +14073,7 @@ vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($1)"; case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; @@ -14105,7 +14111,7 @@ T WaveMaskPrefixSum(WaveMask mask, T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($1)"; case cuda: __intrinsic_asm "_wavePrefixSum($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixSum($1)"; @@ -14133,7 +14139,7 @@ vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($1)"; case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixSum($1)"; @@ -14761,7 +14767,7 @@ T WaveActive$(opName.hlslName)(T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroup$(opName.glslName)($0)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case metal: __intrinsic_asm "simd_$(opName.metalName)"; @@ -14796,7 +14802,7 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroup$(opName.glslName)($0)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case metal: __intrinsic_asm "simd_$(opName.metalName)"; @@ -15018,7 +15024,8 @@ uint WaveActiveCountBits(bool value) __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[ForceInline] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] uint WaveGetLaneCount() { __target_switch @@ -15032,6 +15039,11 @@ uint WaveGetLaneCount() OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupSize:uint) }; + case metal: + return __builtinWaveLaneCount; + case wgsl: + __requireTargetExtension("subgroups"); + return __builtinWaveLaneCount; } } @@ -15039,7 +15051,8 @@ uint WaveGetLaneCount() __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[ForceInline] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] uint WaveGetLaneIndex() { __target_switch @@ -15053,6 +15066,11 @@ uint WaveGetLaneIndex() OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint) }; + case metal: + return __builtinWaveLaneIndex; + case wgsl: + __requireTargetExtension("subgroups"); + return __builtinWaveLaneIndex; } } @@ -15122,7 +15140,7 @@ T WavePrefixProduct(T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($0)"; case hlsl: __intrinsic_asm "WavePrefixProduct"; case metal: __intrinsic_asm "simd_prefix_exclusive_product"; @@ -15158,7 +15176,7 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($0)"; case hlsl: __intrinsic_asm "WavePrefixProduct"; case metal: __intrinsic_asm "simd_prefix_exclusive_product"; @@ -15209,7 +15227,7 @@ T WavePrefixSum(T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($0)"; case hlsl: __intrinsic_asm "WavePrefixSum"; case metal: __intrinsic_asm "simd_prefix_exclusive_sum"; @@ -15241,7 +15259,7 @@ vector<T,N> WavePrefixSum(vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($0)"; case hlsl: __intrinsic_asm "WavePrefixSum"; case metal: __intrinsic_asm "simd_prefix_exclusive_sum"; @@ -15292,7 +15310,7 @@ T WaveReadLaneFirst(T expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcastFirst($0)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case metal: __intrinsic_asm "simd_broadcast_first"; @@ -15314,7 +15332,7 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcastFirst($0)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case metal: __intrinsic_asm "simd_broadcast_first"; @@ -15360,7 +15378,7 @@ T WaveBroadcastLaneAt(T value, constexpr int lane) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcast($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_broadcast($0, ushort($1))"; @@ -15384,7 +15402,7 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcast($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_broadcast($0, ushort($1))"; @@ -15426,7 +15444,7 @@ T WaveReadLaneAt(T value, int lane) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -15449,7 +15467,7 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -15492,7 +15510,7 @@ T WaveShuffle(T value, int lane) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -15516,7 +15534,7 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane) __target_switch { case glsl: - if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf<T>()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -16158,7 +16176,7 @@ extension _Texture<T, __ShapeBuffer, 0, 0, 0, $(aa), 0, 0, format> { case hlsl: __intrinsic_asm ".GetDimensions"; case glsl: - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "($1 = $(glslTextureSizeFunc)($0))"; case metal: __intrinsic_asm "(*($1) = $0.get_width())"; case spirv: @@ -16178,7 +16196,7 @@ extension _Texture<T, __ShapeBuffer, 0, 0, 0, $(aa), 0, 0, format> case hlsl: __intrinsic_asm ".Load"; case metal: __intrinsic_asm "$c$0.read(uint($1))$z"; case glsl: - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$(glslLoadFuncName)($0, $1)$z"; case spirv: return spirv_asm { %sampled:__sampledType(T) = $(spvLoadInstName) $this $location; diff --git a/source/slang/slang-core-module-textures.cpp b/source/slang/slang-core-module-textures.cpp index 22c1fc63fe..f703a8a3b7 100644 --- a/source/slang/slang-core-module-textures.cpp +++ b/source/slang/slang-core-module-textures.cpp @@ -439,7 +439,7 @@ void TextureTypeInfo::writeGetDimensionFunctions() } }; glsl << "if (isCombined == 0) { " - "__requireGLSLExtension(\"GL_EXT_samplerless_texture_functions\"); }\n"; + "__requireTargetExtension(\"GL_EXT_samplerless_texture_functions\"); }\n"; glsl << "if (access == " << kCoreModule_ResourceAccessReadOnly << ") __intrinsic_asm \""; emitIntrinsic(toSlice("textureSize"), !isMultisample); diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 946e9c429f..1c48d98efd 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3061,10 +3061,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_requiredPreludes.add(preludeTextInst); break; } - case kIROp_RequireGLSLExtension: - { - break; // should already have set requirement; case covered for empty intrinsic block - } case kIROp_RequireComputeDerivative: { break; // should already have been parsed and used. @@ -3074,6 +3070,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitOperand(as<IRGlobalValueRef>(inst)->getOperand(0), getInfo(EmitOp::General)); break; } + case kIROp_RequireTargetExtension: + { + emitRequireExtension(as<IRRequireTargetExtension>(inst)); + break; + } default: diagnoseUnhandledInst(inst); break; diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 6fe7f5d34d..ca915ab2da 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -678,6 +678,8 @@ class CLikeSourceEmitter : public SourceEmitterBase void _emitCallArgList(IRCall* call, int startingOperandIndex = 1); virtual void emitCallArg(IRInst* arg); + virtual void emitRequireExtension(IRRequireTargetExtension* inst) { SLANG_UNUSED(inst); } + String _generateUniqueName(const UnownedStringSlice& slice); // Sort witnessTable entries according to the order defined in the witnessed interface type. diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 776c539b48..696830bf24 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -30,7 +30,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction( IRInst* inst, IRBuilder& builder) { - if (auto requireGLSLExt = as<IRRequireGLSLExtension>(inst)) + if (auto requireGLSLExt = as<IRRequireTargetExtension>(inst)) { _requireGLSLExtension(requireGLSLExt->getExtensionName()); return; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 13c79e9acc..7c83b194d5 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) } } +void WGSLSourceEmitter::emitRequireExtension(IRRequireTargetExtension* inst) +{ + _requireExtension(inst->getExtensionName()); +} + } // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index 441933b570..a29f39a1d7 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -57,6 +57,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE; + virtual void emitRequireExtension(IRRequireTargetExtension* inst) SLANG_OVERRIDE; + virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; void emit(const AddressSpace addressSpace); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 847c5b55c0..ddb4ea67ac 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -100,7 +100,7 @@ #include "slang-ir-strip-default-construct.h" #include "slang-ir-strip-legalization-insts.h" #include "slang-ir-synthesize-active-mask.h" -#include "slang-ir-translate-glsl-global-var.h" +#include "slang-ir-translate-global-varying-var.h" #include "slang-ir-uniformity.h" #include "slang-ir-user-type-hint.h" #include "slang-ir-validate.h" @@ -318,7 +318,7 @@ struct RequiredLoweringPassSet bool bindingQuery; bool meshOutput; bool higherOrderFunc; - bool glslGlobalVar; + bool globalVaryingVar; bool glslSSBO; bool byteAddressBuffer; bool dynamicResource; @@ -422,7 +422,7 @@ void calcRequiredLoweringPassSet( case kIROp_GlobalInputDecoration: case kIROp_GlobalOutputDecoration: case kIROp_GetWorkGroupSize: - result.glslGlobalVar = true; + result.globalVaryingVar = true; break; case kIROp_BindExistentialSlotsDecoration: result.bindExistential = true; @@ -667,8 +667,8 @@ Result linkAndOptimizeIR( if (!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO) lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink); - if (requiredLoweringPassSet.glslGlobalVar) - translateGLSLGlobalVar(codeGenContext, irModule); + if (requiredLoweringPassSet.globalVaryingVar) + translateGlobalVaryingVar(codeGenContext, irModule); if (requiredLoweringPassSet.resolveVaryingInputRef) resolveVaryingInputRef(irModule); diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h index 4ee6423566..b7290ef790 100644 --- a/source/slang/slang-ir-call-graph.h +++ b/source/slang/slang-ir-call-graph.h @@ -1,3 +1,6 @@ +// slang-ir-call-graph.h +#pragma once + #include "slang-ir-clone.h" #include "slang-ir-insts.h" diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 3e2872cb7a..714ba146db 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -407,7 +407,7 @@ INST(WitnessTableEntry, witness_table_entry, 2, 0) INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) // An inst to represent the workgroup size of the calling entry point. -// We will materialize this inst during `translateGLSLGlobalVar`. +// We will materialize this inst during `translateGlobalVaryingVar`. INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE) // An inst that returns the current stage of the calling entry point. @@ -666,7 +666,7 @@ INST_RANGE(TerminatorInst, Return, Unreachable) INST(discard, discard, 0, 0) INST(RequirePrelude, RequirePrelude, 1, 0) -INST(RequireGLSLExtension, RequireGLSLExtension, 1, 0) +INST(RequireTargetExtension, RequireTargetExtension, 1, 0) INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0) INST(StaticAssert, StaticAssert, 2, 0) INST(Printf, Printf, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4efb7d6715..5231592ca5 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3506,9 +3506,9 @@ struct IRRequirePrelude : IRInst UnownedStringSlice getPrelude() { return as<IRStringLit>(getOperand(0))->getStringSlice(); } }; -struct IRRequireGLSLExtension : IRInst +struct IRRequireTargetExtension : IRInst { - IR_LEAF_ISA(RequireGLSLExtension) + IR_LEAF_ISA(RequireTargetExtension) UnownedStringSlice getExtensionName() { return as<IRStringLit>(getOperand(0))->getStringSlice(); diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 3b65ee59af..e744969dbe 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3228,6 +3228,20 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); break; } + case SystemValueSemanticName::WaveLaneCount: + { + result.systemValueName = toSlice("threads_per_simdgroup"); + result.permittedTypes.add(builder.getUIntType()); + result.permittedTypes.add(builder.getUInt16Type()); + break; + } + case SystemValueSemanticName::WaveLaneIndex: + { + result.systemValueName = toSlice("thread_index_in_simdgroup"); + result.permittedTypes.add(builder.getUIntType()); + result.permittedTypes.add(builder.getUInt16Type()); + break; + } default: m_sink->diagnose( parentVar, @@ -3845,6 +3859,20 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext break; } + case SystemValueSemanticName::WaveLaneCount: + { + result.systemValueName = toSlice("subgroup_size"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + + case SystemValueSemanticName::WaveLaneIndex: + { + result.systemValueName = toSlice("subgroup_invocation_id"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + case SystemValueSemanticName::ViewID: case SystemValueSemanticName::ViewportArrayIndex: case SystemValueSemanticName::StartVertexLocation: @@ -3853,7 +3881,6 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext result.isUnsupported = true; break; } - default: { m_sink->diagnose( diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index e742f30936..0a7c3be8e7 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -68,6 +68,8 @@ void depointerizeInputParams(IRFunc* entryPoint); M(Target, SV_Target) \ M(StartVertexLocation, SV_StartVertexLocation) \ M(StartInstanceLocation, SV_StartInstanceLocation) \ + M(WaveLaneCount, SV_WaveLaneCount) \ + M(WaveLaneIndex, SV_WaveLaneIndex) \ /* end */ /// A known system-value semantic name that can be applied to a parameter diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-global-varying-var.cpp similarity index 98% rename from source/slang/slang-ir-translate-glsl-global-var.cpp rename to source/slang/slang-ir-translate-global-varying-var.cpp index 80ed3c3e4f..80f5c42c33 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-global-varying-var.cpp @@ -1,4 +1,4 @@ -#include "slang-ir-translate-glsl-global-var.h" +#include "slang-ir-translate-global-varying-var.h" #include "slang-ir-call-graph.h" #include "slang-ir-insts.h" @@ -152,8 +152,8 @@ struct GlobalVarTranslationContext builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input)); builder.addLayoutDecoration(inputParam, paramLayout); - // Initialize all global variables. - for (Index i = 0; i < inputVars.getCount(); i++) + // Initialize all global variables in the order of struct member declaration. + for (Index i = inputVars.getCount() - 1; i >= 0; i--) { auto input = inputVars[i]; setInsertBeforeOrdinaryInst(&builder, firstBlock->getFirstOrdinaryInst()); @@ -373,7 +373,7 @@ struct GlobalVarTranslationContext } }; -void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module) +void translateGlobalVaryingVar(CodeGenContext* context, IRModule* module) { GlobalVarTranslationContext ctx; ctx.context = context; diff --git a/source/slang/slang-ir-translate-global-varying-var.h b/source/slang/slang-ir-translate-global-varying-var.h new file mode 100644 index 0000000000..f976837003 --- /dev/null +++ b/source/slang/slang-ir-translate-global-varying-var.h @@ -0,0 +1,14 @@ +// slang-ir-translate-global-varying-var.h +#pragma once + +namespace Slang +{ + +struct IRModule; +struct CodeGenContext; + +/// Translate GLSL-flavored global in/out variables into +/// entry point parameters with system value semantics. +void translateGlobalVaryingVar(CodeGenContext* context, IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-translate-glsl-global-var.h b/source/slang/slang-ir-translate-glsl-global-var.h deleted file mode 100644 index 5821ba5c5d..0000000000 --- a/source/slang/slang-ir-translate-glsl-global-var.h +++ /dev/null @@ -1,17 +0,0 @@ -// slang-ir-translate-glsl-global-var.h -#ifndef SLANG_IR_TRANSLATE_GLSL_GLOBAL_VAR_H -#define SLANG_IR_TRANSLATE_GLSL_GLOBAL_VAR_H - -namespace Slang -{ - -struct IRModule; -struct CodeGenContext; - -/// Translate global in/out variables defined in GLSL-flavored code -/// into entry point parameters with system value semantics. -void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module); - -} // namespace Slang - -#endif // SLANG_IR_TRANSLATE_GLSL_GLOBAL_VAR_H diff --git a/tests/expected-failure-github.txt b/tests/expected-failure-github.txt index e9f634ea8f..10897b31e7 100644 --- a/tests/expected-failure-github.txt +++ b/tests/expected-failure-github.txt @@ -10,4 +10,6 @@ tests/autodiff/custom-intrinsic.slang.2 syn (wgpu) tests/bugs/buffer-swizzle-store.slang.3 syn (wgpu) tests/compute/interface-shader-param-in-struct.slang.4 syn (wgpu) tests/compute/interface-shader-param.slang.5 syn (wgpu) -tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu) \ No newline at end of file +tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu) +tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang.8 (mtl) +tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables-2.slang.3 (mtl) diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables-2.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables-2.slang new file mode 100644 index 0000000000..2e3896cc53 --- /dev/null +++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables-2.slang @@ -0,0 +1,36 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl + + +// There are some issues with the Metal backend when using glsl-style syntax - this should be fixed. +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-metal -compute -entry computeMain -allow-glsl + +#version 430 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +buffer MyBlockName2 +{ + int data[]; +} outputBuffer; + +layout(local_size_x = 4) in; + +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + // There may be some issues with structure padding for global context containing + // global builtin variables. + // int idx = gl_GlobalInvocationID.x; + int idx = dispatchThreadID.x; + + uint laneId = gl_SubgroupInvocationID; + // The laneCount will be dependent on target hardware. It seems a count of 1 is valid in spec. + // For now we'll just check it's not 0. + uint laneCount = gl_SubgroupSize; + outputBuffer.data[idx] = int(((laneCount > 0) ? 0x100 : 0) + laneId); + + // BUF: 100 + // BUF-NEXT: 101 + // BUF-NEXT: 102 + // BUF-NEXT: 103 +} diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang index 21b533178e..2d11ca5fb3 100644 --- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang +++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-builtin-variables.slang @@ -1,5 +1,5 @@ -//TEST:SIMPLE(filecheck=CHECK_GLSL): -allow-glsl -stage compute -entry computeMain -target glsl -//TEST:SIMPLE(filecheck=CHECK_SPV): -allow-glsl -stage compute -entry computeMain -target spirv -emit-spirv-directly +//TEST:SIMPLE(filecheck=CHECK_GLSL): -allow-glsl -stage compute -entry computeMain -target glsl -DTARGET_VK +//TEST:SIMPLE(filecheck=CHECK_SPV): -allow-glsl -stage compute -entry computeMain -target spirv -emit-spirv-directly -DTARGET_VK // missing implementation of most builtin values due to non trivial translation //DISABLE_TEST:SIMPLE(filecheck=CHECK_HLSL): -allow-glsl -stage compute -entry computeMain -target hlsl -DTARGET_HLSL @@ -8,8 +8,13 @@ //missing implementation of system (varying?) values //DISABLE_TEST:SIMPLE(filecheck=CHECK_CPP): -allow-glsl -stage compute -entry computeMain -target cpp -DTARGET_CPP -//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -xslang -DTARGET_VK +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly -xslang -DTARGET_VK +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl + +// There are some issues with the Metal backend when using glsl-style syntax - this should be fixed. +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-metal -compute -entry computeMain -allow-glsl -xslang -DTARGET_METAL + #version 430 //TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer @@ -24,15 +29,18 @@ void computeMain() { if (gl_GlobalInvocationID.x == 3) { outputBuffer.data[0] = true - && gl_NumSubgroups == 1 - && gl_SubgroupID == 0 //1 subgroup, 0 based indexing && gl_SubgroupSize == 32 && gl_SubgroupInvocationID == 3 + // These intrinsics are only available on Vulkan(SPIRV and GLSL). +#if defined(TARGET_VK) + && gl_SubgroupID == 0 //1 subgroup, 0 based indexing + && gl_NumSubgroups == 1 && gl_SubgroupEqMask == uvec4(0b1000,0,0,0) && gl_SubgroupGeMask == uvec4(0xFFFFFFF8,0,0,0) && gl_SubgroupGtMask == uvec4(0xFFFFFFF0,0,0,0) && gl_SubgroupLeMask == uvec4(0b1111,0,0,0) && gl_SubgroupLtMask == uvec4(0b111,0,0,0) +#endif ; } // CHECK_GLSL: void main( diff --git a/tests/hlsl-intrinsic/quad-control/quad-control-comp-functionality.slang b/tests/hlsl-intrinsic/quad-control/quad-control-comp-functionality.slang index 01771f7a40..77ca031782 100644 --- a/tests/hlsl-intrinsic/quad-control/quad-control-comp-functionality.slang +++ b/tests/hlsl-intrinsic/quad-control/quad-control-comp-functionality.slang @@ -8,11 +8,7 @@ RWStructuredBuffer<uint> outputBuffer; [numthreads(16, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { -#if !defined(METAL) uint index = WaveGetLaneIndex(); -#else - uint index = dispatchThreadID.x; -#endif if (index < 4) { diff --git a/tests/hlsl-intrinsic/wave-get-lane-index.slang b/tests/hlsl-intrinsic/wave-get-lane-index.slang index fb09022c23..e1a9262a97 100644 --- a/tests/hlsl-intrinsic/wave-get-lane-index.slang +++ b/tests/hlsl-intrinsic/wave-get-lane-index.slang @@ -4,6 +4,8 @@ //TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj //TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj //TEST:COMPARE_COMPUTE_EX:-cuda -compute -shaderobj +//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj +//TEST:COMPARE_COMPUTE_EX:-metal -compute -shaderobj //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer<int> outputBuffer; @@ -17,4 +19,4 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) // For now we'll just check it's not 0. uint laneCount = WaveGetLaneCount(); outputBuffer[idx] = int(((laneCount > 0) ? 0x100 : 0) + laneId); -} \ No newline at end of file +}