Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SPIR-V deferred linking option #6500

Merged
merged 10 commits into from
Mar 5, 2025
9 changes: 9 additions & 0 deletions include/slang-gfx.h
Original file line number Diff line number Diff line change
@@ -163,6 +163,12 @@ class IShaderProgram : public ISlangUnknown
SeparateEntryPointCompilation
};

enum class DownstreamLinkMode
{
None,
Deferred,
};

struct Desc
{
// TODO: Tess doesn't like this but doesn't know what to do about it
@@ -180,6 +186,9 @@ class IShaderProgram : public ISlangUnknown
// An array of Slang entry points. The size of the array must be `entryPointCount`.
// Each element must define only 1 Slang EntryPoint.
slang::IComponentType** slangEntryPoints = nullptr;

// Indicates whether the app is responsible for final downstream linking.
DownstreamLinkMode downstreamLinkMode = DownstreamLinkMode::None;
};

struct CreateDesc2
3 changes: 3 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
@@ -653,6 +653,7 @@ typedef uint32_t SlangSizeT;
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
SLANG_PASS_THROUGH_METAL, ///< Metal compiler
SLANG_PASS_THROUGH_TINT, ///< Tint WGSL compiler
SLANG_PASS_THROUGH_SPIRV_LINK, ///< SPIRV-link
SLANG_PASS_THROUGH_COUNT_OF,
};

@@ -1008,6 +1009,8 @@ typedef uint32_t SlangSizeT;

EmitReflectionJSON, // bool
SaveGLSLModuleBinSource,

SkipDownstreamLinking, // bool, experimental
CountOf,
};

13 changes: 13 additions & 0 deletions source/compiler-core/slang-downstream-compiler.h
Original file line number Diff line number Diff line change
@@ -343,6 +343,19 @@ class IDownstreamCompiler : public ICastable

/// True if underlying compiler uses file system to communicate source
virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() = 0;

virtual SLANG_NO_THROW int SLANG_MCALL link(
const uint32_t** modules,
const uint32_t* moduleSizes,
const uint32_t moduleCount,
IArtifact** outArtifact)
{
SLANG_UNREFERENCED_PARAMETER(modules);
SLANG_UNREFERENCED_PARAMETER(moduleSizes);
SLANG_UNREFERENCED_PARAMETER(moduleCount);
SLANG_UNREFERENCED_PARAMETER(outArtifact);
return 0;
}
};

class DownstreamCompilerBase : public ComBaseObject, public IDownstreamCompiler
41 changes: 41 additions & 0 deletions source/compiler-core/slang-glslang-compiler.cpp
Original file line number Diff line number Diff line change
@@ -49,6 +49,11 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase
validate(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
int link(
const uint32_t** modules,
const uint32_t* moduleSizes,
const uint32_t moduleCount,
IArtifact** outArtifact) SLANG_OVERRIDE;

/// Must be called before use
SlangResult init(ISlangSharedLibrary* library);
@@ -66,6 +71,7 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase
glslang_CompileFunc_1_2 m_compile_1_2 = nullptr;
glslang_ValidateSPIRVFunc m_validate = nullptr;
glslang_DisassembleSPIRVFunc m_disassemble = nullptr;
glslang_LinkSPIRVFunc m_link = nullptr;

ComPtr<ISlangSharedLibrary> m_sharedLibrary;

@@ -80,6 +86,7 @@ SlangResult GlslangDownstreamCompiler::init(ISlangSharedLibrary* library)
m_validate = (glslang_ValidateSPIRVFunc)library->findFuncByName("glslang_validateSPIRV");
m_disassemble =
(glslang_DisassembleSPIRVFunc)library->findFuncByName("glslang_disassembleSPIRV");
m_link = (glslang_LinkSPIRVFunc)library->findFuncByName("glslang_linkSPIRV");

if (m_compile_1_0 == nullptr && m_compile_1_1 == nullptr && m_compile_1_2 == nullptr)
{
@@ -323,6 +330,32 @@ SlangResult GlslangDownstreamCompiler::disassemble(const uint32_t* contents, int
return SLANG_FAIL;
}

SlangResult GlslangDownstreamCompiler::link(
const uint32_t** modules,
const uint32_t* moduleSizes,
const uint32_t moduleCount,
IArtifact** outArtifact)
{
glslang_LinkRequest request;
memset(&request, 0, sizeof(request));

request.modules = modules;
request.moduleSizes = moduleSizes;
request.moduleCount = moduleCount;

if (!m_link(&request))
{
return SLANG_FAIL;
}

auto artifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_SPIRV);
artifact->addRepresentationUnknown(
Slang::RawBlob::create(request.linkResult, request.linkResultSize * sizeof(uint32_t)));

*outArtifact = artifact.detach();
return SLANG_OK;
}

bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to)
{
// Can only disassemble blobs that are SPIR-V
@@ -467,6 +500,14 @@ SlangResult SpirvDisDownstreamCompilerUtil::locateCompilers(
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_DIS);
}

SlangResult SpirvLinkDownstreamCompilerUtil::locateCompilers(
const String& path,
ISlangSharedLibraryLoader* loader,
DownstreamCompilerSet* set)
{
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_LINK);
}

#else // SLANG_ENABLE_GLSLANG_SUPPORT

/* static */ SlangResult GlslangDownstreamCompilerUtil::locateCompilers(
8 changes: 8 additions & 0 deletions source/compiler-core/slang-glslang-compiler.h
Original file line number Diff line number Diff line change
@@ -32,6 +32,14 @@ struct SpirvDisDownstreamCompilerUtil
DownstreamCompilerSet* set);
};

struct SpirvLinkDownstreamCompilerUtil
{
static SlangResult locateCompilers(
const String& path,
ISlangSharedLibraryLoader* loader,
DownstreamCompilerSet* set);
};

} // namespace Slang

