Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WaveGetLane* support for Metal and WGSL #6371

Merged
merged 26 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dd7e4f9
support WaveGetLane* for WGSL and Metal
fairywreath Feb 16, 2025
07bd6b7
update test and glsl support
fairywreath Feb 16, 2025
3f334df
address review comments and fix metal test
fairywreath Feb 16, 2025
fd4d4e3
add missing pragma guard
fairywreath Feb 16, 2025
87410a9
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 17, 2025
f2b97e9
update test
fairywreath Feb 20, 2025
8221661
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 20, 2025
260cccd
Revert "update test"
fairywreath Feb 20, 2025
70e3eb1
update failing glsl metal test and added new test
fairywreath Feb 20, 2025
b670683
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 20, 2025
51a565a
make hlsl and glsl outputs similar
fairywreath Feb 20, 2025
1db6f43
update test
fairywreath Feb 21, 2025
2824396
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 21, 2025
1c42aa1
disable tests for Metal and cleanup
fairywreath Feb 25, 2025
1fcc307
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 25, 2025
94eadbd
comment fix
fairywreath Feb 25, 2025
96a2e7b
add expected failures
fairywreath Feb 25, 2025
397865f
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 25, 2025
4dc0742
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 25, 2025
a8088ea
correct expected failures list
fairywreath Feb 26, 2025
bf11684
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 26, 2025
bb606e3
remove expected failure
fairywreath Feb 27, 2025
51256fa
Merge branch 'master' into wgsl-global-param-builtin
fairywreath Feb 27, 2025
9d6dd3f
add tests to expected failure
fairywreath Feb 27, 2025
a881dbe
Merge branch 'master' into wgsl-global-param-builtin
csyonghe Feb 28, 2025
0ae2841
Merge branch 'master' into wgsl-global-param-builtin
csyonghe Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions source/slang/glsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -6334,7 +6334,7 @@ void requireGLSLExtForSubgroupBasicBuiltin() {
}
}

[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
void setupExtForSubgroupBasicBuiltIn() {
__target_switch
{
Expand Down Expand Up @@ -6409,7 +6409,8 @@ public property uint gl_SubgroupID

public property uint gl_SubgroupSize
{
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
return WaveGetLaneCount();
Expand All @@ -6418,7 +6419,8 @@ public property uint gl_SubgroupSize

public property uint gl_SubgroupInvocationID
{
[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
return WaveGetLaneIndex();
Expand Down
25 changes: 23 additions & 2 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ typedef uint UINT;
__intrinsic_op($(kIROp_RequireGLSLExtension))
void __requireGLSLExtension(String extensionName);

__intrinsic_op($(kIROp_RequireWGSLExtension))
void __requireWGSLExtension(String extensionName);

/// Built-in values or system value semantics represented as in/out global variables.
/// This allows the built-ins to be arbitrarily used from a global scope without being
/// explicitly passed as entry point parameters.
in uint __builtinWaveLaneIndex : SV_WaveLaneIndex;
in uint __builtinWaveLaneCount : SV_WaveLaneCount;

//@public:
/// Represents an interface for buffer data layout.
/// This interface is used as a base for defining specific data layouts for buffers.
Expand Down Expand Up @@ -15037,7 +15046,8 @@ uint WaveActiveCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
[NonUniformReturn]
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
[ForceInline]
[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
uint WaveGetLaneCount()
{
__target_switch
Expand All @@ -15051,14 +15061,20 @@ uint WaveGetLaneCount()
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupSize:uint)
};
case metal:
return __builtinWaveLaneCount;
case wgsl:
__requireWGSLExtension("subgroups");
return __builtinWaveLaneCount;
}
}

/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
[NonUniformReturn]
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
[ForceInline]
[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
uint WaveGetLaneIndex()
{
__target_switch
Expand All @@ -15072,6 +15088,11 @@ uint WaveGetLaneIndex()
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint)
};
case metal:
return __builtinWaveLaneIndex;
case wgsl:
__requireWGSLExtension("subgroups");
return __builtinWaveLaneIndex;
}
}

Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3055,6 +3055,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
emitOperand(as<IRGlobalValueRef>(inst)->getOperand(0), getInfo(EmitOp::General));
break;
}
case kIROp_RequireWGSLExtension:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just use a single RequireTargetExtension for all targets?

{
emitRequireExtension(inst);
break;
}
default:
diagnoseUnhandledInst(inst);
break;
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ class CLikeSourceEmitter : public SourceEmitterBase
void _emitCallArgList(IRCall* call, int startingOperandIndex = 1);
virtual void emitCallArg(IRInst* arg);

virtual void emitRequireExtension(IRInst* inst) { SLANG_UNUSED(inst); }

String _generateUniqueName(const UnownedStringSlice& slice);

// Sort witnessTable entries according to the order defined in the witnessed interface type.
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
}
}

