Skip to content

Commit 2cdc1cd

Browse files
committed
Fix precompiledTargetModule tests
Add SPIRV-Tool linker support to gfx unit tests and use the linker in precompileModule tests that use precompiled modules to reconstitute SPIRV shaders that were modularly compiled. Fix a Slang reference count bug in the precompile service.
1 parent a70113c commit 2cdc1cd

13 files changed

+159
-37
lines changed

source/slang/slang-compiler-tu.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getModuleDependency(
291291
{
292292
return SLANG_E_INVALID_ARG;
293293
}
294+
getModuleDependencies()[dependencyIndex]->addRef();
294295
*outModule = getModuleDependencies()[dependencyIndex];
295296
return SLANG_OK;
296297
}

tests/expected-failure-github.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ tests/autodiff/custom-intrinsic.slang.2 syn (wgpu)
1212
tests/bugs/buffer-swizzle-store.slang.3 syn (wgpu)
1313
tests/compute/interface-shader-param-in-struct.slang.4 syn (wgpu)
1414
tests/compute/interface-shader-param.slang.5 syn (wgpu)
15-
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)
16-
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal
15+
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)

tests/expected-failure.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ tests/language-feature/saturated-cooperation/fuse.slang (vk)
55
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
66
tests/ir/loop-unroll-0.slang.1 (vk)
77
tests/hlsl-intrinsic/texture/float-atomics.slang (vk)
8-
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal

tools/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ if(SLANG_ENABLE_GFX)
141141
$<$<BOOL:${SLANG_ENABLE_XLIB}>:X11::X11>
142142
$<$<BOOL:${SLANG_ENABLE_CUDA}>:CUDA::cuda_driver>
143143
$<$<BOOL:${SLANG_ENABLE_NVAPI}>:${NVAPI_LIBRARIES}>
144+
SPIRV-Tools-link
144145
LINK_WITH_FRAMEWORK Foundation Cocoa QuartzCore Metal
145146
EXTRA_COMPILE_DEFINITIONS_PRIVATE
146147
$<$<BOOL:${SLANG_ENABLE_CUDA}>:GFX_ENABLE_CUDA>

tools/gfx/d3d12/d3d12-shader-program.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ using namespace Slang;
1010

1111
Result ShaderProgramImpl::createShaderModule(
1212
slang::EntryPointReflection* entryPointInfo,
13-
ComPtr<ISlangBlob> kernelCode)
13+
List<ComPtr<ISlangBlob>> kernelCodes)
1414
{
1515
ShaderBinary shaderBin;
1616
shaderBin.stage = entryPointInfo->getStage();
1717
shaderBin.entryPointInfo = entryPointInfo;
1818
shaderBin.code.addRange(
19-
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
20-
(Index)kernelCode->getBufferSize());
19+
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
20+
(Index)kernelCodes[0]->getBufferSize());
2121
m_shaders.add(_Move(shaderBin));
2222
return SLANG_OK;
2323
}

tools/gfx/d3d12/d3d12-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase
2727

2828
virtual Result createShaderModule(
2929
slang::EntryPointReflection* entryPointInfo,
30-
ComPtr<ISlangBlob> kernelCode) override;
30+
List<ComPtr<ISlangBlob>> kernelCodes) override;
3131
};
3232

3333
} // namespace d3d12

tools/gfx/metal/metal-shader-program.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ ShaderProgramImpl::~ShaderProgramImpl() {}
2121

2222
Result ShaderProgramImpl::createShaderModule(
2323
slang::EntryPointReflection* entryPointInfo,
24-
ComPtr<ISlangBlob> kernelCode)
24+
List<ComPtr<ISlangBlob>> kernelCodes)
2525
{
2626
Module module;
2727
module.stage = entryPointInfo->getStage();
2828
module.entryPointName = entryPointInfo->getNameOverride();
29-
module.code = kernelCode;
29+
module.code = kernelCodes[0];
3030

3131
dispatch_data_t data = dispatch_data_create(
32-
kernelCode->getBufferPointer(),
33-
kernelCode->getBufferSize(),
32+
kernelCodes[0]->getBufferPointer(),
33+
kernelCodes[0]->getBufferSize(),
3434
dispatch_get_main_queue(),
3535
NULL);
3636
NS::Error* error;

tools/gfx/metal/metal-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3333

3434
virtual Result createShaderModule(
3535
slang::EntryPointReflection* entryPointInfo,
36-
ComPtr<ISlangBlob> kernelCode) override;
36+
List<ComPtr<ISlangBlob>> kernelCodes) override;
3737
};
3838

3939

tools/gfx/renderer-shared.cpp

+94-20
Original file line numberDiff line numberDiff line change
@@ -1103,32 +1103,106 @@ void ShaderProgramBase::init(const IShaderProgram::Desc& inDesc)
11031103

