Skip to content

Commit f114433

Browse files
authored
Support parameter block in metal shader objects. (shader-slang#4671)
* Support parameter block in metal shader objects. * Ingore parameter block tests on devices without tier2 argument buffer. * Fix warning. * Fix texture subscript test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent adf758c commit f114433

24 files changed

+317
-52
lines changed

include/slang.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -2316,7 +2316,7 @@ extern "C"
23162316
// The input_attachment_index subpass occupancy tracker
23172317
SLANG_PARAMETER_CATEGORY_SUBPASS,
23182318

2319-
// Metal resource binding points.
2319+
// Metal tier-1 argument buffer element [[id]].
23202320
SLANG_PARAMETER_CATEGORY_METAL_ARGUMENT_BUFFER_ELEMENT,
23212321

23222322
// Metal [[attribute]] inputs.
@@ -2398,6 +2398,7 @@ extern "C"
23982398
enum SlangLayoutRules : SlangLayoutRulesIntegral
23992399
{
24002400
SLANG_LAYOUT_RULES_DEFAULT,
2401+
SLANG_LAYOUT_RULES_METAL_ARGUMENT_BUFFER_TIER_2,
24012402
};
24022403

24032404
typedef SlangUInt32 SlangModifierIDIntegral;
@@ -3585,6 +3586,7 @@ namespace slang
35853586
enum class LayoutRules : SlangLayoutRulesIntegral
35863587
{
35873588
Default = SLANG_LAYOUT_RULES_DEFAULT,
3589+
MetalArgumentBufferTier2 = SLANG_LAYOUT_RULES_METAL_ARGUMENT_BUFFER_TIER_2,
35883590
};
35893591

35903592
typedef struct ShaderReflection ProgramLayout;

source/slang/slang-compiler.h

+19-3
Original file line numberDiff line numberDiff line change
@@ -1787,11 +1787,27 @@ namespace Slang
17871787
CodeGenTarget getTarget() { return optionSet.getEnumOption<CodeGenTarget>(CompilerOptionName::Target); }
17881788

17891789
// TypeLayouts created on the fly by reflection API
1790-
Dictionary<Type*, RefPtr<TypeLayout>> typeLayouts;
1790+
struct TypeLayoutKey
1791+
{
1792+
Type* type;
1793+
slang::LayoutRules rules;
1794+
HashCode getHashCode() const
1795+
{
1796+
Hasher hasher;
1797+
hasher.hashValue(type);
1798+
hasher.hashValue(rules);
1799+
return hasher.getResult();
1800+
}
1801+
bool operator==(TypeLayoutKey other) const
1802+
{
1803+
return type == other.type && rules == other.rules;
1804+
}
1805+
};
1806+
Dictionary<TypeLayoutKey, RefPtr<TypeLayout>> typeLayouts;
17911807

1792-
Dictionary<Type*, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; }
1808+
Dictionary<TypeLayoutKey, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; }
17931809

1794-
TypeLayout* getTypeLayout(Type* type);
1810+
TypeLayout* getTypeLayout(Type* type, slang::LayoutRules rules);
17951811

17961812
CompilerOptionSet& getOptionSet() { return optionSet; }
17971813

source/slang/slang-parameter-binding.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -3892,7 +3892,7 @@ RefPtr<ProgramLayout> generateParameterBindings(
38923892
}
38933893

38943894
// Try to find rules based on the selected code-generation target
3895-
auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout);
3895+
auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout, slang::LayoutRules::Default);
38963896

38973897
// If there was no target, or there are no rules for the target,
38983898
// then bail out here.

source/slang/slang-reflection-api.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -810,13 +810,13 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re
810810
SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout(
811811
SlangReflection* reflection,
812812
SlangReflectionType* inType,
813-
SlangLayoutRules /*rules*/)
813+
SlangLayoutRules rules)
814814
{
815815
auto context = convert(reflection);
816816
auto type = convert(inType);
817817
auto targetReq = context->getTargetReq();
818818

819-
auto typeLayout = targetReq->getTypeLayout(type);
819+
auto typeLayout = targetReq->getTypeLayout(type, (slang::LayoutRules)rules);
820820
return convert(typeLayout);
821821
}
822822

@@ -1875,6 +1875,7 @@ namespace Slang
18751875
case LayoutResourceKind::DescriptorTableSlot:
18761876
case LayoutResourceKind::Uniform:
18771877
case LayoutResourceKind::ConstantBuffer: // for metal
1878+
case LayoutResourceKind::MetalArgumentBufferElement:
18781879
resInfo = info;
18791880
break;
18801881
}

