Skip to content

Commit 0634684

Browse files
cheneym2slangbot
andauthored
Support SPIR-V deferred linking option (#6500)
The new option "SkipDownstreamLinking" will defer final downstream IR linking to the user application. This option only has an effect if there are modules that were precompiled to the target IR using precompileForTarget(). Until now, the default behavior for SPIR-V was to use deferred linking, and the default behavior for DXIL was to use immediate/internal linking in Slang. This change only affects the SPIR-V behavior such that both deferred and non-deferred linking is supported based on the new option. To support the non-deferred option, Slang will internally call into SPIRV-Tools-link to reconstitute a complete SPIR-V shader program when necessary (due to modules having been precompiled to target IR). Otherwise, if SkipDownstreamLinking is enabled, the shader returned by e.g. getTargetCode() or getEntryPointCode() may have import linkage to the SPIR-V embedded in the constituent modules. Closes #4994 Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
1 parent 5248a02 commit 0634684

15 files changed

+245
-24
lines changed

include/slang-gfx.h

+9
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ class IShaderProgram : public ISlangUnknown
163163
SeparateEntryPointCompilation
164164
};
165165

166+
enum class DownstreamLinkMode
167+
{
168+
None,
169+
Deferred,
170+
};
171+
166172
struct Desc
167173
{
168174
// TODO: Tess doesn't like this but doesn't know what to do about it
@@ -180,6 +186,9 @@ class IShaderProgram : public ISlangUnknown
180186
// An array of Slang entry points. The size of the array must be `entryPointCount`.
181187
// Each element must define only 1 Slang EntryPoint.
182188
slang::IComponentType** slangEntryPoints = nullptr;
189+
190+
// Indicates whether the app is responsible for final downstream linking.
191+
DownstreamLinkMode downstreamLinkMode = DownstreamLinkMode::None;
183192
};
184193

185194
struct CreateDesc2

include/slang.h

+3
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ typedef uint32_t SlangSizeT;
653653
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
654654
SLANG_PASS_THROUGH_METAL, ///< Metal compiler
655655
SLANG_PASS_THROUGH_TINT, ///< Tint WGSL compiler
656+
SLANG_PASS_THROUGH_SPIRV_LINK, ///< SPIRV-link
656657
SLANG_PASS_THROUGH_COUNT_OF,
657658
};
658659

@@ -1008,6 +1009,8 @@ typedef uint32_t SlangSizeT;
10081009

10091010
EmitReflectionJSON, // bool
10101011
SaveGLSLModuleBinSource,
1012+
1013+
SkipDownstreamLinking, // bool, experimental
10111014
CountOf,
10121015
};
10131016

source/compiler-core/slang-downstream-compiler.h

+13
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,19 @@ class IDownstreamCompiler : public ICastable
343343

344344
/// True if underlying compiler uses file system to communicate source
345345
virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() = 0;
346+
347+
virtual SLANG_NO_THROW int SLANG_MCALL link(
348+
const uint32_t** modules,
349+
const uint32_t* moduleSizes,
350+
const uint32_t moduleCount,
351+
IArtifact** outArtifact)
352+
{
353+
SLANG_UNREFERENCED_PARAMETER(modules);
354+
SLANG_UNREFERENCED_PARAMETER(moduleSizes);
355+
SLANG_UNREFERENCED_PARAMETER(moduleCount);
356+
SLANG_UNREFERENCED_PARAMETER(outArtifact);
357+
return 0;
358+
}
346359
};
347360

348361
class DownstreamCompilerBase : public ComBaseObject, public IDownstreamCompiler

