Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve 'extern' types during type layout generation if possible #6450

Merged
merged 4 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions source/slang/slang-parameter-binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2382,7 +2382,14 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
// otherwise they will include all of the above cases...
else if (auto declRefType = as<DeclRefType>(type))
{
// If we are trying to get the layout of some extern type, do our best
// to look it up in other loaded modules and generate the type layout
// based on that.
declRefType = context->layoutContext.lookupExternDeclRefType(declRefType);

auto declRef = declRefType->getDeclRef();


if (auto structDeclRef = declRef.as<StructDecl>())
{
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
Expand Down
79 changes: 79 additions & 0 deletions source/slang/slang-type-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "../compiler-core/slang-artifact-desc-util.h"
#include "slang-check-impl.h"
#include "slang-ir-insts.h"
#include "slang-mangle.h"
#include "slang-syntax.h"

#include <assert.h>
Expand Down Expand Up @@ -5013,8 +5014,13 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type
}
else if (auto declRefType = as<DeclRefType>(type))
{
// If we are trying to get the layout of some extern type, do our best
// to look it up in other loaded modules and generate the type layout
// based on that.
declRefType = context.lookupExternDeclRefType(declRefType);
auto declRef = declRefType->getDeclRef();


if (auto structDeclRef = declRef.as<StructDecl>())
{
StructTypeLayoutBuilder typeLayoutBuilder;
Expand Down Expand Up @@ -5694,4 +5700,77 @@ GlobalGenericParamDecl* GenericParamTypeLayout::getGlobalGenericParamDecl()
return rsDeclRef.getDecl();
}

DeclRefType* TypeLayoutContext::lookupExternDeclRefType(DeclRefType* declRefType)
{
const auto declRef = declRefType->getDeclRef();
const auto decl = declRef.getDecl();
const auto isExtern =
decl->hasModifier<ExternAttribute>() || decl->hasModifier<ExternModifier>();
if (isExtern)
{
if (!externTypeMap)
buildExternTypeMap();
const auto mangledName = getMangledName(targetReq->getLinkage()->getASTBuilder(), decl);
externTypeMap->tryGetValue(mangledName, declRefType);
}
return declRefType;
}

void TypeLayoutContext::buildExternTypeMap()
{
externTypeMap.emplace();
const auto linkage = targetReq->getLinkage();

HashSet<String> externNames;
Dictionary<String, DeclRefType*> allTypes;

// Traverse the AST and keep track of all extern names and all type definitions
// We'll match them up later
auto processDecl = [&](auto&& go, Decl* decl) -> void
{
const auto isExtern =
decl->hasModifier<ExternAttribute>() || decl->hasModifier<ExternModifier>();

if (auto declRefType = as<DeclRefType>(DeclRefType::create(astBuilder, decl)))
{
String mangledName = getMangledName(astBuilder, decl);

if (isExtern)
{
externNames.add(mangledName);
}
else
{
allTypes[mangledName] = declRefType;
}
}

if (auto scopeDecl = as<ScopeDecl>(decl))
{
for (auto member : scopeDecl->members)
{
go(go, member);
}
}
};

for (const auto& m : linkage->loadedModulesList)
{
const auto& ast = m->getModuleDecl();
for (auto member : ast->members)
{
processDecl(processDecl, member);
}
}

// Only keep the types that have matching extern declarations
for (const auto& externName : externNames)
{
if (allTypes.containsKey(externName))
{
externTypeMap.value()[externName] = allTypes[externName];
}
}
}

} // namespace Slang
7 changes: 7 additions & 0 deletions source/slang/slang-type-layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,13 @@ struct TypeLayoutContext
// Options passed to object layout
ObjectLayoutRulesImpl::Options objectLayoutOptions;

// Mangled names to DeclRefType, this is used to match up 'extern' types to
// their linked in definitions during layout generation
std::optional<Dictionary<String, DeclRefType*>> externTypeMap;

DeclRefType* lookupExternDeclRefType(DeclRefType* declRefType);
void buildExternTypeMap();

LayoutRulesImpl* getRules() { return rules; }
LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); }

