Skip to content

Commit dd7e4f9

Browse files
committed
support WaveGetLane* for WGSL and Metal
1 parent e043428 commit dd7e4f9

14 files changed

+102
-28
lines changed

source/slang/hlsl.meta.slang

+23-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ typedef uint UINT;
66
__intrinsic_op($(kIROp_RequireGLSLExtension))
77
void __requireGLSLExtension(String extensionName);
88

9+
__intrinsic_op($(kIROp_RequireWGSLExtension))
10+
void __requireWGSLExtension(String extensionName);
11+
12+
/// Built-in values or system value semantics represented as in/out global variables.
13+
/// This allows the built-ins to be arbitrarily used from a global scope without being
14+
/// explicitly passed as entry point parameters.
15+
in uint __builtinWaveLaneIndex : SV_WaveLaneIndex;
16+
in uint __builtinWaveLaneCount : SV_WaveLaneCount;
17+
918
//@public:
1019
/// Represents an interface for buffer data layout.
1120
/// This interface is used as a base for defining specific data layouts for buffers.
@@ -15037,7 +15046,8 @@ uint WaveActiveCountBits(bool value)
1503715046
__glsl_extension(GL_KHR_shader_subgroup_basic)
1503815047
__spirv_version(1.3)
1503915048
[NonUniformReturn]
15040-
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
15049+
[ForceInline]
15050+
[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
1504115051
uint WaveGetLaneCount()
1504215052
{
1504315053
__target_switch
@@ -15051,14 +15061,20 @@ uint WaveGetLaneCount()
1505115061
OpCapability GroupNonUniform;
1505215062
result:$$uint = OpLoad builtin(SubgroupSize:uint)
1505315063
};
15064+
case metal:
15065+
return __builtinWaveLaneCount;
15066+
case wgsl:
15067+
__requireWGSLExtension("subgroups");
15068+
return __builtinWaveLaneCount;
1505415069
}
1505515070
}
1505615071

1505715072
/// @category wave
1505815073
__glsl_extension(GL_KHR_shader_subgroup_basic)
1505915074
__spirv_version(1.3)
1506015075
[NonUniformReturn]
15061-
[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
15076+
[ForceInline]
15077+
[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
1506215078
uint WaveGetLaneIndex()
1506315079
{
1506415080
__target_switch
@@ -15072,6 +15088,11 @@ uint WaveGetLaneIndex()
1507215088
OpCapability GroupNonUniform;
1507315089
result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint)
1507415090
};
15091+
case metal:
15092+
return __builtinWaveLaneIndex;
15093+
case wgsl:
15094+
__requireWGSLExtension("subgroups");
15095+
return __builtinWaveLaneIndex;
1507515096
}
1507615097
}
1507715098

source/slang/slang-emit-c-like.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -3055,6 +3055,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
30553055
emitOperand(as<IRGlobalValueRef>(inst)->getOperand(0), getInfo(EmitOp::General));
30563056
break;
30573057
}
3058+
case kIROp_RequireWGSLExtension:
3059+
{
3060+
emitRequireExtension(inst);
3061+
break;
3062+
}
30583063
default:
30593064
diagnoseUnhandledInst(inst);
30603065
break;

source/slang/slang-emit-c-like.h

+2
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,8 @@ class CLikeSourceEmitter : public SourceEmitterBase
678678
void _emitCallArgList(IRCall* call, int startingOperandIndex = 1);
679679
virtual void emitCallArg(IRInst* arg);
680680

681+
virtual void emitRequireExtension(IRInst* inst) { SLANG_UNUSED(inst); }
682+
681683
String _generateUniqueName(const UnownedStringSlice& slice);
682684

683685
// Sort witnessTable entries according to the order defined in the witnessed interface type.

source/slang/slang-emit-wgsl.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
16961696
}
16971697
}
16981698