source/compiler-core/slang-glslang-compiler.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase
4949
validate(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
5050
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
5151
disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
52+
int link(
53+
const uint32_t** modules,
54+
const uint32_t* moduleSizes,
55+
const uint32_t moduleCount,
56+
IArtifact** outArtifact) SLANG_OVERRIDE;
5257

5358
/// Must be called before use
5459
SlangResult init(ISlangSharedLibrary* library);
@@ -66,6 +71,7 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase
6671
glslang_CompileFunc_1_2 m_compile_1_2 = nullptr;
6772
glslang_ValidateSPIRVFunc m_validate = nullptr;
6873
glslang_DisassembleSPIRVFunc m_disassemble = nullptr;
74+
glslang_LinkSPIRVFunc m_link = nullptr;
6975

7076
ComPtr<ISlangSharedLibrary> m_sharedLibrary;
7177

@@ -80,6 +86,7 @@ SlangResult GlslangDownstreamCompiler::init(ISlangSharedLibrary* library)
8086
m_validate = (glslang_ValidateSPIRVFunc)library->findFuncByName("glslang_validateSPIRV");
8187
m_disassemble =
8288
(glslang_DisassembleSPIRVFunc)library->findFuncByName("glslang_disassembleSPIRV");
89+
m_link = (glslang_LinkSPIRVFunc)library->findFuncByName("glslang_linkSPIRV");
8390

8491
if (m_compile_1_0 == nullptr && m_compile_1_1 == nullptr && m_compile_1_2 == nullptr)
8592
{
@@ -323,6 +330,32 @@ SlangResult GlslangDownstreamCompiler::disassemble(const uint32_t* contents, int
323330
return SLANG_FAIL;
324331
}
325332

333+
SlangResult GlslangDownstreamCompiler::link(
334+
const uint32_t** modules,
335+
const uint32_t* moduleSizes,
336+
const uint32_t moduleCount,
337+
IArtifact** outArtifact)
338+
{
339+
glslang_LinkRequest request;
340+
memset(&request, 0, sizeof(request));
341+
342+
request.modules = modules;
343+
request.moduleSizes = moduleSizes;
344+
request.moduleCount = moduleCount;
345+
346+
if (!m_link(&request))
347+
{
348+
return SLANG_FAIL;
349+
}
350+
351+
auto artifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_SPIRV);
352+
artifact->addRepresentationUnknown(
353+
Slang::RawBlob::create(request.linkResult, request.linkResultSize * sizeof(uint32_t)));
354+
355+
*outArtifact = artifact.detach();
356+
return SLANG_OK;
357+
}
358+
326359
bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to)
327360
{
328361
// Can only disassemble blobs that are SPIR-V
@@ -467,6 +500,14 @@ SlangResult SpirvDisDownstreamCompilerUtil::locateCompilers(
467500
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_DIS);
468501
}
469502

503+
SlangResult SpirvLinkDownstreamCompilerUtil::locateCompilers(
504+
const String& path,
505+
ISlangSharedLibraryLoader* loader,
506+
DownstreamCompilerSet* set)
507+
{
508+
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_LINK);
509+
}
510+
470511
#else // SLANG_ENABLE_GLSLANG_SUPPORT
471512

472513
/* static */ SlangResult GlslangDownstreamCompilerUtil::locateCompilers(

source/compiler-core/slang-glslang-compiler.h

+8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ struct SpirvDisDownstreamCompilerUtil
3232
DownstreamCompilerSet* set);
3333
};
3434

35+
struct SpirvLinkDownstreamCompilerUtil
36+
{
37+
static SlangResult locateCompilers(
38+
const String& path,
39+
ISlangSharedLibraryLoader* loader,
40+
DownstreamCompilerSet* set);
41+
};
42+
3543
} // namespace Slang
3644

3745
#endif

source/slang-glslang/slang-glslang.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ extern "C"
10371037
request->linkResultSize = linkedBinary.size();
10381038
}
10391039

1040-
return success;
1040+
return success == SPV_SUCCESS;
10411041
}
10421042
catch (...)
10431043
{

source/slang/slang-compiler.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -2669,6 +2669,12 @@ bool CodeGenContext::shouldDumpIR()
26692669
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
26702670
}
26712671