source/slang/slang-type-layout.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -1811,11 +1811,21 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe
18111811
}
18121812
}
18131813

1814-
TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout)
1814+
TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout, slang::LayoutRules rules)
18151815
{
18161816
auto astBuilder = targetReq->getLinkage()->getASTBuilder();
18171817

1818-
LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq);
1818+
LayoutRulesFamilyImpl* rulesFamily;
1819+
switch (rules)
1820+
{
1821+
case slang::LayoutRules::Default:
1822+
default:
1823+
rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq);
1824+
break;
1825+
case slang::LayoutRules::MetalArgumentBufferTier2:
1826+
rulesFamily = &kCPULayoutRulesFamilyImpl;
1827+
break;
1828+
}
18191829

18201830
TypeLayoutContext context;
18211831
context.astBuilder = astBuilder;

source/slang/slang-type-layout.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,8 @@ struct StructTypeLayoutBuilder
12881288
//
12891289
TypeLayoutContext getInitialLayoutContextForTarget(
12901290
TargetRequest* targetRequest,
1291-
ProgramLayout* programLayout);
1291+
ProgramLayout* programLayout,
1292+
slang::LayoutRules rules);
12921293

12931294
/// Direction(s) of a varying shader parameter
12941295
typedef unsigned int EntryPointParameterDirectionMask;

source/slang/slang.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,7 @@ SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout(
13681368
//
13691369
SLANG_UNUSED(rules);
13701370

1371-
auto typeLayout = target->getTypeLayout(type);
1371+
auto typeLayout = target->getTypeLayout(type, rules);
13721372

13731373
// TODO: We currently don't have a path for capturing
13741374
// errors that occur during layout (e.g., types that
@@ -1827,7 +1827,7 @@ CapabilitySet TargetRequest::getTargetCaps()
18271827
}
18281828

18291829

1830-
TypeLayout* TargetRequest::getTypeLayout(Type* type)
1830+
TypeLayout* TargetRequest::getTypeLayout(Type* type, slang::LayoutRules rules)
18311831
{
18321832
SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder());
18331833

@@ -1841,13 +1841,14 @@ TypeLayout* TargetRequest::getTypeLayout(Type* type)
18411841
// parameter instead (leaving the user to figure out how that
18421842
// maps to the ordering via some API on the program layout).
18431843
//
1844-
auto layoutContext = getInitialLayoutContextForTarget(this, nullptr);
1844+
auto layoutContext = getInitialLayoutContextForTarget(this, nullptr, rules);
18451845

18461846
RefPtr<TypeLayout> result;
1847-
if (getTypeLayouts().tryGetValue(type, result))
1847+
auto key = TypeLayoutKey{ type, rules };
1848+
if (getTypeLayouts().tryGetValue(key, result))
18481849
return result.Ptr();
18491850
result = createTypeLayout(layoutContext, type);
1850-
getTypeLayouts()[type] = result;
1851+
getTypeLayouts()[key] = result;
18511852
return result.Ptr();
18521853
}
18531854

tests/autodiff/global-param-hoisting.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
88
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
99
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
10-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
10+
//TEST(compute):COMPARE_COMPUTE:-slang -compute -mtl -output-using-type -render-features argument-buffer-tier-2
1111

1212
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
1313
RWStructuredBuffer<float> outputBuffer;

tests/bindings/nested-parameter-block-2.slang

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
22
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12 -use-dxil -shaderobj -output-using-type
33
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
4+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal -shaderobj -output-using-type -render-features argument-buffer-tier-2
45
// nested-parameter-block-2.slang
56

67
struct CB

tests/bugs/buffer-swizzle-store.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
22
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
3+
//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
44

55
//TEST_INPUT: RWTexture2D(format=R16G16_FLOAT, size=4, content = one, mipMaps = 1):name g_test
66
[format("rg16f")]

tests/compute/entry-point-uniform-params.slang

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
//TEST(compute):COMPARE_COMPUTE: -dx11 -shaderobj
1010
//TEST(compute):COMPARE_COMPUTE: -cuda -shaderobj
1111
//TEST(compute):COMPARE_COMPUTE: -cpu -shaderobj
12-
13-
12+
//TEST(compute):COMPARE_COMPUTE: -metal -shaderobj
1413

