Skip to content

Commit 28a0994

Browse files
committed
Resolve 'extern' types during type layout generation if possible
Closes shader-slang#5994 Closes shader-slang#6437
1 parent f7b9745 commit 28a0994

File tree

6 files changed

+345
-8
lines changed

6 files changed

+345
-8
lines changed

cmake/SlangTarget.cmake

+4-8
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,10 @@ function(slang_add_target dir type)
507507
endif()
508508
install(
509509
TARGETS ${target} ${export_args}
510-
ARCHIVE DESTINATION ${archive_subdir}
511-
${ARGN}
512-
LIBRARY DESTINATION ${library_subdir}
513-
${ARGN}
514-
RUNTIME DESTINATION ${runtime_subdir}
515-
${ARGN}
516-
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
517-
${ARGN}
510+
ARCHIVE DESTINATION ${archive_subdir} ${ARGN}
511+
LIBRARY DESTINATION ${library_subdir} ${ARGN}
512+
RUNTIME DESTINATION ${runtime_subdir} ${ARGN}
513+
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN}
518514
)
519515
endmacro()
520516

source/slang/slang-parameter-binding.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -2382,7 +2382,14 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
23822382
// otherwise they will include all of the above cases...
23832383
else if (auto declRefType = as<DeclRefType>(type))
23842384
{
2385+
// If we are trying to get the layout of some extern type, do our best
2386+
// to look it up in other loaded modules and generate the type layout
2387+
// based on that.
2388+
declRefType = context->layoutContext.lookupExternDeclRefType(declRefType);
2389+
23852390
auto declRef = declRefType->getDeclRef();
2391+
2392+
23862393
if (auto structDeclRef = declRef.as<StructDecl>())
23872394
{
23882395
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();

source/slang/slang-type-layout.cpp

+79
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "../compiler-core/slang-artifact-desc-util.h"
55
#include "slang-check-impl.h"
66
#include "slang-ir-insts.h"
7+
#include "slang-mangle.h"
78
#include "slang-syntax.h"
89

910
#include <assert.h>
@@ -5013,8 +5014,13 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type
50135014
}
50145015
else if (auto declRefType = as<DeclRefType>(type))
50155016
{
5017+
// If we are trying to get the layout of some extern type, do our best
5018+
// to look it up in other loaded modules and generate the type layout
5019+
// based on that.
5020+
declRefType = context.lookupExternDeclRefType(declRefType);
50165021
auto declRef = declRefType->getDeclRef();
50175022

5023+
50185024
if (auto structDeclRef = declRef.as<StructDecl>())
50195025
{
50205026
StructTypeLayoutBuilder typeLayoutBuilder;
@@ -5694,4 +5700,77 @@ GlobalGenericParamDecl* GenericParamTypeLayout::getGlobalGenericParamDecl()
56945700
return rsDeclRef.getDecl();
56955701
}
56965702

5703+
DeclRefType* TypeLayoutContext::lookupExternDeclRefType(DeclRefType* declRefType)
5704+
{
5705+
const auto declRef = declRefType->getDeclRef();
5706+
const auto decl = declRef.getDecl();
5707+
const auto isExtern =
5708+
decl->hasModifier<ExternAttribute>() || decl->hasModifier<ExternModifier>();
5709+
if (isExtern)
5710+
{
5711+
if (!externTypeMap)
5712+
buildExternTypeMap();
5713+
const auto mangledName = getMangledName(targetReq->getLinkage()->getASTBuilder(), decl);
5714+
externTypeMap->tryGetValue(mangledName, declRefType);
5715+
}
5716+
return declRefType;
5717+
}
5718+
5719+
void TypeLayoutContext::buildExternTypeMap()
5720+
{
5721+
externTypeMap.emplace();
5722+
const auto linkage = targetReq->getLinkage();
5723+
5724+
HashSet<String> externNames;
5725+
Dictionary<String, DeclRefType*> allTypes;
5726+
5727+
// Traverse the AST and keep track of all extern names and all type definitions
5728+
// We'll match them up later
5729+
auto processDecl = [&](auto&& go, Decl* decl) -> void
5730+
{
5731+
const auto isExtern =
5732+
decl->hasModifier<ExternAttribute>() || decl->hasModifier<ExternModifier>();
5733+
5734+
if (auto declRefType = as<DeclRefType>(DeclRefType::create(astBuilder, decl)))
5735+
{
5736+
String mangledName = getMangledName(astBuilder, decl);
5737+
5738+
if (isExtern)
5739+
{
5740+
externNames.add(mangledName);
5741+
}
5742+
else
5743+
{
5744+
allTypes[mangledName] = declRefType;
5745+
}
5746+
}
5747+
5748+
if (auto scopeDecl = as<ScopeDecl>(decl))
5749+
{
5750+
for (auto member : scopeDecl->members)
5751+
{
5752+
go(go, member);
5753+
}
5754+
}
5755+
};
5756+
5757+
for (const auto& m : linkage->loadedModulesList)
5758+
{
5759+
const auto& ast = m->getModuleDecl();
5760+
for (auto member : ast->members)
5761+
{
5762+
processDecl(processDecl, member);
5763+
}
5764+
}
5765+
5766+
// Only keep the types that have matching extern declarations
5767+
for (const auto& externName : externNames)
5768+
{
5769+
if (allTypes.containsKey(externName))
5770+
{
5771+
externTypeMap.value()[externName] = allTypes[externName];
5772+
}
5773+
}
5774+
}
5775+
56975776
} // namespace Slang

source/slang/slang-type-layout.h

+7
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,13 @@ struct TypeLayoutContext
11831183
// Options passed to object layout
11841184
ObjectLayoutRulesImpl::Options objectLayoutOptions;
11851185

1186+
// Mangled names to DeclRefType, this is used to match up 'extern' types to
1187+
// their linked in definitions during layout generation
1188+
std::optional<Dictionary<String, DeclRefType*>> externTypeMap;
1189+
1190+
DeclRefType* lookupExternDeclRefType(DeclRefType* declRefType);
1191+
void buildExternTypeMap();
1192+
11861193
LayoutRulesImpl* getRules() { return rules; }
11871194
LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); }
11881195

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#include "core/slang-blob.h"
2+
#include "gfx-test-util.h"
3+
#include "slang-gfx.h"
4+
#include "unit-test/slang-unit-test.h"
5+
6+
using namespace gfx;
7+
8+
namespace gfx_test
9+
{
10+
11+
static void diagnoseIfNeeded(Slang::ComPtr<slang::IBlob>& diagnosticsBlob)
12+
{
13+
if (diagnosticsBlob && diagnosticsBlob->getBufferSize() > 0)
14+
{
15+
fprintf(stderr, "%s\n", (const char*)diagnosticsBlob->getBufferPointer());
16+
}
17+
}
18+
19+
static Slang::Result loadSpirvProgram(
20+
gfx::IDevice* device,
21+
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
22+
slang::ProgramLayout*& slangReflection)
23+
{
24+
// main.slang: declares the interface and extern struct S, and the vertex shader.
25+
const char* mainSrc = R"(
26+
public interface IFoo
27+
{
28+
public float4 getFoo();
29+
};
30+
public extern struct S : IFoo;
31+
32+
[shader("vertex")]
33+
float4 vertexMain(S params) : SV_Position
34+
{
35+
return params.getFoo();
36+
}
37+
)";
38+
39+
// foo.slang: defines S with its field layout and its implementation of getFoo().
40+
const char* fooSrc = R"(
41+
import main;
42+
43+
export public struct S : IFoo
44+
{
45+
public float4 getFoo() { return this.foo; }
46+
float4 foo;
47+
}
48+
)";
49+
50+
Slang::ComPtr<slang::ISession> slangSession;
51+
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
52+
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
53+
54+
// Create blobs for the two modules.
55+
auto mainBlob = Slang::UnownedRawBlob::create(mainSrc, strlen(mainSrc));
56+
auto fooBlob = Slang::UnownedRawBlob::create(fooSrc, strlen(fooSrc));
57+
58+
// Load modules from source.
59+
slang::IModule* mainModule = slangSession->loadModuleFromSource("main", "main.slang", mainBlob);
60+
slang::IModule* fooModule = slangSession->loadModuleFromSource("foo", "foo.slang", fooBlob);
61+
62+
// Find the entry point from main.slang
63+
Slang::ComPtr<slang::IEntryPoint> vsEntryPoint;
64+
SLANG_RETURN_ON_FAIL(mainModule->findEntryPointByName("vertexMain", vsEntryPoint.writeRef()));
65+
66+
// Compose the program from both modules and the entry point.
67+
Slang::List<slang::IComponentType*> componentTypes;
68+
componentTypes.add(mainModule);
69+
componentTypes.add(fooModule);
70+
componentTypes.add(vsEntryPoint);
71+
72+
Slang::ComPtr<slang::IComponentType> composedProgram;
73+
SLANG_RETURN_ON_FAIL(slangSession->createCompositeComponentType(
74+
componentTypes.getBuffer(),
75+
componentTypes.getCount(),
76+
composedProgram.writeRef(),
77+
diagnosticsBlob.writeRef()));
78+
diagnoseIfNeeded(diagnosticsBlob);
79+
80+
// Link the composite program.
81+
Slang::ComPtr<slang::IComponentType> linkedProgram;
82+
SLANG_RETURN_ON_FAIL(
83+
composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()));
84+
diagnoseIfNeeded(diagnosticsBlob);
85+
86+
// Retrieve the reflection information.
87+
composedProgram = linkedProgram;
88+
slangReflection = composedProgram->getLayout();
89+
90+
// Create a shader program that will generate SPIRV code.
91+
gfx::IShaderProgram::Desc programDesc = {};
92+
programDesc.slangGlobalScope = composedProgram.get();
93+
auto shaderProgram = device->createProgram(programDesc);
94+
outShaderProgram = shaderProgram;
95+
96+
// Force SPIRV generation by explicitly requesting it
97+
Slang::ComPtr<slang::IBlob> spirvBlob;
98+
Slang::ComPtr<slang::IBlob> spirvDiagnostics;
99+
100+
// Request SPIRV code generation for the vertex shader entry point
101+
auto targetIndex = 0; // Assuming this is the first/only target
102+
auto entryPointIndex = 0; // Assuming this is the first/only entry point
103+
104+
auto result = composedProgram->getEntryPointCode(
105+
entryPointIndex,
106+
targetIndex,
107+
spirvBlob.writeRef(),
108+
spirvDiagnostics.writeRef());
109+
110+
if (SLANG_FAILED(result))
111+
{
112+
if (spirvDiagnostics && spirvDiagnostics->getBufferSize() > 0)
113+
{
114+
fprintf(
115+
stderr,
116+
"SPIRV generation failed: %s\n",
117+
(const char*)spirvDiagnostics->getBufferPointer());
118+
}
119+
return result;
120+
}
121+
122+
// Verify we actually got SPIRV code
123+
if (!spirvBlob || spirvBlob->getBufferSize() == 0)
124+
{
125+
return SLANG_FAIL;
126+
}
127+
128+
return SLANG_OK;
129+
}
130+
131+
// Function to validate the type layout of struct S
132+
static void validateStructSLayout(UnitTestContext* context, slang::ProgramLayout* slangReflection)
133+
{
134+
// Check reflection is available
135+
SLANG_CHECK(slangReflection != nullptr);
136+
137+
// Get the entry point layout for vertexMain
138+
auto entryPointCount = slangReflection->getEntryPointCount();
139+
slang::EntryPointLayout* entryPointLayout = nullptr;
140+
141+
for (unsigned int i = 0; i < entryPointCount; i++)
142+
{
143+
auto currentEntryPoint = slangReflection->getEntryPointByIndex(i);
144+
const char* name = currentEntryPoint->getName();
145+
146+
if (strcmp(name, "vertexMain") == 0)
147+
{
148+
entryPointLayout = currentEntryPoint;
149+
break;
150+
}
151+
}
152+
153+
SLANG_CHECK_MSG(entryPointLayout != nullptr, "Could not find vertexMain entry point");
154+
155+
// Get the parameter count for the entry point
156+
auto paramCount = entryPointLayout->getParameterCount();
157+
SLANG_CHECK_MSG(paramCount >= 1, "Entry point has no parameters");
158+
159+
// Get the first parameter, which should be of type S
160+
auto paramLayout = entryPointLayout->getParameterByIndex(0);
161+
SLANG_CHECK_MSG(paramLayout != nullptr, "Could not get first parameter layout");
162+
163+
// Get the type layout of the parameter
164+
auto typeLayout = paramLayout->getTypeLayout();
165+
SLANG_CHECK_MSG(typeLayout != nullptr, "Parameter has no type layout");
166+
167+
// Check if it's a struct type
168+
auto kind = typeLayout->getKind();
169+
SLANG_CHECK_MSG(kind == slang::TypeReflection::Kind::Struct, "Parameter is not a struct type");
170+
171+
// Get the field count
172+
auto fieldCount = typeLayout->getFieldCount();
173+
SLANG_CHECK_MSG(fieldCount >= 1, "Struct has no fields");
174+
175+
// Check for the 'foo' field
176+
bool foundFooField = false;
177+
for (unsigned int i = 0; i < fieldCount; i++)
178+
{
179+
auto fieldLayout = typeLayout->getFieldByIndex(i);
180+
const char* fieldName = fieldLayout->getName();
181+
182+
if (fieldName && strcmp(fieldName, "foo") == 0)
183+
{
184+
foundFooField = true;
185+
186+
// Check that it's a float4 type
187+
auto fieldTypeLayout = fieldLayout->getTypeLayout();
188+
auto fieldTypeKind = fieldTypeLayout->getKind();
189+
190+
SLANG_CHECK_MSG(
191+
fieldTypeKind == slang::TypeReflection::Kind::Vector,
192+
"Field 'foo' is not a vector type");
193+
194+
auto elementCount = fieldTypeLayout->getElementCount();
195+
SLANG_CHECK_MSG(elementCount == 4, "Field 'foo' is not a 4-element vector");
196+
197+
break;
198+
}
199+
}
200+
201+
SLANG_CHECK_MSG(foundFooField, "Could not find field 'foo' in struct S");
202+
}
203+
204+
void linkTimeTypeLayoutImpl(gfx::IDevice* device, UnitTestContext* context)
205+
{
206+
Slang::ComPtr<gfx::IShaderProgram> shaderProgram;
207+
slang::ProgramLayout* slangReflection = nullptr;
208+
209+
auto result = loadSpirvProgram(device, shaderProgram, slangReflection);
210+
SLANG_CHECK(SLANG_SUCCEEDED(result));
211+
212+
// Validate the struct S layout
213+
validateStructSLayout(context, slangReflection);
214+
215+
// Create a graphics pipeline to verify SPIRV code generation works
216+
GraphicsPipelineStateDesc pipelineDesc = {};
217+
pipelineDesc.program = shaderProgram.get();
218+
219+
// We need to set up a minimal pipeline state for a vertex shader
220+
pipelineDesc.primitiveType = PrimitiveType::Triangle;
221+
222+
ComPtr<gfx::IPipelineState> pipelineState;
223+
auto pipelineResult =
224+
device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef());
225+
SLANG_CHECK(SLANG_SUCCEEDED(pipelineResult));
226+
}
227+
228+
//
229+
// This test verifies that type layout information correctly propagates through
230+
// the Slang compilation pipeline when types are defined in modules other than where they are used.
231+
// Specifically, it tests
232+
// that when using an extern struct that's defined in a separate module:
233+
//
234+
// 1. The struct definition is properly linked across module boundaries
235+
// 2. The complete type layout information is available in the reflection data
236+
// 3. SPIRV code generation succeeds with the linked type information (this
237+
// failed before when layout information was required during code generation)
238+
//
239+
240+
SLANG_UNIT_TEST(linkTimeTypeLayout)
241+
{
242+
runTestImpl(linkTimeTypeLayoutImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
243+
}
244+
245+
} // namespace gfx_test

0 commit comments

Comments
 (0)