2672+
bool CodeGenContext::shouldSkipDownstreamLinking()
2673+
{
2674+
return getTargetProgram()->getOptionSet().getBoolOption(
2675+
CompilerOptionName::SkipDownstreamLinking);
2676+
}
2677+
26722678
bool CodeGenContext::shouldReportCheckpointIntermediates()
26732679
{
26742680
return getTargetProgram()->getOptionSet().getBoolOption(

source/slang/slang-compiler.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,8 @@ enum class PassThroughMode : SlangPassThroughIntegral
13841384
LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler'
13851385
SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt
13861386
MetalC = SLANG_PASS_THROUGH_METAL,
1387-
Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
1387+
Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
1388+
SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link
13881389
CountOf = SLANG_PASS_THROUGH_COUNT_OF,
13891390
};
13901391
void printDiagnosticArg(StringBuilder& sb, PassThroughMode val);
@@ -2886,6 +2887,12 @@ struct CodeGenContext
28862887
// removed between IR linking and target source generation.
28872888
bool removeAvailableInDownstreamIR = false;
28882889

2890+
// Determines if program level compilation like getTargetCode() or getEntryPointCode()
2891+
// should return a fully linked downstream program or just the glue SPIR-V/DXIL that
2892+
// imports and uses the precompiled SPIR-V/DXIL from constituent modules.
2893+
// This is a no-op if modules are not precompiled.
2894+
bool shouldSkipDownstreamLinking();
2895+
28892896
protected:
28902897
CodeGenTarget m_targetFormat = CodeGenTarget::Unknown;
28912898
ExtensionTracker* m_extensionTracker = nullptr;

source/slang/slang-emit.cpp

+62-1
Original file line numberDiff line numberDiff line change
@@ -2093,10 +2093,71 @@ SlangResult emitSPIRVForEntryPointsDirectly(
20932093
if (compiler)
20942094
{
20952095
#if 0
2096-
// Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV
2096+
// Dump the unoptimized/unlinked SPIRV after lowering from slang IR -> SPIRV
20972097
compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4));
20982098
#endif
20992099

2100+
bool isPrecompilation = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
2101+
CompilerOptionName::EmbedDownstreamIR);
2102+
2103+
if (!isPrecompilation && !codeGenContext->shouldSkipDownstreamLinking())
2104+
{
2105+
ComPtr<IArtifact> linkedArtifact;
2106+
2107+
// collect spirv files
2108+
List<uint32_t*> spirvFiles;
2109+
List<uint32_t> spirvSizes;
2110+
2111+
// Start with the SPIR-V we just generated.
2112+
// SPIRV-Tools-link expects the size in 32-bit words
2113+
// whereas the spirv blob size is in bytes.
2114+
spirvFiles.add((uint32_t*)spirv.getBuffer());
2115+
spirvSizes.add(int(spirv.getCount()) / 4);
2116+
2117+
// Iterate over all modules in the linkedIR. For each module, if it
2118+
// contains an embedded downstream ir instruction, add it to the list
2119+
// of spirv files.
2120+
auto program = codeGenContext->getProgram();
2121+
2122+
program->enumerateIRModules(
2123+
[&](IRModule* irModule)
2124+
{
2125+
for (auto globalInst : irModule->getModuleInst()->getChildren())
2126+
{
2127+
if (auto inst = as<IREmbeddedDownstreamIR>(globalInst))
2128+
{
2129+
if (inst->getTarget() == CodeGenTarget::SPIRV)
2130+
{
2131+
auto slice = inst->getBlob()->getStringSlice();
2132+
spirvFiles.add((uint32_t*)slice.begin());
2133+
spirvSizes.add(int(slice.getLength()) / 4);
2134+
}
2135+
}
2136+
}
2137+
});
2138+
2139+
SLANG_ASSERT(int(spirv.getCount()) % 4 == 0);
2140+
SLANG_ASSERT(spirvFiles.getCount() == spirvSizes.getCount());
2141+
2142+
if (spirvFiles.getCount() > 1)
2143+
{
2144+
SlangResult linkresult = compiler->link(
2145+
(const uint32_t**)spirvFiles.getBuffer(),
2146+
(const uint32_t*)spirvSizes.getBuffer(),
2147+
(uint32_t)spirvFiles.getCount(),
2148+
linkedArtifact.writeRef());
2149+
2150+
if (linkresult != SLANG_OK)
2151+
{
2152+
return SLANG_FAIL;
2153+
}
2154+
2155+
ComPtr<ISlangBlob> blob;
2156+
linkedArtifact->loadBlob(ArtifactKeep::No, blob.writeRef());
2157+
artifact = _Move(linkedArtifact);
2158+
}
2159+
}
2160+
21002161
if (!codeGenContext->shouldSkipSPIRVValidation())
21012162
{
21022163
StringBuilder runSpirvValEnvVar;

tools/gfx-unit-test/gfx-test-util.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ Slang::Result loadComputeProgram(
8080
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
8181
const char* shaderModuleName,
8282
const char* entryPointName,
83-
slang::ProgramLayout*& slangReflection)
83+
slang::ProgramLayout*& slangReflection,
84+
PrecompilationMode precompilationMode)
8485
{
8586
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
8687
slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
@@ -115,6 +116,14 @@ Slang::Result loadComputeProgram(
115116

116117
gfx::IShaderProgram::Desc programDesc = {};
117118
programDesc.slangGlobalScope = composedProgram.get();
119+
if (precompilationMode == PrecompilationMode::ExternalLink)
120+
{
121+
programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::Deferred;
122+
}
123+
else
124+
{
125+
programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::None;
126+
}
118127

119128
auto shaderProgram = device->createProgram(programDesc);
120129

tools/gfx-unit-test/gfx-test-util.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77

88
namespace gfx_test
99
{
10+
enum class PrecompilationMode
11+
{
12+
None,
13+
SlangIR,
14+
InternalLink,
15+
ExternalLink,
16+
};
1017
/// Helper function for print out diagnostic messages output by Slang compiler.
1118
void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob);
1219

@@ -24,7 +31,8 @@ Slang::Result loadComputeProgram(
2431
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
2532
const char* shaderModuleName,
2633
const char* entryPointName,
27-
slang::ProgramLayout*& slangReflection);
34+
slang::ProgramLayout*& slangReflection,
35+
PrecompilationMode precompilationMode = PrecompilationMode::None);
2836

2937
Slang::Result loadComputeProgramFromSource(
3038
gfx::IDevice* device,

0 commit comments

Comments
 (0)