1514
struct Signs
1615
{

tests/compute/parameter-block.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj
33
//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
44
//TEST(compute):COMPARE_COMPUTE:-shaderobj
5-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
5+
//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -render-features argument-buffer-tier-2
66

77
// Ensure that Slang `ParameterBlock` type is lowered
88
// to HLSL in the fashion that we expect.

tests/compute/texture-subscript.slang

+17-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//TEST:SIMPLE(filecheck=METALLIB): -target metallib -entry computeMain -stage compute
33
// Metal lacks RWTexture GFX backend support.
44
// Due to this, Metal compute test is disabled
5-
//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF): -slang -output-using-type -shaderobj -mtl
5+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF): -slang -output-using-type -shaderobj -mtl
66
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF): -slang -output-using-type -shaderobj -vk
77
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF): -slang -output-using-type -shaderobj -vk -glsl
88

@@ -28,18 +28,22 @@ RWStructuredBuffer<uint> outputBuffer;
2828
[numthreads(1,1,1)]
2929
void computeMain()
3030
{
31-
outputTexture1D[0].xz = int2(1,2).xx;
32-
outputTexture1D[1].x = int2(3,4).y;
33-
34-
outputTexture2D[0].xz = int2(1,2).xx;
35-
outputTexture2D[int2(0, 1)].x = int2(3,4).y;
36-
37-
outputTexture3D[0].xz = int2(1,2).xx;
38-
outputTexture3D[int3(0, 0, 1)].x = int2(3,4).y;
39-
40-
outputTexture2DArray[0].xz = int2(1,2);
41-
outputTexture2DArray[int3(0, 0, 1)].xz = int2(3,4);
42-
31+
outputTexture1D[0].xz = int2(1, 2).xx;
32+
AllMemoryBarrier();
33+
outputTexture1D[1].x = int2(3, 4).y;
34+
AllMemoryBarrier();
35+
outputTexture2D[0].xz = int2(1, 2).xx;
36+
AllMemoryBarrier();
37+
outputTexture2D[int2(0, 1)].x = int2(3, 4).y;
38+
AllMemoryBarrier();
39+
outputTexture3D[0].xz = int2(1, 2).xx;
40+
AllMemoryBarrier();
41+
outputTexture3D[int3(0, 0, 1)].x = int2(3, 4).y;
42+
AllMemoryBarrier();
43+
outputTexture2DArray[0].xz = int2(1, 2);
44+
AllMemoryBarrier();
45+
outputTexture2DArray[int3(0, 0, 1)].xz = int2(3, 4);
46+
AllMemoryBarrier();
4347
outputBuffer[0] = uint(true
4448
&& all(outputTexture1D[0] == int4(1, 0, 1, 0)) == true
4549
&& all(outputTexture1D[1] == int4(4, 0, 0, 0)) == true

tests/language-feature/shader-params/entry-point-uniform-params.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//TEST(compute):COMPARE_COMPUTE: -shaderobj
44
//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj
55
//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj
6-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
6+
//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
77

88
// Test that a shader can be written that
99
// only uses entry point `uniform` parameters,

tests/language-feature/types/opaque/inout-param-opaque-type-in-struct.slang

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
// aggregate type that includes an opaque type
55

66
//TEST(compute):COMPARE_COMPUTE:
7-
// GFX backend fails
8-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
7+
//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
98

109
struct Things
1110
{

tests/language-feature/types/opaque/out-param-opaque-type-in-struct.slang

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
// aggregate type that includes an opaque type
55

66
//TEST(compute):COMPARE_COMPUTE:
7-
// GFX backend fails
8-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
7+
//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
98

109
struct Things
1110
{

tests/language-feature/types/opaque/return-opaque-type-in-struct.slang

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
// aggregate type that includes an opaque type
55

66
//TEST(compute):COMPARE_COMPUTE:
7-
// GFX backend fails
8-
//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
7+
//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
98

109
struct Things
1110
{

tools/gfx/metal/metal-command-encoder.cpp

+15-5
Original file line numberDiff line numberDiff line change
@@ -478,18 +478,28 @@ Result ComputeCommandEncoder::bindPipelineWithRootObject(
478478

479479
Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z)
480480
{
481-
auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
482-
pipeline->ensureAPIPipelineStateCreated();
483-
484481
MTL::ComputeCommandEncoder* encoder = m_commandBuffer->getMetalComputeCommandEncoder();
485-
encoder->setComputePipelineState(pipeline->m_computePipelineState.get());
486482

487483
ComputeBindingContext bindingContext;
488484
bindingContext.init(m_commandBuffer->m_device, encoder);
489485
auto program = static_cast<ShaderProgramImpl*>(m_currentPipeline->m_program.get());
490486
m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout);
491487

492-
encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize);
488+
auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
489+
RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject;
490+
RefPtr<PipelineStateBase> newPipeline;
491+
SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->maybeSpecializePipeline(
492+
m_currentPipeline, rootObjectImpl, newPipeline));
493+
PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr());
494+
495+
SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineStateCreated());
496+
m_currentPipeline = newPipelineImpl;
497+
498+
m_currentPipeline->ensureAPIPipelineStateCreated();
499+
encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get());
500+
501+
502+
encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize);
493503