11041104
Result ShaderProgramBase::compileShaders(RendererBase* device)
11051105
{
1106+
auto compileTarget = device->slangContext.compileTarget;
11061107
// For a fully specialized program, read and store its kernel code in `shaderProgram`.
11071108
auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
11081109
slang::IComponentType* entryPointComponent,
11091110
SlangInt entryPointIndex)
11101111
{
11111112
auto stage = entryPointInfo->getStage();
1112-
ComPtr<ISlangBlob> kernelCode;
1113-
ComPtr<ISlangBlob> diagnostics;
1114-
auto compileResult = device->getEntryPointCodeFromShaderCache(
1115-
entryPointComponent,
1116-
entryPointIndex,
1117-
0,
1118-
kernelCode.writeRef(),
1119-
diagnostics.writeRef());
1120-
if (diagnostics)
1113+
List<ComPtr<ISlangBlob>> kernelCodes;
11211114
{
1122-
DebugMessageType msgType = DebugMessageType::Warning;
1123-
if (compileResult != SLANG_OK)
1124-
msgType = DebugMessageType::Error;
1125-
getDebugCallback()->handleMessage(
1126-
msgType,
1127-
DebugMessageSource::Slang,
1128-
(char*)diagnostics->getBufferPointer());
1115+
ComPtr<ISlangBlob> downstreamIR;
1116+
ComPtr<ISlangBlob> diagnostics;
1117+
auto compileResult = device->getEntryPointCodeFromShaderCache(
1118+
entryPointComponent,
1119+
entryPointIndex,
1120+
0,
1121+
downstreamIR.writeRef(),
1122+
diagnostics.writeRef());
1123+
if (diagnostics)
1124+
{
1125+
DebugMessageType msgType = DebugMessageType::Warning;
1126+
if (compileResult != SLANG_OK)
1127+
msgType = DebugMessageType::Error;
1128+
getDebugCallback()->handleMessage(
1129+
msgType,
1130+
DebugMessageSource::Slang,
1131+
(char*)diagnostics->getBufferPointer());
1132+
}
1133+
kernelCodes.add(downstreamIR);
11291134
}
1130-
SLANG_RETURN_ON_FAIL(compileResult);
1131-
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));
1135+
1136+
// If target precompilation was used, kernelCode may only represent the
1137+
// glue code holding together the bits of precompiled target IR.
1138+
// Collect those dependency target IRs too.
1139+
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
1140+
if (entryPointComponent->queryInterface(
1141+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1142+
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
1143+
{
1144+
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
1145+
if (dependencyCount > 0)
1146+
{
1147+
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
1148+
{
1149+
ComPtr<slang::IModule> dependencyModule;
1150+
{
1151+
ComPtr<slang::IBlob> diagnosticsBlob;
1152+
auto result = componentPrecompileService->getModuleDependency(
1153+
dependencyIndex,
1154+
dependencyModule.writeRef(),
1155+
diagnosticsBlob.writeRef());
1156+
if (diagnosticsBlob)
1157+
{
1158+
DebugMessageType msgType = DebugMessageType::Warning;
1159+
if (result != SLANG_OK)
1160+
msgType = DebugMessageType::Error;
1161+
getDebugCallback()->handleMessage(
1162+
msgType,
1163+
DebugMessageSource::Slang,
1164+
(char*)diagnosticsBlob->getBufferPointer());
1165+
}
1166+
SLANG_RETURN_ON_FAIL(result);
1167+
}
1168+
1169+
ComPtr<slang::IBlob> downstreamIR;
1170+
{
1171+
ComPtr<slang::IBlob> diagnosticsBlob;
1172+
SlangResult result = SLANG_OK;
1173+
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
1174+
result = dependencyModule->queryInterface(
1175+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1176+
(void**)precompileService.writeRef());
1177+
if (result == SLANG_OK)
1178+
{
1179+
ComPtr<slang::IBlob> diagnosticsBlob;
1180+
auto result = precompileService->getPrecompiledTargetCode(
1181+
compileTarget,
1182+
downstreamIR.writeRef(),
1183+
diagnosticsBlob.writeRef());
1184+
if (result == SLANG_OK)
1185+
{
1186+
kernelCodes.add(downstreamIR);
1187+
}
1188+
if (diagnosticsBlob)
1189+
{
1190+
DebugMessageType msgType = DebugMessageType::Warning;
1191+
if (result != SLANG_OK)
1192+
msgType = DebugMessageType::Error;
1193+
getDebugCallback()->handleMessage(
1194+
msgType,
1195+
DebugMessageSource::Slang,
1196+
(char*)diagnosticsBlob->getBufferPointer());
1197+
}
1198+
}
1199+
SLANG_RETURN_ON_FAIL(result);
1200+
}
1201+
}
1202+
}
1203+
}
1204+
1205+
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
11321206
return SLANG_OK;
11331207
};
11341208

@@ -1160,10 +1234,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
11601234

11611235
Result ShaderProgramBase::createShaderModule(
11621236
slang::EntryPointReflection* entryPointInfo,
1163-
ComPtr<ISlangBlob> kernelCode)
1237+
List<ComPtr<ISlangBlob>> kernelCodes)
11641238
{
11651239
SLANG_UNUSED(entryPointInfo);
1166-
SLANG_UNUSED(kernelCode);
1240+
SLANG_UNUSED(kernelCodes);
11671241
return SLANG_OK;
11681242
}
11691243

