Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix precompiledTargetModule tests #6455

Merged
merged 10 commits into from
Feb 27, 2025
2 changes: 1 addition & 1 deletion source/slang-glslang/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ if(SLANG_ENABLE_SLANG_GLSLANG)
.
MODULE
USE_FEWER_WARNINGS
LINK_WITH_PRIVATE glslang SPIRV SPIRV-Tools-opt
LINK_WITH_PRIVATE glslang SPIRV SPIRV-Tools-opt SPIRV-Tools-link
INCLUDE_DIRECTORIES_PRIVATE ${slang_SOURCE_DIR}/include
INSTALL
EXPORT_SET_NAME SlangTargets
65 changes: 65 additions & 0 deletions source/slang-glslang/slang-glslang.cpp
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
#include "glslang/Public/ShaderLang.h"
#include "slang.h"
#include "spirv-tools/libspirv.h"
#include "spirv-tools/linker.hpp"
#include "spirv-tools/optimizer.hpp"

#ifdef _WIN32
@@ -979,3 +980,67 @@ extern "C"
request.set(*inRequest);
return glslang_compile_1_1(&request);
}

extern "C"
#ifdef _MSC_VER
_declspec(dllexport)
#else
__attribute__((__visibility__("default")))
#endif
int glslang_linkSPIRV(glslang_LinkRequest* request)
{
if (!request || !request->modules || request->linkResult)
return false;

try
{
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
spvtools::LinkerOptions options = {};

spvtools::MessageConsumer consumer = [](spv_message_level_t level,
const char* source,
const spv_position_t& position,
const char* message)
{
printf("SPIRV-TOOLS: %s\n", message);
printf("SPIRV-TOOLS: %s\n", source);
printf("SPIRV-TOOLS: %zu:%zu\n", position.index, position.column);
};
context.SetMessageConsumer(consumer);

std::vector<std::vector<uint32_t>> moduleVecs(request->moduleCount);
std::vector<const uint32_t*> moduleData(request->moduleCount);
std::vector<size_t> moduleSizes(request->moduleCount);

for (size_t i = 0; i < request->moduleCount; ++i)
{
moduleData[i] = request->modules[i];
moduleSizes[i] = request->moduleSizes[i];
}

std::vector<uint32_t> linkedBinary;
spv_result_t success = spvtools::Link(
context,
moduleData.data(),
moduleSizes.data(),
request->moduleCount,
&linkedBinary,
options);

if (success == SPV_SUCCESS)
{
request->linkResult = new uint32_t[linkedBinary.size()];
memcpy(
(void*)request->linkResult,
linkedBinary.data(),
linkedBinary.size() * sizeof(uint32_t));
request->linkResultSize = linkedBinary.size();
}

return success;
}
catch (...)
{
return false;
}
}
11 changes: 10 additions & 1 deletion source/slang-glslang/slang-glslang.h
Original file line number Diff line number Diff line change
@@ -152,10 +152,19 @@ inline void glslang_CompileRequest_1_2::set(const glslang_CompileRequest_1_1& in
memcpy(this, &in, sizeof(in));
}

typedef struct glslang_LinkRequest_t
{
const uint32_t** modules; // Input: array of pointers to SPIR-V modules
const uint32_t* moduleSizes; // Input: array of sizes of SPIR-V modules in 32-bit words
int moduleCount; // Input: number of modules in the array
const uint32_t* linkResult; // Output: pointer to linked SPIR-V module
size_t linkResultSize; // Output: size of the linked SPIR-V module in 32-bit words
} glslang_LinkRequest;

typedef int (*glslang_CompileFunc_1_0)(glslang_CompileRequest_1_0* request);
typedef int (*glslang_CompileFunc_1_1)(glslang_CompileRequest_1_1* request);
typedef int (*glslang_CompileFunc_1_2)(glslang_CompileRequest_1_2* request);
typedef bool (*glslang_ValidateSPIRVFunc)(const uint32_t* contents, int contentsSize);
typedef bool (*glslang_DisassembleSPIRVFunc)(const uint32_t* contents, int contentsSize);