void WGSLSourceEmitter::emitRequireExtension(IRInst* inst)
{
_requireExtension(as<IRRequireWGSLExtension>(inst)->getExtensionName());
}

} // namespace Slang
2 changes: 2 additions & 0 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter
EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE;

virtual void emitRequireExtension(IRInst* inst) SLANG_OVERRIDE;

virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE;

void emit(const AddressSpace addressSpace);
Expand Down
10 changes: 5 additions & 5 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
#include "slang-ir-strip-default-construct.h"
#include "slang-ir-strip-legalization-insts.h"
#include "slang-ir-synthesize-active-mask.h"
#include "slang-ir-translate-glsl-global-var.h"
#include "slang-ir-translate-in-out-global-var.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can name it translate-global-varying-var.

#include "slang-ir-uniformity.h"
#include "slang-ir-user-type-hint.h"
#include "slang-ir-validate.h"
Expand Down Expand Up @@ -318,7 +318,7 @@ struct RequiredLoweringPassSet
bool bindingQuery;
bool meshOutput;
bool higherOrderFunc;
bool glslGlobalVar;
bool inOutGlobalVar;
bool glslSSBO;
bool byteAddressBuffer;
bool dynamicResource;
Expand Down Expand Up @@ -422,7 +422,7 @@ void calcRequiredLoweringPassSet(
case kIROp_GlobalInputDecoration:
case kIROp_GlobalOutputDecoration:
case kIROp_GetWorkGroupSize:
result.glslGlobalVar = true;
result.inOutGlobalVar = true;
break;
case kIROp_BindExistentialSlotsDecoration:
result.bindExistential = true;
Expand Down Expand Up @@ -641,8 +641,8 @@ Result linkAndOptimizeIR(
if (!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO)
lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink);

if (requiredLoweringPassSet.glslGlobalVar)
translateGLSLGlobalVar(codeGenContext, irModule);
if (requiredLoweringPassSet.inOutGlobalVar)
translateInOutGlobalVar(codeGenContext, irModule);

if (requiredLoweringPassSet.resolveVaryingInputRef)
resolveVaryingInputRef(irModule);
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-call-graph.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// slang-ir-call-graph.h
#pragma once

#include "slang-ir-clone.h"
#include "slang-ir-insts.h"

Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ INST(WitnessTableEntry, witness_table_entry, 2, 0)
INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL)

// An inst to represent the workgroup size of the calling entry point.
// We will materialize this inst during `translateGLSLGlobalVar`.
// We will materialize this inst during `translateGlobalInOutVar`.
INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE)

// An inst that returns the current stage of the calling entry point.
Expand Down Expand Up @@ -667,6 +667,7 @@ INST(discard, discard, 0, 0)

INST(RequirePrelude, RequirePrelude, 1, 0)
INST(RequireGLSLExtension, RequireGLSLExtension, 1, 0)
INST(RequireWGSLExtension, RequireWGSLExtension, 1, 0)
INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0)
INST(StaticAssert, StaticAssert, 2, 0)
INST(Printf, Printf, 1, 0)
Expand Down
9 changes: 9 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -3479,6 +3479,15 @@ struct IRRequireGLSLExtension : IRInst
}
};

struct IRRequireWGSLExtension : IRInst
{
IR_LEAF_ISA(RequireWGSLExtension)
UnownedStringSlice getExtensionName()
{
return as<IRStringLit>(getOperand(0))->getStringSlice();
}
};

struct IRRequireComputeDerivative : IRInst
{
IR_LEAF_ISA(RequireComputeDerivative)
Expand Down
29 changes: 28 additions & 1 deletion source/slang/slang-ir-legalize-varying-params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3228,6 +3228,20 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext
result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
break;
}
case SystemValueSemanticName::WaveLaneCount:
{
result.systemValueName = toSlice("threads_per_simdgroup");
result.permittedTypes.add(builder.getUIntType());
result.permittedTypes.add(builder.getUInt16Type());
break;
}
case SystemValueSemanticName::WaveLaneIndex:
{
result.systemValueName = toSlice("thread_index_in_simdgroup");
result.permittedTypes.add(builder.getUIntType());
result.permittedTypes.add(builder.getUInt16Type());
break;
}
default:
m_sink->diagnose(
parentVar,
Expand Down Expand Up @@ -3845,6 +3859,20 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
break;
}

case SystemValueSemanticName::WaveLaneCount:
{
result.systemValueName = toSlice("subgroup_size");
result.permittedTypes.add(builder.getUIntType());
break;
}

case SystemValueSemanticName::WaveLaneIndex:
{
result.systemValueName = toSlice("subgroup_invocation_id");
result.permittedTypes.add(builder.getUIntType());
break;
}

