Skip to content

Commit

Permalink
Fix argument buffer tier2 layout computation. (#6101)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Jan 16, 2025
1 parent 387f2be commit edf5e9f
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 2 deletions.
75 changes: 73 additions & 2 deletions source/slang/slang-type-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ struct MetalLayoutRulesImpl : public CPULayoutRulesImpl
auto alignedElementCount = 1 << Math::Log2Ceil((uint32_t)elementCount);

// Metal aligns vectors to 2/4 element boundaries.
size_t size = elementSize * elementCount;
size_t size = alignedElementCount * elementSize;
size_t alignment = alignedElementCount * elementSize;

SimpleLayoutInfo vectorInfo;
Expand Down Expand Up @@ -1147,6 +1147,14 @@ struct MetalLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
LayoutRulesImpl* getStructuredBufferRules(CompilerOptionSet& compilerOptions) override;
};

struct MetalArgumentBufferTier2LayoutRulesFamilyImpl : MetalLayoutRulesFamilyImpl
{
virtual LayoutRulesImpl* getConstantBufferRules(
CompilerOptionSet& compilerOptions,
Type* containerType) override;
virtual LayoutRulesImpl* getParameterBlockRules(CompilerOptionSet& compilerOptions) override;
};

struct WGSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
{
virtual LayoutRulesImpl* getAnyValueRules() override;
Expand Down Expand Up @@ -1175,6 +1183,7 @@ HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl;
CPULayoutRulesFamilyImpl kCPULayoutRulesFamilyImpl;
CUDALayoutRulesFamilyImpl kCUDALayoutRulesFamilyImpl;
MetalLayoutRulesFamilyImpl kMetalLayoutRulesFamilyImpl;
MetalArgumentBufferTier2LayoutRulesFamilyImpl kMetalArgumentBufferTier2LayoutRulesFamilyImpl;
WGSLLayoutRulesFamilyImpl kWGSLLayoutRulesFamilyImpl;

// CPU case
Expand Down Expand Up @@ -1969,8 +1978,44 @@ struct MetalArgumentBufferElementLayoutRulesImpl : ObjectLayoutRulesImpl, Defaul
}
};

struct MetalTier2ObjectLayoutRulesImpl : ObjectLayoutRulesImpl
{
virtual ObjectLayoutInfo GetObjectLayout(ShaderParameterKind kind, const Options& /* options */)
override
{
switch (kind)
{
case ShaderParameterKind::ConstantBuffer:
case ShaderParameterKind::ParameterBlock:
case ShaderParameterKind::StructuredBuffer:
case ShaderParameterKind::MutableStructuredBuffer:
case ShaderParameterKind::RawBuffer:
case ShaderParameterKind::Buffer:
case ShaderParameterKind::MutableRawBuffer:
case ShaderParameterKind::MutableBuffer:
case ShaderParameterKind::ShaderStorageBuffer:
case ShaderParameterKind::AccelerationStructure:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8);
case ShaderParameterKind::AppendConsumeStructuredBuffer:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 16, 8);
case ShaderParameterKind::MutableTexture:
case ShaderParameterKind::TextureUniformBuffer:
case ShaderParameterKind::Texture:
case ShaderParameterKind::SamplerState:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8);
case ShaderParameterKind::TextureSampler:
case ShaderParameterKind::MutableTextureSampler:
return SimpleLayoutInfo(LayoutResourceKind::Uniform, 16, 8);
default:
SLANG_UNEXPECTED("unhandled shader parameter kind");
UNREACHABLE_RETURN(SimpleLayoutInfo());
}
}
};

static MetalObjectLayoutRulesImpl kMetalObjectLayoutRulesImpl;
static MetalArgumentBufferElementLayoutRulesImpl kMetalArgumentBufferElementLayoutRulesImpl;
static MetalTier2ObjectLayoutRulesImpl kMetalTier2ObjectLayoutRulesImpl;
static MetalLayoutRulesImpl kMetalLayoutRulesImpl;