typedef bool (*glslang_LinkSPIRVFunc)(glslang_LinkRequest* request);
#endif
1 change: 1 addition & 0 deletions source/slang/slang-compiler-tu.cpp
Original file line number Diff line number Diff line change
@@ -291,6 +291,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getModuleDependency(
{
return SLANG_E_INVALID_ARG;
}
getModuleDependencies()[dependencyIndex]->addRef();
*outModule = getModuleDependencies()[dependencyIndex];
return SLANG_OK;
}
3 changes: 1 addition & 2 deletions tests/expected-failure-github.txt
Original file line number Diff line number Diff line change
@@ -10,5 +10,4 @@ tests/autodiff/custom-intrinsic.slang.2 syn (wgpu)
tests/bugs/buffer-swizzle-store.slang.3 syn (wgpu)
tests/compute/interface-shader-param-in-struct.slang.4 syn (wgpu)
tests/compute/interface-shader-param.slang.5 syn (wgpu)
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)
1 change: 0 additions & 1 deletion tests/expected-failure.txt
Original file line number Diff line number Diff line change
@@ -5,4 +5,3 @@ tests/language-feature/saturated-cooperation/fuse.slang (vk)
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
tests/ir/loop-unroll-0.slang.1 (vk)
tests/hlsl-intrinsic/texture/float-atomics.slang (vk)
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal
2 changes: 1 addition & 1 deletion tools/gfx-unit-test/precompiled-module-2.cpp
Original file line number Diff line number Diff line change
@@ -112,7 +112,7 @@ void precompiledModule2TestImplCommon(
{
case gfx::DeviceType::DirectX12:
targetDesc.format = SLANG_DXIL;
targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("sm_6_1");
targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("sm_6_6");
break;
case gfx::DeviceType::Vulkan:
targetDesc.format = SLANG_SPIRV;
6 changes: 3 additions & 3 deletions tools/gfx/d3d12/d3d12-shader-program.cpp
Original file line number Diff line number Diff line change
@@ -10,14 +10,14 @@ using namespace Slang;

Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob>>& kernelCodes)
{
ShaderBinary shaderBin;
shaderBin.stage = entryPointInfo->getStage();
shaderBin.entryPointInfo = entryPointInfo;
shaderBin.code.addRange(
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
(Index)kernelCode->getBufferSize());
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
(Index)kernelCodes[0]->getBufferSize());
m_shaders.add(_Move(shaderBin));
return SLANG_OK;
}
2 changes: 1 addition & 1 deletion tools/gfx/d3d12/d3d12-shader-program.h
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
List<ComPtr<ISlangBlob>>& kernelCodes) override;
};

} // namespace d3d12
8 changes: 4 additions & 4 deletions tools/gfx/metal/metal-shader-program.cpp
Original file line number Diff line number Diff line change
@@ -21,16 +21,16 @@ ShaderProgramImpl::~ShaderProgramImpl() {}

Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob>>& kernelCodes)
{
Module module;
module.stage = entryPointInfo->getStage();
module.entryPointName = entryPointInfo->getNameOverride();
module.code = kernelCode;
module.code = kernelCodes[0];

dispatch_data_t data = dispatch_data_create(
kernelCode->getBufferPointer(),
kernelCode->getBufferSize(),
kernelCodes[0]->getBufferPointer(),
kernelCodes[0]->getBufferSize(),
dispatch_get_main_queue(),
NULL);
NS::Error* error;
2 changes: 1 addition & 1 deletion tools/gfx/metal/metal-shader-program.h
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
List<ComPtr<ISlangBlob>>& kernelCodes) override;
};


114 changes: 94 additions & 20 deletions tools/gfx/renderer-shared.cpp
Original file line number Diff line number Diff line change
@@ -1103,32 +1103,106 @@ void ShaderProgramBase::init(const IShaderProgram::Desc& inDesc)