Expand Down
245 changes: 245 additions & 0 deletions tools/gfx-unit-test/link-time-type-layout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#include "core/slang-blob.h"
#include "gfx-test-util.h"
#include "slang-gfx.h"
#include "unit-test/slang-unit-test.h"

using namespace gfx;

namespace gfx_test
{

static void diagnoseIfNeeded(Slang::ComPtr<slang::IBlob>& diagnosticsBlob)
{
if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0)
{
fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer());
}
}

static Slang::Result loadSpirvProgram(
gfx::IDevice* device,
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
slang::ProgramLayout*& slangReflection)
{
// main.slang: declares the interface and extern struct S, and the vertex shader.
const char* mainSrc = R"(
public interface IFoo
{
public float4 getFoo();
};
public extern struct S : IFoo;

[shader("vertex")]
float4 vertexMain(S params) : SV_Position
{
return params.getFoo();
}
)";

// foo.slang: defines S with its field layout and its implementation of getFoo().
const char* fooSrc = R"(
import main;

export public struct S : IFoo
{
public float4 getFoo() { return this.foo; }
float4 foo;
}
)";

Slang::ComPtr<slang::ISession> slangSession;
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
Slang::ComPtr<slang::IBlob> diagnosticsBlob;

// Create blobs for the two modules.
auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc));
auto fooBlob = Slang::UnownedRawBlob::create(fooSrc, strlen(fooSrc));

// Load modules from source.
slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob);
slang::IModule* fooModule = slangSession->loadModuleFromSource("foo", "foo.slang", fooBlob);

// Find the entry point from main.slang
Slang::ComPtr<slang::IEntryPoint> vsEntryPoint;
SLANG_RETURN_ON_FAIL(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef()));

// Compose the program from both modules and the entry point.
Slang::List<slang::IComponentType*> componentTypes;
componentTypes.add(mainModule);
componentTypes.add(fooModule);
componentTypes.add(vsEntryPoint);

Slang::ComPtr<slang::IComponentType> composedProgram;
SLANG_RETURN_ON_FAIL(slangSession->createCompositeComponentType(
componentTypes.getBuffer(),
componentTypes.getCount(),
composedProgram.writeRef(),
diagnosticsBlob.writeRef()));
diagnoseIfNeeded(diagnosticsBlob);

// Link the composite program.
Slang::ComPtr<slang::IComponentType> linkedProgram;
SLANG_RETURN_ON_FAIL(
composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()));
diagnoseIfNeeded(diagnosticsBlob);

// Retrieve the reflection information.
composedProgram = linkedProgram;
slangReflection = composedProgram->getLayout();

// Create a shader program that will generate SPIRV code.
gfx::IShaderProgram::Desc programDesc = {};
programDesc.slangGlobalScope = composedProgram.get();
auto shaderProgram = device->createProgram(programDesc);
outShaderProgram = shaderProgram;

// Force SPIRV generation by explicitly requesting it
Slang::ComPtr<slang::IBlob> spirvBlob;
Slang::ComPtr<slang::IBlob> spirvDiagnostics;

// Request SPIRV code generation for the vertex shader entry point
auto targetIndex = 0; // Assuming this is the first/only target
auto entryPointIndex = 0; // Assuming this is the first/only entry point

auto result = composedProgram->getEntryPointCode(
entryPointIndex,
targetIndex,
spirvBlob.writeRef(),
spirvDiagnostics.writeRef());

if (SLANG_FAILED(result))
{
if (spirvDiagnostics && spirvDiagnostics->getBufferSize() > 0)
{
fprintf(
stderr,
"SPIRV generation failed: %s\n",
(const char*)spirvDiagnostics->getBufferPointer());
}
return result;
}

// Verify we actually got SPIRV code
if (!spirvBlob || spirvBlob->getBufferSize() == 0)
{
return SLANG_FAIL;
}

return SLANG_OK;
}