494504
return SLANG_OK;
495505
}

tools/gfx/metal/metal-device.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ SlangResult DeviceImpl::initialize(const Desc& desc)
7070

7171
m_device = NS::TransferPtr(MTL::CreateSystemDefaultDevice());
7272
m_commandQueue = NS::TransferPtr(m_device->newCommandQueue(64));
73+
m_hasArgumentBufferTier2 = m_device->argumentBuffersSupport() >= MTL::ArgumentBuffersTier2;
74+
75+
if (m_hasArgumentBufferTier2)
76+
{
77+
m_features.add("argument-buffer-tier-2");
78+
}
7379

7480
SLANG_RETURN_ON_FAIL(slangContext.initialize(
7581
desc.slang,
@@ -415,8 +421,19 @@ Result DeviceImpl::createTextureResource(
415421
}
416422
if (desc.allowedStates.contains(ResourceState::UnorderedAccess))
417423
{
424+
textureUsage |= MTL::TextureUsageShaderRead;
418425
textureUsage |= MTL::TextureUsageShaderWrite;
419-
textureUsage |= MTL::TextureUsageShaderAtomic;
426+
427+
// Request atomic access if the format allows it.
428+
switch (desc.format)
429+
{
430+
case Format::R32_UINT:
431+
case Format::R32_SINT:
432+
case Format::R32G32_UINT:
433+
case Format::R32G32_SINT:
434+
textureUsage |= MTL::TextureUsageShaderAtomic;
435+
break;
436+
}
420437
}
421438

422439
textureDesc->setMipmapLevelCount(desc.numMipLevels);

tools/gfx/metal/metal-device.h

+2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class DeviceImpl : public RendererBase
137137

138138
uint32_t m_queueAllocCount;
139139

140+
bool m_hasArgumentBufferTier2 = false;
141+
140142
// A list to hold objects that may have a strong back reference to the device
141143
// instance. Because of the pipeline cache in `RendererBase`, there could be a reference
142144
// cycle among `DeviceImpl`->`PipelineStateImpl`->`ShaderProgramImpl`->`DeviceImpl`.

tools/gfx/metal/metal-shader-object-layout.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@ SlangResult ShaderObjectLayoutImpl::Builder::build(ShaderObjectLayoutImpl** outL
219219
return SLANG_OK;
220220
}
221221

222+
slang::TypeLayoutReflection* ShaderObjectLayoutImpl::getParameterBlockTypeLayout()
223+
{
224+
if (!m_parameterBlockTypeLayout)
225+
{
226+
m_parameterBlockTypeLayout = m_slangSession->getTypeLayout(
227+
m_elementTypeLayout->getType(), 0, slang::LayoutRules::MetalArgumentBufferTier2);
228+
}
229+
return m_parameterBlockTypeLayout;
230+
}
231+
222232
Result ShaderObjectLayoutImpl::createForElementType(
223233
RendererBase* renderer,
224234
slang::ISession* session,

tools/gfx/metal/metal-shader-object-layout.h

+3
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class ShaderObjectLayoutImpl : public ShaderObjectLayoutBase
177177

178178
uint32_t getTotalOrdinaryDataSize() const { return m_totalOrdinaryDataSize; }
179179

180+
slang::TypeLayoutReflection* getParameterBlockTypeLayout();
180181
protected:
181182
Result _init(Builder const* builder);
182183

@@ -190,6 +191,8 @@ class ShaderObjectLayoutImpl : public ShaderObjectLayoutBase
190191
Index m_subObjectCount = 0;
191192
uint32_t m_totalOrdinaryDataSize = 0;
192193
List<SubObjectRangeInfo> m_subObjectRanges;
194+
// The type layout to use when the shader object is bind as a parameter block.
195+
slang::TypeLayoutReflection* m_parameterBlockTypeLayout = nullptr;
193196
};
194197

195198
class RootShaderObjectLayoutImpl : public ShaderObjectLayoutImpl

0 commit comments

Comments
 (0)