Skip to content

Commit 7dabfa7

Browse files
Implement explciit binding for metal and wgsl. (#5778)
* Respect explicit bindings in wgsl emit. * Implement explciit binding generation for metal and wgsl. * Update toc. * Fix warnings in tests. * Fix tests. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
1 parent ecc5a39 commit 7dabfa7

30 files changed

+83
-27
lines changed

docs/user-guide/a2-02-metal-target-specific.md

+8
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,11 @@ Metal requires explicit address space qualifiers. Slang automatically assigns ap
267267
| RW/Structured Buffers | `device` |
268268
| Group Shared | `threadgroup` |
269269
| Parameter Blocks | `constant` |
270+
271+
## Explicit Parameter Binding
272+
273+
The HLSL `:register()` semantic is respected when emitting Metal code.
274+
275+
Since metal does not differentiate a constant buffer, a shader resource (read-only) buffer and an unordered access buffer, Slang will map `register(tN)`, `register(uN)` and `register(bN)` to `[[buffer(N)]]` when such `register` semantic is declared on a buffer typed parameter.
276+
277+
`spaceN` specifiers inside `register` semantics are ignored.

docs/user-guide/a2-03-wgsl-target-specific.md

+6
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,9 @@ Matrix type translation
156156
A m-row-by-n-column matrix in Slang, represented as float`m`x`n` or matrix<T, m, n>, is translated to `mat[n]x[m]` in WGSL, i.e. a matrix with `n` columns and `m` rows.
157157
The rationale for this inversion of terminology is the same as [the rationale for SPIR-V](a2-01-spirv-target-specific.md#matrix-type-translation).
158158
Since the WGSL matrix multiplication convention is the normal one, where inner products of rows of the matrix on the left are taken with columns of the matrix on the right, the order of matrix products is also reversed in WGSL. This is relying on the fact that the transpose of a matrix product equals the product of the transposed matrix operands in reverse order.
159+
160+
## Explicit Parameter Binding
161+
162+
The `[vk::binding(index,set)]` attribute is respected when emitting WGSL code, and will translate to `@binding(index) @group(set)` in WGSL.
163+
164+
If the `[vk::binding()]` attribute is not specified by a `:register()` semantic is present, Slang will derive the binding from the `register` semantic the same way as the SPIRV and GLSL backends.

docs/user-guide/toc.html

+2
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
<li data-link="metal-target-specific#value-type-conversion"><span>Value Type Conversion</span></li>
234234
<li data-link="metal-target-specific#conservative-rasterization"><span>Conservative Rasterization</span></li>
235235
<li data-link="metal-target-specific#address-space-assignment"><span>Address Space Assignment</span></li>
236+
<li data-link="metal-target-specific#explicit-parameter-binding"><span>Explicit Parameter Binding</span></li>
236237
</ul>
237238
</li>
238239
<li data-link="wgsl-target-specific"><span>WGSL specific functionalities</span>
@@ -249,6 +250,7 @@
249250
<li data-link="wgsl-target-specific#pointers"><span>Pointers</span></li>
250251
<li data-link="wgsl-target-specific#address-space-assignment"><span>Address Space Assignment</span></li>
251252
<li data-link="wgsl-target-specific#matrix-type-translation"><span>Matrix type translation</span></li>
253+
<li data-link="wgsl-target-specific#explicit-parameter-binding"><span>Explicit Parameter Binding</span></li>
252254
</ul>
253255
</li>
254256
</ul>

source/slang/slang-emit-wgsl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
663663

664664
EmitVarChain chain = {};
665665
chain.varLayout = layout;
666-
auto space = getBindingSpaceForKinds(&chain, kind);
666+
auto space = getBindingSpaceForKinds(&chain, LayoutResourceKindFlag::make(kind));
667667
m_writer->emit("@group(");
668668
m_writer->emit(space);
669669
m_writer->emit(") ");

source/slang/slang-parameter-binding.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,8 @@ static void addExplicitParameterBindings_HLSL(
984984
//
985985
// For now we do the filtering on target in a very direct fashion:
986986
//
987-
if (!isD3DTarget(context->getTargetRequest()) && !isMetalTarget(context->getTargetRequest()))
987+
bool isMetal = isMetalTarget(context->getTargetRequest());
988+
if (!isD3DTarget(context->getTargetRequest()) && !isMetal)
988989
return;
989990

990991
auto typeLayout = varLayout->typeLayout;
@@ -1018,13 +1019,27 @@ static void addExplicitParameterBindings_HLSL(
10181019
if (kind == LayoutResourceKind::None)
10191020
continue;
10201021

1022+
10211023
// TODO: need to special-case when this is a `c` register binding...
10221024

10231025
// Find the appropriate resource-binding information
10241026
// inside the type, to see if we even use any resources
10251027
// of the given kind.
10261028

10271029
auto typeRes = typeLayout->FindResourceInfo(kind);
1030+
if (isMetal && !typeRes)
1031+
{
1032+
// Metal doesn't distinguish a unordered access and a readonly/uniform buffer.
1033+
switch (kind)
1034+
{
1035+
case LayoutResourceKind::UnorderedAccess:
1036+
case LayoutResourceKind::ShaderResource:
1037+
semanticInfo.kind = LayoutResourceKind::MetalBuffer;
1038+
typeRes = typeLayout->FindResourceInfo(LayoutResourceKind::MetalBuffer);
1039+
break;
1040+
}
1041+
}
1042+
10281043
LayoutSize count = 0;
10291044
if (typeRes)
10301045
{
@@ -1073,7 +1088,7 @@ static void addExplicitParameterBindings_GLSL(
10731088
// so that we are able to distinguish between
10741089
// Vulkan and OpenGL as targets.
10751090
//
1076-
if (!isKhronosTarget(context->getTargetRequest()))
1091+
if (!isKhronosTarget(context->getTargetRequest()) && !isWGPUTarget(context->getTargetRequest()))
10771092
return;
10781093

10791094
auto typeLayout = varLayout->typeLayout;

tests/bugs/gh-471.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ int test(int inVal)
2222
return x * 16;
2323
}
2424

25-
RWStructuredBuffer<int> outputBuffer : register(u0);
25+
RWStructuredBuffer<int> outputBuffer;
2626

2727
[numthreads(4, 1, 1)]
2828
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/bugs/gh-775.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ int test(int inVal)
1818
}
1919

2020
//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name outputBuffer
21-
RWStructuredBuffer<int> outputBuffer : register(u0);
21+
RWStructuredBuffer<int> outputBuffer;
2222

2323
[numthreads(4, 1, 1)]
2424
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/bugs/static-method.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct S
1111
}
1212

1313
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
14-
RWStructuredBuffer<int> outputBuffer : register(u0);
14+
RWStructuredBuffer<int> outputBuffer;
1515

1616
int test(int t)
1717
{

tests/bugs/static-var.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ int test(int inVal)
88
}
99

1010
//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name outputBuffer
11-
RWStructuredBuffer<int> outputBuffer : register(u0);
11+
RWStructuredBuffer<int> outputBuffer;
1212

1313
[numthreads(4, 1, 1)]
1414
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/bugs/texture2d-gather.hlsl

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
//TEST_INPUT: Texture2D(size=16, content=chessboard, format=R32_FLOAT):name g_texture
44
//TEST_INPUT: Sampler :name g_sampler
55

6-
Texture2D<float> g_texture : register(t0);
7-
SamplerState g_sampler : register(s0);
6+
Texture2D<float> g_texture;
7+
8+
SamplerState g_sampler;
89

910
cbuffer Uniforms
1011
{

tests/bugs/type-legalize-bug-1.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
//TEST_INPUT:type_conformance A:IFoo=0
66
//TEST_INPUT:type_conformance B:IFoo=1
77

8-
RWStructuredBuffer<int> outputBuffer : register(u0);
8+
RWStructuredBuffer<int> outputBuffer;
99
interface IFoo
1010
{
1111
associatedtype T : IFoo;

tests/compute/break-stmt.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test(int inVal)
1616
}
1717

1818
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
19-
RWStructuredBuffer<int> outputBuffer : register(u0);
19+
RWStructuredBuffer<int> outputBuffer;
2020

2121
[numthreads(4, 1, 1)]
2222
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/continue-stmt.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ int test(int inVal)
2121
}
2222

2323
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
24-
RWStructuredBuffer<int> outputBuffer : register(u0);
24+
RWStructuredBuffer<int> outputBuffer;
2525

2626
[numthreads(4, 1, 1)]
2727
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/default-initializer.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ int test(int value)
2323
}
2424

2525
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
26-
RWStructuredBuffer<int> outputBuffer : register(u0);
26+
RWStructuredBuffer<int> outputBuffer;
2727

2828
[numthreads(4, 1, 1)]
2929
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/explicit-this-expr.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct A
1616
};
1717

1818
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
19-
RWStructuredBuffer<float> outputBuffer : register(u0);
19+
RWStructuredBuffer<float> outputBuffer;
2020

2121

2222
float test(float inVal)

tests/compute/generics-constrained.slang

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ float testHelp(T helper)
2828
}
2929