1699+
void WGSLSourceEmitter::emitRequireExtension(IRInst* inst)
1700+
{
1701+
_requireExtension(as<IRRequireWGSLExtension>(inst)->getExtensionName());
1702+
}
1703+
16991704
} // namespace Slang

source/slang/slang-emit-wgsl.h

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter
5757
EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
5858
virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE;
5959

60+
virtual void emitRequireExtension(IRInst* inst) SLANG_OVERRIDE;
61+
6062
virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE;
6163

6264
void emit(const AddressSpace addressSpace);

source/slang/slang-emit.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
#include "slang-ir-strip-default-construct.h"
101101
#include "slang-ir-strip-legalization-insts.h"
102102
#include "slang-ir-synthesize-active-mask.h"
103-
#include "slang-ir-translate-glsl-global-var.h"
103+
#include "slang-ir-translate-in-out-global-var.h"
104104
#include "slang-ir-uniformity.h"
105105
#include "slang-ir-user-type-hint.h"
106106
#include "slang-ir-validate.h"
@@ -318,7 +318,7 @@ struct RequiredLoweringPassSet
318318
bool bindingQuery;
319319
bool meshOutput;
320320
bool higherOrderFunc;
321-
bool glslGlobalVar;
321+
bool inOutGlobalVar;
322322
bool glslSSBO;
323323
bool byteAddressBuffer;
324324
bool dynamicResource;
@@ -422,7 +422,7 @@ void calcRequiredLoweringPassSet(
422422
case kIROp_GlobalInputDecoration:
423423
case kIROp_GlobalOutputDecoration:
424424
case kIROp_GetWorkGroupSize:
425-
result.glslGlobalVar = true;
425+
result.inOutGlobalVar = true;
426426
break;
427427
case kIROp_BindExistentialSlotsDecoration:
428428
result.bindExistential = true;
@@ -641,8 +641,8 @@ Result linkAndOptimizeIR(
641641
if (!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO)
642642
lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink);
643643

644-
if (requiredLoweringPassSet.glslGlobalVar)
645-
translateGLSLGlobalVar(codeGenContext, irModule);
644+
if (requiredLoweringPassSet.inOutGlobalVar)
645+
translateInOutGlobalVar(codeGenContext, irModule);
646646

647647
if (requiredLoweringPassSet.resolveVaryingInputRef)
648648
resolveVaryingInputRef(irModule);

source/slang/slang-ir-call-graph.h

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// slang-ir-call-graph.h
2+
#pragma once
3+
14
#include "slang-ir-clone.h"
25
#include "slang-ir-insts.h"
36

source/slang/slang-ir-inst-defs.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ INST(WitnessTableEntry, witness_table_entry, 2, 0)
407407
INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL)
408408

409409
// An inst to represent the workgroup size of the calling entry point.
410-
// We will materialize this inst during `translateGLSLGlobalVar`.
410+
// We will materialize this inst during `translateGlobalInOutVar`.
411411
INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE)
412412

413413
// An inst that returns the current stage of the calling entry point.
@@ -667,6 +667,7 @@ INST(discard, discard, 0, 0)
667667

668668
INST(RequirePrelude, RequirePrelude, 1, 0)
669669
INST(RequireGLSLExtension, RequireGLSLExtension, 1, 0)
670+
INST(RequireWGSLExtension, RequireWGSLExtension, 1, 0)
670671
INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0)
671672
INST(StaticAssert, StaticAssert, 2, 0)
672673
INST(Printf, Printf, 1, 0)

source/slang/slang-ir-insts.h

+9
Original file line numberDiff line numberDiff line change
@@ -3479,6 +3479,15 @@ struct IRRequireGLSLExtension : IRInst
34793479
}
34803480
};
34813481