case SystemValueSemanticName::ViewID:
case SystemValueSemanticName::ViewportArrayIndex:
case SystemValueSemanticName::StartVertexLocation:
Expand All @@ -3853,7 +3881,6 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
result.isUnsupported = true;
break;
}

default:
{
m_sink->diagnose(
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-legalize-varying-params.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ void depointerizeInputParams(IRFunc* entryPoint);
M(Target, SV_Target) \
M(StartVertexLocation, SV_StartVertexLocation) \
M(StartInstanceLocation, SV_StartInstanceLocation) \
M(WaveLaneCount, SV_WaveLaneCount) \
M(WaveLaneIndex, SV_WaveLaneIndex) \
/* end */

/// A known system-value semantic name that can be applied to a parameter
Expand Down
17 changes: 0 additions & 17 deletions source/slang/slang-ir-translate-glsl-global-var.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "slang-ir-translate-glsl-global-var.h"
#include "slang-ir-translate-in-out-global-var.h"

#include "slang-ir-call-graph.h"
#include "slang-ir-insts.h"
Expand Down Expand Up @@ -373,7 +373,7 @@ struct GlobalVarTranslationContext
}
};

void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module)
void translateInOutGlobalVar(CodeGenContext* context, IRModule* module)
{
GlobalVarTranslationContext ctx;
ctx.context = context;
Expand Down
14 changes: 14 additions & 0 deletions source/slang/slang-ir-translate-in-out-global-var.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// slang-ir-translate-in-out-global-var.h
// #pragma once

namespace Slang
{

struct IRModule;
struct CodeGenContext;

/// Translate GLSL-flavored global in/out variables into
/// entry point parameters with system value semantics.
void translateInOutGlobalVar(CodeGenContext* context, IRModule* module);

} // namespace Slang
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//TEST:SIMPLE(filecheck=CHECK_GLSL): -allow-glsl -stage compute -entry computeMain -target glsl
//TEST:SIMPLE(filecheck=CHECK_SPV): -allow-glsl -stage compute -entry computeMain -target spirv -emit-spirv-directly
//TEST:SIMPLE(filecheck=CHECK_GLSL): -allow-glsl -stage compute -entry computeMain -target glsl -DTARGET_VK
//TEST:SIMPLE(filecheck=CHECK_SPV): -allow-glsl -stage compute -entry computeMain -target spirv -emit-spirv-directly -DTARGET_VK

// missing implementation of most builtin values due to non trivial translation
//DISABLE_TEST:SIMPLE(filecheck=CHECK_HLSL): -allow-glsl -stage compute -entry computeMain -target hlsl -DTARGET_HLSL
Expand All @@ -8,8 +8,11 @@
//missing implementation of system (varying?) values
//DISABLE_TEST:SIMPLE(filecheck=CHECK_CPP): -allow-glsl -stage compute -entry computeMain -target cpp -DTARGET_CPP

//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -xslang -DTARGET_VK
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly -xslang -DTARGET_VK
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-metal -compute -entry computeMain -allow-glsl

#version 430

//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
Expand All @@ -24,15 +27,17 @@ void computeMain()
{
if (gl_GlobalInvocationID.x == 3) {
outputBuffer.data[0] = true
&& gl_NumSubgroups == 1
&& gl_SubgroupID == 0 //1 subgroup, 0 based indexing
&& gl_SubgroupSize == 32
&& gl_SubgroupInvocationID == 3
#if defined(TARGET_VK)
&& gl_SubgroupID == 0 //1 subgroup, 0 based indexing
&& gl_NumSubgroups == 1
&& gl_SubgroupEqMask == uvec4(0b1000,0,0,0)
&& gl_SubgroupGeMask == uvec4(0xFFFFFFF8,0,0,0)
&& gl_SubgroupGtMask == uvec4(0xFFFFFFF0,0,0,0)
&& gl_SubgroupLeMask == uvec4(0b1111,0,0,0)
&& gl_SubgroupLtMask == uvec4(0b111,0,0,0)
#endif
;
}
// CHECK_GLSL: void main(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ RWStructuredBuffer<uint> outputBuffer;
[numthreads(16, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
#if !defined(METAL)
uint index = WaveGetLaneIndex();
#else
uint index = dispatchThreadID.x;
#endif

if (index < 4)
{
Expand Down
4 changes: 3 additions & 1 deletion tests/hlsl-intrinsic/wave-get-lane-index.slang
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-metal -compute -shaderobj

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
Expand All @@ -17,4 +19,4 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
// For now we'll just check it's not 0.
uint laneCount = WaveGetLaneCount();
outputBuffer[idx] = int(((laneCount > 0) ? 0x100 : 0) + laneId);
}
}
Loading