3030
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
31+
[vk::binding(0, 0)]
3132
RWStructuredBuffer<float> outputBuffer : register(u0);
3233

3334

tests/compute/global-init.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ int test(int inVal)
1212
}
1313

1414
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
15-
RWStructuredBuffer<int> outputBuffer : register(u0);
15+
RWStructuredBuffer<int> outputBuffer;
1616

1717
[numthreads(4, 1, 1)]
1818
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/implicit-generic-app.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ int test(int val)
3030
}
3131

3232
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
33-
RWStructuredBuffer<int> outputBuffer : register(u0);
33+
RWStructuredBuffer<int> outputBuffer;
3434

3535
[numthreads(4, 1, 1)]
3636
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/implicit-this-expr.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct A
1515
};
1616

1717
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
18-
RWStructuredBuffer<float> outputBuffer : register(u0);
18+
RWStructuredBuffer<float> outputBuffer;
1919

2020
float test(float inVal)
2121
{

tests/compute/init-list-defaults.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int test(int inVal)
2424
}
2525

2626
//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer
27-
RWStructuredBuffer<int> outputBuffer : register(u0);
27+
RWStructuredBuffer<int> outputBuffer;
2828

2929
[numthreads(4, 1, 1)]
3030
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/inout.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ int test(int inVal)
3636
}
3737