LayoutRulesImpl kMetalAnyValueLayoutRulesImpl_ = {
Expand All @@ -1991,6 +2036,18 @@ LayoutRulesImpl kMetalParameterBlockLayoutRulesImpl_ = {
&kMetalArgumentBufferElementLayoutRulesImpl,
};

LayoutRulesImpl kMetalTier2ConstantBufferLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
&kMetalTier2ObjectLayoutRulesImpl,
};

LayoutRulesImpl kMetalTier2ParameterBlockLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
&kMetalTier2ObjectLayoutRulesImpl,
};

LayoutRulesImpl kMetalStructuredBufferLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
Expand Down Expand Up @@ -2079,6 +2136,20 @@ LayoutRulesImpl* MetalLayoutRulesFamilyImpl::getHitAttributesParameterRules()
return nullptr;
}

LayoutRulesImpl* MetalArgumentBufferTier2LayoutRulesFamilyImpl::getConstantBufferRules(
CompilerOptionSet&,
Type*)
{
return &kMetalTier2ConstantBufferLayoutRulesImpl_;
}

LayoutRulesImpl* MetalArgumentBufferTier2LayoutRulesFamilyImpl::getParameterBlockRules(
CompilerOptionSet&)
{
return &kMetalTier2ParameterBlockLayoutRulesImpl_;
}


// WGSL Family

LayoutRulesImpl kWGSLConstantBufferLayoutRulesImpl_ = {
Expand Down Expand Up @@ -2229,7 +2300,7 @@ TypeLayoutContext getInitialLayoutContextForTarget(
rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq);
break;
case slang::LayoutRules::MetalArgumentBufferTier2:
rulesFamily = &kCPULayoutRulesFamilyImpl;
rulesFamily = &kMetalArgumentBufferTier2LayoutRulesFamilyImpl;
break;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// unit-test-argument-buffer-tier-2-reflection.cpp

#include "../../source/core/slang-io.h"
#include "../../source/core/slang-process.h"
#include "slang-com-ptr.h"
#include "slang.h"
#include "unit-test/slang-unit-test.h"

#include <stdio.h>
#include <stdlib.h>

using namespace Slang;

// Test metal argument buffer tier2 layout rules.

SLANG_UNIT_TEST(metalArgumentBufferTier2Reflection)
{
const char* userSourceBody = R"(
struct A
{
float3 one;
float3 two;
float three;
}
struct Args{
ParameterBlock<A> a;
}
ParameterBlock<Args> argument_buffer;
RWStructuredBuffer<float> outputBuffer;
[numthreads(1,1,1)]
void computeMain()
{
outputBuffer[0] = argument_buffer.a.two.x;
}
)";

auto moduleName = "moduleG" + String(Process::getId());
String userSource = "import " + moduleName + ";\n" + userSourceBody;
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_SPIRV;
targetDesc.profile = globalSession->findProfile("spirv_1_5");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"m",
"m.slang",
userSourceBody,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

auto layout = module->getLayout();

auto type = layout->findTypeByName("A");
auto typeLayout = layout->getTypeLayout(type, slang::LayoutRules::MetalArgumentBufferTier2);
SLANG_CHECK(typeLayout->getFieldByIndex(0)->getOffset() == 0);
SLANG_CHECK(typeLayout->getFieldByIndex(0)->getTypeLayout()->getSize() == 16);
SLANG_CHECK(typeLayout->getFieldByIndex(1)->getOffset() == 16);
SLANG_CHECK(typeLayout->getFieldByIndex(1)->getTypeLayout()->getSize() == 16);
SLANG_CHECK(typeLayout->getFieldByIndex(2)->getOffset() == 32);
SLANG_CHECK(typeLayout->getFieldByIndex(2)->getTypeLayout()->getSize() == 4);
}

0 comments on commit edf5e9f

Please sign in to comment.