From 28a09948b938dd6e6e7feab76b1f5fb7ab162a3d Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Tue, 25 Feb 2025 11:31:07 +0800 Subject: [PATCH 1/2] Resolve 'extern' types during type layout generation if possible Closes https://github.com/shader-slang/slang/issues/5994 Closes https://github.com/shader-slang/slang/issues/6437 --- cmake/SlangTarget.cmake | 12 +- source/slang/slang-parameter-binding.cpp | 7 + source/slang/slang-type-layout.cpp | 79 ++++++ source/slang/slang-type-layout.h | 7 + tools/gfx-unit-test/link-time-type-layout.cpp | 245 ++++++++++++++++++ tools/unit-test/slang-unit-test.h | 3 + 6 files changed, 345 insertions(+), 8 deletions(-) create mode 100644 tools/gfx-unit-test/link-time-type-layout.cpp diff --git a/cmake/SlangTarget.cmake b/cmake/SlangTarget.cmake index 85d09be897..ebec67f6b4 100644 --- a/cmake/SlangTarget.cmake +++ b/cmake/SlangTarget.cmake @@ -507,14 +507,10 @@ function(slang_add_target dir type) endif() install( TARGETS ${target} ${export_args} - ARCHIVE DESTINATION ${archive_subdir} - ${ARGN} - LIBRARY DESTINATION ${library_subdir} - ${ARGN} - RUNTIME DESTINATION ${runtime_subdir} - ${ARGN} - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - ${ARGN} + ARCHIVE DESTINATION ${archive_subdir} ${ARGN} + LIBRARY DESTINATION ${library_subdir} ${ARGN} + RUNTIME DESTINATION ${runtime_subdir} ${ARGN} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN} ) endmacro() diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 0e47ff56d7..86d76fc6d4 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2382,7 +2382,14 @@ static RefPtr processEntryPointVaryingParameter( // otherwise they will include all of the above cases... else if (auto declRefType = as(type)) { + // If we are trying to get the layout of some extern type, do our best + // to look it up in other loaded modules and generate the type layout + // based on that. + declRefType = context->layoutContext.lookupExternDeclRefType(declRefType); + auto declRef = declRefType->getDeclRef(); + + if (auto structDeclRef = declRef.as()) { RefPtr structLayout = new StructTypeLayout(); diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index a412bf5b28..7968450b49 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -4,6 +4,7 @@ #include "../compiler-core/slang-artifact-desc-util.h" #include "slang-check-impl.h" #include "slang-ir-insts.h" +#include "slang-mangle.h" #include "slang-syntax.h" #include @@ -5013,8 +5014,13 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type } else if (auto declRefType = as(type)) { + // If we are trying to get the layout of some extern type, do our best + // to look it up in other loaded modules and generate the type layout + // based on that. + declRefType = context.lookupExternDeclRefType(declRefType); auto declRef = declRefType->getDeclRef(); + if (auto structDeclRef = declRef.as()) { StructTypeLayoutBuilder typeLayoutBuilder; @@ -5694,4 +5700,77 @@ GlobalGenericParamDecl* GenericParamTypeLayout::getGlobalGenericParamDecl() return rsDeclRef.getDecl(); } +DeclRefType* TypeLayoutContext::lookupExternDeclRefType(DeclRefType* declRefType) +{ + const auto declRef = declRefType->getDeclRef(); + const auto decl = declRef.getDecl(); + const auto isExtern = + decl->hasModifier() || decl->hasModifier(); + if (isExtern) + { + if (!externTypeMap) + buildExternTypeMap(); + const auto mangledName = getMangledName(targetReq->getLinkage()->getASTBuilder(), decl); + externTypeMap->tryGetValue(mangledName, declRefType); + } + return declRefType; +} + +void TypeLayoutContext::buildExternTypeMap() +{ + externTypeMap.emplace(); + const auto linkage = targetReq->getLinkage(); + + HashSet externNames; + Dictionary allTypes; + + // Traverse the AST and keep track of all extern names and all type definitions + // We'll match them up later + auto processDecl = [&](auto&& go, Decl* decl) -> void + { + const auto isExtern = + decl->hasModifier() || decl->hasModifier(); + + if (auto declRefType = as(DeclRefType::create(astBuilder, decl))) + { + String mangledName = getMangledName(astBuilder, decl); + + if (isExtern) + { + externNames.add(mangledName); + } + else + { + allTypes[mangledName] = declRefType; + } + } + + if (auto scopeDecl = as(decl)) + { + for (auto member : scopeDecl->members) + { + go(go, member); + } + } + }; + + for (const auto& m : linkage->loadedModulesList) + { + const auto& ast = m->getModuleDecl(); + for (auto member : ast->members) + { + processDecl(processDecl, member); + } + } + + // Only keep the types that have matching extern declarations + for (const auto& externName : externNames) + { + if (allTypes.containsKey(externName)) + { + externTypeMap.value()[externName] = allTypes[externName]; + } + } +} + } // namespace Slang diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index 7d3dd8369d..bcb8420133 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -1183,6 +1183,13 @@ struct TypeLayoutContext // Options passed to object layout ObjectLayoutRulesImpl::Options objectLayoutOptions; + // Mangled names to DeclRefType, this is used to match up 'extern' types to + // their linked in definitions during layout generation + std::optional> externTypeMap; + + DeclRefType* lookupExternDeclRefType(DeclRefType* declRefType); + void buildExternTypeMap(); + LayoutRulesImpl* getRules() { return rules; } LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); } diff --git a/tools/gfx-unit-test/link-time-type-layout.cpp b/tools/gfx-unit-test/link-time-type-layout.cpp new file mode 100644 index 0000000000..c1997b18a4 --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-layout.cpp @@ -0,0 +1,245 @@ +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-gfx.h" +#include "unit-test/slang-unit-test.h" + +using namespace gfx; + +namespace gfx_test +{ + +static void diagnoseIfNeeded(Slang::ComPtr& diagnosticsBlob) +{ + if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0) + { + fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer()); + } +} + +static Slang::Result loadSpirvProgram( + gfx::IDevice* device, + Slang::ComPtr& outShaderProgram, + slang::ProgramLayout*& slangReflection) +{ + // main.slang: declares the interface and extern struct S, and the vertex shader. + const char* mainSrc = R"( + public interface IFoo + { + public float4 getFoo(); + }; + public extern struct S : IFoo; + + [shader("vertex")] + float4 vertexMain(S params) : SV_Position + { + return params.getFoo(); + } + )"; + + // foo.slang: defines S with its field layout and its implementation of getFoo(). + const char* fooSrc = R"( + import main; + + export public struct S : IFoo + { + public float4 getFoo() { return this.foo; } + float4 foo; + } + )"; + + Slang::ComPtr slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr diagnosticsBlob; + + // Create blobs for the two modules. + auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc)); + auto fooBlob = Slang::UnownedRawBlob::create(fooSrc, strlen(fooSrc)); + + // Load modules from source. + slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob); + slang::IModule* fooModule = slangSession->loadModuleFromSource("foo", "foo.slang", fooBlob); + + // Find the entry point from main.slang + Slang::ComPtr vsEntryPoint; + SLANG_RETURN_ON_FAIL(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef())); + + // Compose the program from both modules and the entry point. + Slang::List componentTypes; + componentTypes.add(mainModule); + componentTypes.add(fooModule); + componentTypes.add(vsEntryPoint); + + Slang::ComPtr composedProgram; + SLANG_RETURN_ON_FAIL(slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef())); + diagnoseIfNeeded(diagnosticsBlob); + + // Link the composite program. + Slang::ComPtr linkedProgram; + SLANG_RETURN_ON_FAIL( + composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef())); + diagnoseIfNeeded(diagnosticsBlob); + + // Retrieve the reflection information. + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + // Create a shader program that will generate SPIRV code. + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + auto shaderProgram = device->createProgram(programDesc); + outShaderProgram = shaderProgram; + + // Force SPIRV generation by explicitly requesting it + Slang::ComPtr spirvBlob; + Slang::ComPtr spirvDiagnostics; + + // Request SPIRV code generation for the vertex shader entry point + auto targetIndex = 0; // Assuming this is the first/only target + auto entryPointIndex = 0; // Assuming this is the first/only entry point + + auto result = composedProgram->getEntryPointCode( + entryPointIndex, + targetIndex, + spirvBlob.writeRef(), + spirvDiagnostics.writeRef()); + + if (SLANG_FAILED(result)) + { + if (spirvDiagnostics && spirvDiagnostics->getBufferSize() > 0) + { + fprintf( + stderr, + "SPIRV generation failed: %s\n", + (const char*)spirvDiagnostics->getBufferPointer()); + } + return result; + } + + // Verify we actually got SPIRV code + if (!spirvBlob || spirvBlob->getBufferSize() == 0) + { + return SLANG_FAIL; + } + + return SLANG_OK; +} + +// Function to validate the type layout of struct S +static void validateStructSLayout(UnitTestContext* context, slang::ProgramLayout* slangReflection) +{ + // Check reflection is available + SLANG_CHECK(slangReflection != nullptr); + + // Get the entry point layout for vertexMain + auto entryPointCount = slangReflection->getEntryPointCount(); + slang::EntryPointLayout* entryPointLayout = nullptr; + + for (unsigned int i = 0; i < entryPointCount; i++) + { + auto currentEntryPoint = slangReflection->getEntryPointByIndex(i); + const char* name = currentEntryPoint->getName(); + + if (strcmp(name, "vertexMain") == 0) + { + entryPointLayout = currentEntryPoint; + break; + } + } + + SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point"); + + // Get the parameter count for the entry point + auto paramCount = entryPointLayout->getParameterCount(); + SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters"); + + // Get the first parameter, which should be of type S + auto paramLayout = entryPointLayout->getParameterByIndex(0); + SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout"); + + // Get the type layout of the parameter + auto typeLayout = paramLayout->getTypeLayout(); + SLANG_CHECK_MSG(typeLayout != nullptr, "Parameter has no type layout"); + + // Check if it's a struct type + auto kind = typeLayout->getKind(); + SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type"); + + // Get the field count + auto fieldCount = typeLayout->getFieldCount(); + SLANG_CHECK_MSG(fieldCount >= 1, "Struct has no fields"); + + // Check for the 'foo' field + bool foundFooField = false; + for (unsigned int i = 0; i < fieldCount; i++) + { + auto fieldLayout = typeLayout->getFieldByIndex(i); + const char* fieldName = fieldLayout->getName(); + + if (fieldName && strcmp(fieldName, "foo") == 0) + { + foundFooField = true; + + // Check that it's a float4 type + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + auto fieldTypeKind = fieldTypeLayout->getKind(); + + SLANG_CHECK_MSG( + fieldTypeKind == slang::TypeReflection::Kind::Vector, + "Field 'foo' is not a vector type"); + + auto elementCount = fieldTypeLayout->getElementCount(); + SLANG_CHECK_MSG(elementCount == 4, "Field 'foo' is not a 4-element vector"); + + break; + } + } + + SLANG_CHECK_MSG(foundFooField, "Could not find field 'foo' in struct S"); +} + +void linkTimeTypeLayoutImpl(gfx::IDevice* device, UnitTestContext* context) +{ + Slang::ComPtr shaderProgram; + slang::ProgramLayout* slangReflection = nullptr; + + auto result = loadSpirvProgram(device, shaderProgram, slangReflection); + SLANG_CHECK(SLANG_SUCCEEDED(result)); + + // Validate the struct S layout + validateStructSLayout(context, slangReflection); + + // Create a graphics pipeline to verify SPIRV code generation works + GraphicsPipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + + // We need to set up a minimal pipeline state for a vertex shader + pipelineDesc.primitiveType = PrimitiveType::Triangle; + + ComPtr pipelineState; + auto pipelineResult = + device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef()); + SLANG_CHECK(SLANG_SUCCEEDED(pipelineResult)); +} + +// +// This test verifies that type layout information correctly propagates through +// the Slang compilation pipeline when types are defined in modules other than where they are used. +// Specifically, it tests +// that when using an extern struct that's defined in a separate module: +// +// 1. The struct definition is properly linked across module boundaries +// 2. The complete type layout information is available in the reflection data +// 3. SPIRV code generation succeeds with the linked type information (this +// failed before when layout information was required during code generation) +// + +SLANG_UNIT_TEST(linkTimeTypeLayout) +{ + runTestImpl(linkTimeTypeLayoutImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/unit-test/slang-unit-test.h b/tools/unit-test/slang-unit-test.h index 3484b30204..8f0f0a445c 100644 --- a/tools/unit-test/slang-unit-test.h +++ b/tools/unit-test/slang-unit-test.h @@ -95,3 +95,6 @@ typedef IUnitTestModule* (*UnitTestGetModuleFunc)(); #define SLANG_IGNORE_TEST \ getTestReporter()->addResult(TestResult::Ignored); \ throw AbortTestException(); +#define SLANG_CHECK_MSG(condition, message) \ + getTestReporter() \ + ->addResultWithLocation((condition), #condition " " message, __FILE__, __LINE__) From 106fbd88fd53a5354106cf76be7d7bb48cb94026 Mon Sep 17 00:00:00 2001 From: slangbot <186143334+slangbot@users.noreply.github.com> Date: Fri, 28 Feb 2025 01:34:22 +0000 Subject: [PATCH 2/2] format code --- cmake/SlangTarget.cmake | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cmake/SlangTarget.cmake b/cmake/SlangTarget.cmake index ebec67f6b4..85d09be897 100644 --- a/cmake/SlangTarget.cmake +++ b/cmake/SlangTarget.cmake @@ -507,10 +507,14 @@ function(slang_add_target dir type) endif() install( TARGETS ${target} ${export_args} - ARCHIVE DESTINATION ${archive_subdir} ${ARGN} - LIBRARY DESTINATION ${library_subdir} ${ARGN} - RUNTIME DESTINATION ${runtime_subdir} ${ARGN} - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN} + ARCHIVE DESTINATION ${archive_subdir} + ${ARGN} + LIBRARY DESTINATION ${library_subdir} + ${ARGN} + RUNTIME DESTINATION ${runtime_subdir} + ${ARGN} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + ${ARGN} ) endmacro()