tools/gfx/renderer-shared.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
877877
Slang::Result compileShaders(RendererBase* device);
878878
virtual Slang::Result createShaderModule(
879879
slang::EntryPointReflection* entryPointInfo,
880-
Slang::ComPtr<ISlangBlob> kernelCode);
880+
Slang::List<Slang::ComPtr<ISlangBlob>> kernelCodes);
881881

882882
virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
883883
findTypeByName(const char* name) override

tools/gfx/slang-context.h

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class SlangContext
1010
public:
1111
Slang::ComPtr<slang::IGlobalSession> globalSession;
1212
Slang::ComPtr<slang::ISession> session;
13+
SlangCompileTarget compileTarget;
1314
Result initialize(
1415
const gfx::IDevice::SlangDesc& desc,
1516
uint32_t extendedDescCount,
@@ -27,6 +28,7 @@ class SlangContext
2728
SLANG_RETURN_ON_FAIL(slang::createGlobalSession(globalSession.writeRef()));
2829
}
2930

31+
this->compileTarget = compileTarget;
3032
slang::SessionDesc slangSessionDesc = {};
3133
slangSessionDesc.defaultMatrixLayoutMode = desc.defaultMatrixLayoutMode;
3234
slangSessionDesc.searchPathCount = desc.searchPathCount;

tools/gfx/vulkan/vk-shader-program.cpp

+49-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "vk-device.h"
55
#include "vk-util.h"
66

7+
#include "external/spirv-tools/include/spirv-tools/linker.hpp"
8+
79
namespace gfx
810
{
911

@@ -69,17 +71,61 @@ VkPipelineShaderStageCreateInfo ShaderProgramImpl::compileEntryPoint(
6971
return shaderStageCreateInfo;
7072
}
7173

74+
static ComPtr<ISlangBlob> LinkUsingSPIRVTools(List<ComPtr<ISlangBlob> > kernelCodes)
75+
{
76+
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
77+
spvtools::LinkerOptions options;
78+
spvtools::MessageConsumer consumer = [](spv_message_level_t level,
79+
const char* source,
80+
const spv_position_t& position,
81+
const char* message)
82+
{
83+
printf("SPIRV-TOOLS: %s\n", message);
84+
printf("SPIRV-TOOLS: %s\n", source);
85+
printf("SPIRV-TOOLS: %zu:%zu\n", position.index, position.column);
86+
};
87+
context.SetMessageConsumer(consumer);
88+
std::vector<uint32_t*> binaries;
89+
std::vector<size_t> binary_sizes;
90+
for (auto kernelCode : kernelCodes)
91+
{
92+
binaries.push_back((uint32_t*)kernelCode->getBufferPointer());
93+
binary_sizes.push_back(kernelCode->getBufferSize() / sizeof(uint32_t));
94+
}
95+
96+
std::vector<uint32_t> linked_binary;
97+
98+
spvtools::Link(
99+
context,
100+
binaries.data(),
101+
binary_sizes.data(),
102+
binaries.size(),
103+
&linked_binary,
104+
options);
105+
106+
// Create a blob to hold the linked binary
107+
ComPtr<ISlangBlob> linkedKernelCode;
108+
109+
// Replace kernel code with linked binary
110+
// Creates a new blob with the linked binary
111+
linkedKernelCode = RawBlob::create(linked_binary.data(), linked_binary.size() * sizeof(uint32_t));
112+
113+
return linkedKernelCode;
114+
}
115+
72116
Result ShaderProgramImpl::createShaderModule(
73117
slang::EntryPointReflection* entryPointInfo,
74-
ComPtr<ISlangBlob> kernelCode)
118+
List<ComPtr<ISlangBlob>> kernelCodes)
75119
{
76-
m_codeBlobs.add(kernelCode);
120+
ComPtr<ISlangBlob> linkedKernel = LinkUsingSPIRVTools(kernelCodes);
121+
m_codeBlobs.add(linkedKernel);
122+
77123
VkShaderModule shaderModule;
78124
auto realEntryPointName = entryPointInfo->getNameOverride();
79125
const char* spirvBinaryEntryPointName = "main";
80126
m_stageCreateInfos.add(compileEntryPoint(
81127
spirvBinaryEntryPointName,
82-
kernelCode,
128+
linkedKernel,
83129
(VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()),
84130
shaderModule));
85131
m_entryPointNames.add(realEntryPointName);

tools/gfx/vulkan/vk-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3737

3838
virtual Result createShaderModule(
3939
slang::EntryPointReflection* entryPointInfo,
40-
ComPtr<ISlangBlob> kernelCode) override;
40+
List<ComPtr<ISlangBlob>> kernelCodes) override;
4141
};
4242

4343

0 commit comments

Comments
 (0)