@@ -6,6 +6,15 @@ typedef uint UINT;
6
6
__intrinsic_op($(kIROp_RequireGLSLExtension))
7
7
void __requireGLSLExtension(String extensionName);
8
8
9
+ __intrinsic_op($(kIROp_RequireWGSLExtension))
10
+ void __requireWGSLExtension(String extensionName);
11
+
12
+ /// Built-in values or system value semantics represented as in/out global variables.
13
+ /// This allows the built-ins to be arbitrarily used from a global scope without being
14
+ /// explicitly passed as entry point parameters.
15
+ in uint __builtinWaveLaneIndex : SV_WaveLaneIndex;
16
+ in uint __builtinWaveLaneCount : SV_WaveLaneCount;
17
+
9
18
//@public:
10
19
/// Represents an interface for buffer data layout.
11
20
/// This interface is used as a base for defining specific data layouts for buffers.
@@ -15037,7 +15046,8 @@ uint WaveActiveCountBits(bool value)
15037
15046
__glsl_extension(GL_KHR_shader_subgroup_basic)
15038
15047
__spirv_version(1.3)
15039
15048
[NonUniformReturn]
15040
- [require(cuda_glsl_hlsl_spirv, subgroup_basic)]
15049
+ [ForceInline]
15050
+ [require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
15041
15051
uint WaveGetLaneCount()
15042
15052
{
15043
15053
__target_switch
@@ -15051,14 +15061,20 @@ uint WaveGetLaneCount()
15051
15061
OpCapability GroupNonUniform;
15052
15062
result:$$uint = OpLoad builtin(SubgroupSize:uint)
15053
15063
};
15064
+ case metal:
15065
+ return __builtinWaveLaneCount;
15066
+ case wgsl:
15067
+ __requireWGSLExtension("subgroups");
15068
+ return __builtinWaveLaneCount;
15054
15069
}
15055
15070
}
15056
15071
15057
15072
/// @category wave
15058
15073
__glsl_extension(GL_KHR_shader_subgroup_basic)
15059
15074
__spirv_version(1.3)
15060
15075
[NonUniformReturn]
15061
- [require(cuda_glsl_hlsl_spirv, subgroup_basic)]
15076
+ [ForceInline]
15077
+ [require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
15062
15078
uint WaveGetLaneIndex()
15063
15079
{
15064
15080
__target_switch
@@ -15072,6 +15088,11 @@ uint WaveGetLaneIndex()
15072
15088
OpCapability GroupNonUniform;
15073
15089
result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint)
15074
15090
};
15091
+ case metal:
15092
+ return __builtinWaveLaneIndex;
15093
+ case wgsl:
15094
+ __requireWGSLExtension("subgroups");
15095
+ return __builtinWaveLaneIndex;
15075
15096
}
15076
15097
}
15077
15098
0 commit comments