3482+
struct IRRequireWGSLExtension : IRInst
3483+
{
3484+
IR_LEAF_ISA(RequireWGSLExtension)
3485+
UnownedStringSlice getExtensionName()
3486+
{
3487+
return as<IRStringLit>(getOperand(0))->getStringSlice();
3488+
}
3489+
};
3490+
34823491
struct IRRequireComputeDerivative : IRInst
34833492
{
34843493
IR_LEAF_ISA(RequireComputeDerivative)

source/slang/slang-ir-legalize-varying-params.cpp

+28-1
Original file line numberDiff line numberDiff line change
@@ -3228,6 +3228,20 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext
32283228
result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
32293229
break;
32303230
}
3231+
case SystemValueSemanticName::WaveLaneCount:
3232+
{
3233+
result.systemValueName = toSlice("threads_per_simdgroup");
3234+
result.permittedTypes.add(builder.getUIntType());
3235+
result.permittedTypes.add(builder.getUInt16Type());
3236+
break;
3237+
}
3238+
case SystemValueSemanticName::WaveLaneIndex:
3239+
{
3240+
result.systemValueName = toSlice("thread_index_in_simdgroup");
3241+
result.permittedTypes.add(builder.getUIntType());
3242+
result.permittedTypes.add(builder.getUInt16Type());
3243+
break;
3244+
}
32313245
default:
32323246
m_sink->diagnose(
32333247
parentVar,
@@ -3845,6 +3859,20 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
38453859
break;
38463860
}
38473861

3862+
case SystemValueSemanticName::WaveLaneCount:
3863+
{
3864+
result.systemValueName = toSlice("subgroup_size");
3865+
result.permittedTypes.add(builder.getUIntType());
3866+
break;
3867+
}
3868+
3869+
case SystemValueSemanticName::WaveLaneIndex:
3870+
{
3871+
result.systemValueName = toSlice("subgroup_invocation_id");
3872+
result.permittedTypes.add(builder.getUIntType());
3873+
break;
3874+
}
3875+
38483876
case SystemValueSemanticName::ViewID:
38493877
case SystemValueSemanticName::ViewportArrayIndex:
38503878
case SystemValueSemanticName::StartVertexLocation:
@@ -3853,7 +3881,6 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
38533881
result.isUnsupported = true;
38543882
break;
38553883
}
3856-
38573884
default:
38583885
{
38593886
m_sink->diagnose(

source/slang/slang-ir-legalize-varying-params.h

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ void depointerizeInputParams(IRFunc* entryPoint);
6868
M(Target, SV_Target) \
6969
M(StartVertexLocation, SV_StartVertexLocation) \
7070
M(StartInstanceLocation, SV_StartInstanceLocation) \
71+
M(WaveLaneCount, SV_WaveLaneCount) \
72+
M(WaveLaneIndex, SV_WaveLaneIndex) \
7173
/* end */
7274

7375
/// A known system-value semantic name that can be applied to a parameter

source/slang/slang-ir-translate-glsl-global-var.h

-17
This file was deleted.

source/slang/slang-ir-translate-glsl-global-var.cpp source/slang/slang-ir-translate-in-out-global-var.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "slang-ir-translate-glsl-global-var.h"
1+
#include "slang-ir-translate-in-out-global-var.h"
22

33
#include "slang-ir-call-graph.h"
44
#include "slang-ir-insts.h"
@@ -373,7 +373,7 @@ struct GlobalVarTranslationContext
373373
}
374374
};
375375

376-
void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module)
376+
void translateInOutGlobalVar(CodeGenContext* context, IRModule* module)
377377
{
378378
GlobalVarTranslationContext ctx;
379379
ctx.context = context;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// slang-ir-translate-in-out-global-var.h
2+
// #pragma once
3+
4+
namespace Slang
5+
{
6+
7+
struct IRModule;
8+
struct CodeGenContext;
9+
10+
/// Translate GLSL-flavored global in/out variables into
11+
/// entry point parameters with system value semantics.
12+
void translateInOutGlobalVar(CodeGenContext* context, IRModule* module);
13+
14+
} // namespace Slang

0 commit comments

Comments
 (0)