From 778545dba4eb04829712270cf78d8532c00ffb17 Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska <ellieh@nvidia.com> Date: Tue, 11 Mar 2025 16:56:18 +0800 Subject: [PATCH] test that link time extern struct layouts are visible for nested types closes https://github.com/shader-slang/slang/issues/6556 --- .../link-time-type-layout-nested.cpp | 241 ++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 tools/gfx-unit-test/link-time-type-layout-nested.cpp diff --git a/tools/gfx-unit-test/link-time-type-layout-nested.cpp b/tools/gfx-unit-test/link-time-type-layout-nested.cpp new file mode 100644 index 0000000000..b6ea0e64d1 --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-layout-nested.cpp @@ -0,0 +1,241 @@ +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-gfx.h" +#include "unit-test/slang-unit-test.h" + +using namespace gfx; + +namespace gfx_test +{ + +static void diagnoseIfNeeded(Slang::ComPtr<slang::IBlob>& diagnosticsBlob) +{ + if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0) + { + fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer()); + } +} + +static Slang::Result loadProgram( + gfx::IDevice* device, + Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection) +{ + // main.slang: declares the interface, extern struct Inner, and Outer struct with Inner field + const char* mainSrc = R"( + // Define an interface + public interface IFoo + { + public float4 getFoo(); + }; + + // Define an extern struct that implements the interface + public extern struct Inner : IFoo; + + // Define a regular struct that contains an Inner field + public struct Outer + { + float2 position; + Inner innerData; + float2 texCoord; + }; + + // Vertex shader entry point that takes an Outer parameter + [shader("vertex")] + float4 vertexMain(Outer params) : SV_Position + { + return float4(params.position, 0.0f, 1.0f) + params.innerData.getFoo(); + } + )"; + + // inner.slang: defines Inner with its field layout and its implementation of getFoo() + const char* innerSrc = R"( + import main; + + // Define the implementation of Inner with its field layout + export public struct Inner : IFoo + { + public float4 getFoo() { return this.data; } + float4 data; + } + )"; + + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + + // Create blobs for the two modules + auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc)); + auto innerBlob = Slang::UnownedRawBlob::create(innerSrc, strlen(innerSrc)); + + // Load modules from source + slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob); + slang::IModule* innerModule = + slangSession->loadModuleFromSource("inner", "inner.slang", innerBlob); + + // Find the entry point from main.slang + Slang::ComPtr<slang::IEntryPoint> vsEntryPoint; + SLANG_RETURN_ON_FAIL(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef())); + + // Compose the program from both modules and the entry point + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(mainModule); + componentTypes.add(innerModule); + componentTypes.add(vsEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SLANG_RETURN_ON_FAIL(slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef())); + diagnoseIfNeeded(diagnosticsBlob); + + // Link the composite program + Slang::ComPtr<slang::IComponentType> linkedProgram; + SLANG_RETURN_ON_FAIL( + composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef())); + diagnoseIfNeeded(diagnosticsBlob); + + // Retrieve the reflection information + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + // Create a shader program + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + auto shaderProgram = device->createProgram(programDesc); + outShaderProgram = shaderProgram; + + return SLANG_OK; +} + +// Function to validate the type layout of Outer struct with nested Inner struct +static void validateNestedExternStructLayout( + UnitTestContext* context, + slang::ProgramLayout* slangReflection) +{ + // Check reflection is available + SLANG_CHECK(slangReflection != nullptr); + + // Get the entry point layout for vertexMain + slang::EntryPointLayout* entryPointLayout = slangReflection->findEntryPointByName("vertexMain"); + + SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point"); + + // Get the parameter count for the entry point + auto paramCount = entryPointLayout->getParameterCount(); + SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters"); + + // Get the first parameter, which should be of type Outer + auto paramLayout = entryPointLayout->getParameterByIndex(0); + SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout"); + + // Get the type layout of the parameter + auto outerTypeLayout = paramLayout->getTypeLayout(); + SLANG_CHECK_MSG(outerTypeLayout != nullptr, "Parameter has no type layout"); + + // Check if it's a struct type + auto kind = outerTypeLayout->getKind(); + SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type"); + + // Verify Outer has 3 fields: position, innerData, texCoord + auto fieldCount = outerTypeLayout->getFieldCount(); + SLANG_CHECK_MSG(fieldCount == 3, "Outer struct does not have 3 fields"); + + // Find and check the innerData field + slang::VariableLayoutReflection* innerDataField = nullptr; + for (unsigned int i = 0; i < fieldCount; i++) + { + auto fieldLayout = outerTypeLayout->getFieldByIndex(i); + const char* fieldName = fieldLayout->getName(); + + if (fieldName && strcmp(fieldName, "innerData") == 0) + { + innerDataField = fieldLayout; + break; + } + } + + SLANG_CHECK_MSG(innerDataField != nullptr, "Could not find innerData field in Outer struct"); + + // Get the type layout of the innerData field + auto innerTypeLayout = innerDataField->getTypeLayout(); + SLANG_CHECK_MSG(innerTypeLayout != nullptr, "innerData field has no type layout"); + + // Verify Inner is a struct type + kind = innerTypeLayout->getKind(); + SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Inner is not a struct type"); + + // Verify Inner has 1 field (data) + fieldCount = innerTypeLayout->getFieldCount(); + SLANG_CHECK_MSG(fieldCount == 1, "Inner struct does not have 1 field"); + + // Find and check the data field in Inner + bool foundDataField = false; + for (unsigned int i = 0; i < fieldCount; i++) + { + auto fieldLayout = innerTypeLayout->getFieldByIndex(i); + const char* fieldName = fieldLayout->getName(); + + if (fieldName && strcmp(fieldName, "data") == 0) + { + foundDataField = true; + + // Check that it's a float4 type + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + auto fieldTypeKind = fieldTypeLayout->getKind(); + + SLANG_CHECK_MSG( + fieldTypeKind == slang::TypeReflection::Kind::Vector, + "Field 'data' is not a vector type"); + + auto elementCount = fieldTypeLayout->getElementCount(); + SLANG_CHECK_MSG(elementCount == 4, "Field 'data' is not a 4-element vector"); + + break; + } + } + + SLANG_CHECK_MSG(foundDataField, "Could not find field 'data' in Inner struct"); +} + +void linkTimeTypeLayoutNestedImpl(gfx::IDevice* device, UnitTestContext* context) +{ + Slang::ComPtr<gfx::IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection = nullptr; + + auto result = loadProgram(device, shaderProgram, slangReflection); + SLANG_CHECK(SLANG_SUCCEEDED(result)); + + // Validate the nested struct layout + validateNestedExternStructLayout(context, slangReflection); + + // Create a graphics pipeline to verify everything works + GraphicsPipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + pipelineDesc.primitiveType = PrimitiveType::Triangle; + + ComPtr<gfx::IPipelineState> pipelineState; + auto pipelineResult = + device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef()); + SLANG_CHECK(SLANG_SUCCEEDED(pipelineResult)); +} + +// +// This test verifies that type layout information correctly propagates through +// the Slang compilation pipeline when a regular struct contains a field whose type +// is an extern struct defined in another module. +// Specifically, it tests that: +// +// 1. The Outer struct correctly includes the Inner extern struct as a field +// 2. After linking, the Inner struct's layout is properly resolved with its field +// 3. The complete type layout information is available in the reflection data +// + +SLANG_UNIT_TEST(linkTimeTypeLayoutNested) +{ + runTestImpl(linkTimeTypeLayoutNestedImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); +} + +} // namespace gfx_test