Result ShaderProgramBase::compileShaders(RendererBase* device)
{
auto compileTarget = device->slangContext.compileTarget;
// For a fully specialized program, read and store its kernel code in `shaderProgram`.
auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
slang::IComponentType* entryPointComponent,
SlangInt entryPointIndex)
{
auto stage = entryPointInfo->getStage();
ComPtr<ISlangBlob> kernelCode;
ComPtr<ISlangBlob> diagnostics;
auto compileResult = device->getEntryPointCodeFromShaderCache(
entryPointComponent,
entryPointIndex,
0,
kernelCode.writeRef(),
diagnostics.writeRef());
if (diagnostics)
List<ComPtr<ISlangBlob>> kernelCodes;
{
DebugMessageType msgType = DebugMessageType::Warning;
if (compileResult != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnostics->getBufferPointer());
ComPtr<ISlangBlob> downstreamIR;
ComPtr<ISlangBlob> diagnostics;
auto compileResult = device->getEntryPointCodeFromShaderCache(
entryPointComponent,
entryPointIndex,
0,
downstreamIR.writeRef(),
diagnostics.writeRef());
if (diagnostics)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (compileResult != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnostics->getBufferPointer());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is error here, we should either return error, or not add downstreamIR into kernerlCodes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is error here, we should either return error, or not add downstreamIR into kernerlCodes.
https://github.com/shader-slang/slang/pull/6570/files

}
kernelCodes.add(downstreamIR);
}
SLANG_RETURN_ON_FAIL(compileResult);
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));

// If target precompilation was used, kernelCode may only represent the
// glue code holding together the bits of precompiled target IR.
// Collect those dependency target IRs too.
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
if (entryPointComponent->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
{
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
if (dependencyCount > 0)
{
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
{
ComPtr<slang::IModule> dependencyModule;
{
ComPtr<slang::IBlob> diagnosticsBlob;
auto result = componentPrecompileService->getModuleDependency(
dependencyIndex,
dependencyModule.writeRef(),
diagnosticsBlob.writeRef());
if (diagnosticsBlob)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (result != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnosticsBlob->getBufferPointer());
}
SLANG_RETURN_ON_FAIL(result);
}

ComPtr<slang::IBlob> downstreamIR;
{
ComPtr<slang::IBlob> diagnosticsBlob;
SlangResult result = SLANG_OK;
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
result = dependencyModule->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)precompileService.writeRef());
if (result == SLANG_OK)
{
ComPtr<slang::IBlob> diagnosticsBlob;
auto result = precompileService->getPrecompiledTargetCode(
compileTarget,
downstreamIR.writeRef(),
diagnosticsBlob.writeRef());
if (result == SLANG_OK)
{
kernelCodes.add(downstreamIR);
}
if (diagnosticsBlob)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (result != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnosticsBlob->getBufferPointer());
}
}
SLANG_RETURN_ON_FAIL(result);
}
}
}
}

SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
return SLANG_OK;
};

@@ -1160,10 +1234,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)

Result ShaderProgramBase::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob>>& kernelCodes)
{
SLANG_UNUSED(entryPointInfo);
SLANG_UNUSED(kernelCode);
SLANG_UNUSED(kernelCodes);
return SLANG_OK;
}

2 changes: 1 addition & 1 deletion tools/gfx/renderer-shared.h
Original file line number Diff line number Diff line change
@@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
Slang::Result compileShaders(RendererBase* device);
virtual Slang::Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
Slang::ComPtr<ISlangBlob> kernelCode);
Slang::List<Slang::ComPtr<ISlangBlob>>& kernelCodes);

virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
findTypeByName(const char* name) override
2 changes: 2 additions & 0 deletions tools/gfx/slang-context.h
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ class SlangContext
public:
Slang::ComPtr<slang::IGlobalSession> globalSession;
Slang::ComPtr<slang::ISession> session;
SlangCompileTarget compileTarget;
Result initialize(
const gfx::IDevice::SlangDesc& desc,
uint32_t extendedDescCount,
@@ -27,6 +28,7 @@ class SlangContext
SLANG_RETURN_ON_FAIL(slang::createGlobalSession(globalSession.writeRef()));
}

this->compileTarget = compileTarget;
slang::SessionDesc slangSessionDesc = {};
slangSessionDesc.defaultMatrixLayoutMode = desc.defaultMatrixLayoutMode;
slangSessionDesc.searchPathCount = desc.searchPathCount;
Loading
Loading