3838
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
39-
RWStructuredBuffer<int> outputBuffer : register(u0);
39+
RWStructuredBuffer<int> outputBuffer;
4040

4141
[numthreads(4, 1, 1)]
4242
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/multiple-continue-sites.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ int test(int inVal)
2828
}
2929

3030
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
31-
RWStructuredBuffer<int> outputBuffer : register(u0);
31+
RWStructuredBuffer<int> outputBuffer;
3232

3333
[numthreads(4, 1, 1)]
3434
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/struct-default-init.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ int test(int inVal)
2626
}
2727

2828
//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer
29-
RWStructuredBuffer<int> outputBuffer : register(u0);
29+
RWStructuredBuffer<int> outputBuffer;
3030

3131
[numthreads(4, 1, 1)]
3232
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/switch-stmt.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ int test(int inVal)
3030
}
3131

3232
//TEST_INPUT:ubuffer(data=[0 1 2 3 4 5 6 7], stride=4):out,name=outputBuffer
33-
RWStructuredBuffer<int> outputBuffer : register(u0);
33+
RWStructuredBuffer<int> outputBuffer;
3434

3535
[numthreads(8, 1, 1)]
3636
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/this-type.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ int test(int value)
3636
}
3737

3838
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
39-
RWStructuredBuffer<int> outputBuffer : register(u0);
39+
RWStructuredBuffer<int> outputBuffer;
4040

4141
[numthreads(4, 1, 1)]
4242
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/compute/user-defined-initializer.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ int test(int value)
2828
}
2929

3030
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
31-
RWStructuredBuffer<int> outputBuffer : register(u0);
31+
RWStructuredBuffer<int> outputBuffer;
3232

3333
[numthreads(4, 1, 1)]
3434
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/language-feature/properties/property-in-interface.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ int test(int value)
4646
}
4747

4848
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
49-
RWStructuredBuffer<int> outputBuffer : register(u0);
49+
RWStructuredBuffer<int> outputBuffer;
5050

5151
[numthreads(4, 1, 1)]
5252
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

tests/preprocessor/line-macro.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type
22

33
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
4-
RWStructuredBuffer<int> outputBuffer : register(u0);
4+
RWStructuredBuffer<int> outputBuffer;
55

66
#define T(x) x
77
#define LL T(__LINE__)

tests/serialization/std-lib-serialize.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct A
1212
};
1313

1414
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
15-
RWStructuredBuffer<float> outputBuffer : register(u0);
15+
RWStructuredBuffer<float> outputBuffer;
1616

1717

1818
float test(float inVal)

tests/wgsl/explicit-binding.slang

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//TEST:SIMPLE(filecheck=METAL): -target metal
2+
//TEST:SIMPLE(filecheck=CHECK): -target wgsl -entry computeMain -stage compute
3+
4+
// CHECK-DAG: @binding(9) @group(7)
5+
// CHECK-DAG: @binding(3) @group(4)
6+
// CHECK-DAG: @binding(1) @group(2)
7+
8+
// METAL-DAG: buffer(9)
9+
// METAL-DAG: texture(7)
10+
11+
[vk::binding(1, 2)]
12+
Texture2D texA : register(t7);
13+
14+
[vk::binding(3, 4)]
15+
ConstantBuffer<float> cb;
16+
17+
RWStructuredBuffer<float> ob : register(u9, space7);
18+
19+
[numthreads(1,1,1)]
20+
void computeMain()
21+
{
22+
ob[0] = cb + texA.Load(int3(0)).x;
23+
}

0 commit comments

Comments
 (0)