#endif
2 changes: 1 addition & 1 deletion source/slang-glslang/slang-glslang.cpp
Original file line number Diff line number Diff line change
@@ -1037,7 +1037,7 @@ extern "C"
request->linkResultSize = linkedBinary.size();
}

return success;
return success == SPV_SUCCESS;
}
catch (...)
{
6 changes: 6 additions & 0 deletions source/slang/slang-compiler.cpp
Original file line number Diff line number Diff line change
@@ -2669,6 +2669,12 @@ bool CodeGenContext::shouldDumpIR()
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}

bool CodeGenContext::shouldSkipDownstreamLinking()
{
return getTargetProgram()->getOptionSet().getBoolOption(
CompilerOptionName::SkipDownstreamLinking);
}

bool CodeGenContext::shouldReportCheckpointIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(
9 changes: 8 additions & 1 deletion source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
@@ -1384,7 +1384,8 @@ enum class PassThroughMode : SlangPassThroughIntegral
LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler'
SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt
MetalC = SLANG_PASS_THROUGH_METAL,
Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link
CountOf = SLANG_PASS_THROUGH_COUNT_OF,
};
void printDiagnosticArg(StringBuilder& sb, PassThroughMode val);
@@ -2886,6 +2887,12 @@ struct CodeGenContext
// removed between IR linking and target source generation.
bool removeAvailableInDownstreamIR = false;

// Determines if program level compilation like getTargetCode() or getEntryPointCode()
// should return a fully linked downstream program or just the glue SPIR-V/DXIL that
// imports and uses the precompiled SPIR-V/DXIL from constituent modules.
// This is a no-op if modules are not precompiled.
bool shouldSkipDownstreamLinking();

protected:
CodeGenTarget m_targetFormat = CodeGenTarget::Unknown;
ExtensionTracker* m_extensionTracker = nullptr;
63 changes: 62 additions & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
@@ -2093,10 +2093,71 @@ SlangResult emitSPIRVForEntryPointsDirectly(
if (compiler)
{
#if 0
// Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV
// Dump the unoptimized/unlinked SPIRV after lowering from slang IR -> SPIRV
compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4));
#endif

bool isPrecompilation = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
CompilerOptionName::EmbedDownstreamIR);

if (!isPrecompilation && !codeGenContext->shouldSkipDownstreamLinking())
{
ComPtr<IArtifact> linkedArtifact;

// collect spirv files
List<uint32_t*> spirvFiles;
List<uint32_t> spirvSizes;

// Start with the SPIR-V we just generated.
// SPIRV-Tools-link expects the size in 32-bit words
// whereas the spirv blob size is in bytes.
spirvFiles.add((uint32_t*)spirv.getBuffer());
spirvSizes.add(int(spirv.getCount()) / 4);

// Iterate over all modules in the linkedIR. For each module, if it
// contains an embedded downstream ir instruction, add it to the list
// of spirv files.
auto program = codeGenContext->getProgram();

program->enumerateIRModules(
[&](IRModule* irModule)
{
for (auto globalInst : irModule->getModuleInst()->getChildren())
{
if (auto inst = as<IREmbeddedDownstreamIR>(globalInst))
{
if (inst->getTarget() == CodeGenTarget::SPIRV)
{
auto slice = inst->getBlob()->getStringSlice();
spirvFiles.add((uint32_t*)slice.begin());
spirvSizes.add(int(slice.getLength()) / 4);
}
}
}
});

SLANG_ASSERT(int(spirv.getCount()) % 4 == 0);
SLANG_ASSERT(spirvFiles.getCount() == spirvSizes.getCount());

if (spirvFiles.getCount() > 1)
{
SlangResult linkresult = compiler->link(
(const uint32_t**)spirvFiles.getBuffer(),
(const uint32_t*)spirvSizes.getBuffer(),
(uint32_t)spirvFiles.getCount(),
linkedArtifact.writeRef());

if (linkresult != SLANG_OK)
{
return SLANG_FAIL;
}

ComPtr<ISlangBlob> blob;
linkedArtifact->loadBlob(ArtifactKeep::No, blob.writeRef());
artifact = _Move(linkedArtifact);
}
}

if (!codeGenContext->shouldSkipSPIRVValidation())
{
StringBuilder runSpirvValEnvVar;
11 changes: 10 additions & 1 deletion tools/gfx-unit-test/gfx-test-util.cpp
Original file line number Diff line number Diff line change
@@ -80,7 +80,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
slang::ProgramLayout*& slangReflection)
slang::ProgramLayout*& slangReflection,
PrecompilationMode precompilationMode)
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
@@ -115,6 +116,14 @@ Slang::Result loadComputeProgram(

gfx::IShaderProgram::Desc programDesc = {};
programDesc.slangGlobalScope = composedProgram.get();
if (precompilationMode == PrecompilationMode::ExternalLink)
{
programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::Deferred;
}
else
{
programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::None;
}

auto shaderProgram = device->createProgram(programDesc);

10 changes: 9 additions & 1 deletion tools/gfx-unit-test/gfx-test-util.h
Original file line number Diff line number Diff line change
@@ -7,6 +7,13 @@

namespace gfx_test
{
enum class PrecompilationMode
{
None,
SlangIR,
InternalLink,
ExternalLink,
};
/// Helper function for print out diagnostic messages output by Slang compiler.
void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob);

@@ -24,7 +31,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
slang::ProgramLayout*& slangReflection);
slang::ProgramLayout*& slangReflection,
PrecompilationMode precompilationMode = PrecompilationMode::None);

Slang::Result loadComputeProgramFromSource(
gfx::IDevice* device,
Loading
Loading