// Function to validate the type layout of struct S
static void validateStructSLayout(UnitTestContext* context, slang::ProgramLayout* slangReflection)
{
// Check reflection is available
SLANG_CHECK(slangReflection != nullptr);

// Get the entry point layout for vertexMain
auto entryPointCount = slangReflection->getEntryPointCount();
slang::EntryPointLayout* entryPointLayout = nullptr;

for (unsigned int i = 0; i < entryPointCount; i++)
{
auto currentEntryPoint = slangReflection->getEntryPointByIndex(i);
const char* name = currentEntryPoint->getName();

if (strcmp(name, "vertexMain") == 0)
{
entryPointLayout = currentEntryPoint;
break;
}
}

SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point");

// Get the parameter count for the entry point
auto paramCount = entryPointLayout->getParameterCount();
SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters");

// Get the first parameter, which should be of type S
auto paramLayout = entryPointLayout->getParameterByIndex(0);
SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout");

// Get the type layout of the parameter
auto typeLayout = paramLayout->getTypeLayout();
SLANG_CHECK_MSG(typeLayout != nullptr, "Parameter has no type layout");

// Check if it's a struct type
auto kind = typeLayout->getKind();
SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type");

// Get the field count
auto fieldCount = typeLayout->getFieldCount();
SLANG_CHECK_MSG(fieldCount >= 1, "Struct has no fields");

// Check for the 'foo' field
bool foundFooField = false;
for (unsigned int i = 0; i < fieldCount; i++)
{
auto fieldLayout = typeLayout->getFieldByIndex(i);
const char* fieldName = fieldLayout->getName();

if (fieldName && strcmp(fieldName, "foo") == 0)
{
foundFooField = true;

// Check that it's a float4 type
auto fieldTypeLayout = fieldLayout->getTypeLayout();
auto fieldTypeKind = fieldTypeLayout->getKind();

SLANG_CHECK_MSG(
fieldTypeKind == slang::TypeReflection::Kind::Vector,
"Field 'foo' is not a vector type");

auto elementCount = fieldTypeLayout->getElementCount();
SLANG_CHECK_MSG(elementCount == 4, "Field 'foo' is not a 4-element vector");

break;
}
}

SLANG_CHECK_MSG(foundFooField, "Could not find field 'foo' in struct S");
}

void linkTimeTypeLayoutImpl(gfx::IDevice* device, UnitTestContext* context)
{
Slang::ComPtr<gfx::IShaderProgram> shaderProgram;
slang::ProgramLayout* slangReflection = nullptr;

auto result = loadSpirvProgram(device, shaderProgram, slangReflection);
SLANG_CHECK(SLANG_SUCCEEDED(result));

// Validate the struct S layout
validateStructSLayout(context, slangReflection);

// Create a graphics pipeline to verify SPIRV code generation works
GraphicsPipelineStateDesc pipelineDesc = {};
pipelineDesc.program = shaderProgram.get();

// We need to set up a minimal pipeline state for a vertex shader
pipelineDesc.primitiveType = PrimitiveType::Triangle;

ComPtr<gfx::IPipelineState> pipelineState;
auto pipelineResult =
device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef());
SLANG_CHECK(SLANG_SUCCEEDED(pipelineResult));
}

//
// This test verifies that type layout information correctly propagates through
// the Slang compilation pipeline when types are defined in modules other than where they are used.
// Specifically, it tests
// that when using an extern struct that's defined in a separate module:
//
// 1. The struct definition is properly linked across module boundaries
// 2. The complete type layout information is available in the reflection data
// 3. SPIRV code generation succeeds with the linked type information (this
// failed before when layout information was required during code generation)
//

SLANG_UNIT_TEST(linkTimeTypeLayout)
{
runTestImpl(linkTimeTypeLayoutImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
}

} // namespace gfx_test
3 changes: 3 additions & 0 deletions tools/unit-test/slang-unit-test.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ typedef IUnitTestModule* (*UnitTestGetModuleFunc)();
#define SLANG_IGNORE_TEST \
getTestReporter()->addResult(TestResult::Ignored); \
throw AbortTestException();
#define SLANG_CHECK_MSG(condition, message) \
getTestReporter() \
->addResultWithLocation((condition), #condition " " message, __FILE__, __LINE__)
Loading