Skip to content

Commit 3f2ed2a

Browse files
committed
Fix precompiledTargetModule tests
In the SPIR-V backend of Slang, compiling a shader that contains some modules with precompiled target blobs will produce only a "glue" SPIR-V output which needs to be linked with the assorted precompiled blobs to be complete. Closes shader-slang#6170
1 parent a5b1aa0 commit 3f2ed2a

10 files changed

+162
-33
lines changed

tools/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ if(SLANG_ENABLE_GFX)
158158
EXPORT_SET_NAME SlangTargets
159159
FOLDER gfx
160160
)
161+
target_link_libraries(gfx PUBLIC SPIRV-Tools-link)
161162
set(modules_dest_dir $<TARGET_FILE_DIR:slang-test>)
162163
add_custom_target(
163164
copy-gfx-slang-modules

tools/gfx/d3d12/d3d12-shader-program.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ using namespace Slang;
1010

1111
Result ShaderProgramImpl::createShaderModule(
1212
slang::EntryPointReflection* entryPointInfo,
13-
ComPtr<ISlangBlob> kernelCode)
13+
List<ComPtr<ISlangBlob> > kernelCodes)
1414
{
1515
ShaderBinary shaderBin;
1616
shaderBin.stage = entryPointInfo->getStage();
1717
shaderBin.entryPointInfo = entryPointInfo;
18+
assert(kernelCodes.getCount() == 1); // Only one kernel code is supported for now
1819
shaderBin.code.addRange(
19-
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
20-
(Index)kernelCode->getBufferSize());
20+
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
21+
(Index)kernelCodes[0]->getBufferSize());
2122
m_shaders.add(_Move(shaderBin));
2223
return SLANG_OK;
2324
}

tools/gfx/d3d12/d3d12-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase
2727

2828
virtual Result createShaderModule(
2929
slang::EntryPointReflection* entryPointInfo,
30-
ComPtr<ISlangBlob> kernelCode) override;
30+
List<ComPtr<ISlangBlob> > kernelCodes) override;
3131
};
3232

3333
} // namespace d3d12

tools/gfx/metal/metal-shader-program.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ ShaderProgramImpl::~ShaderProgramImpl() {}
2121

2222
Result ShaderProgramImpl::createShaderModule(
2323
slang::EntryPointReflection* entryPointInfo,
24-
ComPtr<ISlangBlob> kernelCode)
24+
Slang::List<ComPtr<ISlangBlob> > kernelCodes)
2525
{
2626
Module module;
2727
module.stage = entryPointInfo->getStage();
2828
module.entryPointName = entryPointInfo->getNameOverride();
29-
module.code = kernelCode;
29+
assert(kernelCodes.getCount() == 1);
30+
module.code = kernelCodes[0];
31+
3032

3133
dispatch_data_t data = dispatch_data_create(
32-
kernelCode->getBufferPointer(),
33-
kernelCode->getBufferSize(),
34+
kernelCodes[0]->getBufferPointer(),
35+
kernelCodes[0]->getBufferSize(),
3436
dispatch_get_main_queue(),
3537
NULL);
3638
NS::Error* error;

tools/gfx/metal/metal-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3333

3434
virtual Result createShaderModule(
3535
slang::EntryPointReflection* entryPointInfo,
36-
ComPtr<ISlangBlob> kernelCode) override;
36+
Slang::List<ComPtr<ISlangBlob> > kernelCodes) override;
3737
};
3838

3939

tools/gfx/renderer-shared.cpp

+94-19
Original file line numberDiff line numberDiff line change
@@ -1109,26 +1109,101 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
11091109
SlangInt entryPointIndex)
11101110
{
11111111
auto stage = entryPointInfo->getStage();
1112+
1113+
List<ComPtr<ISlangBlob> > kernelCodes;
11121114
ComPtr<ISlangBlob> kernelCode;
1113-
ComPtr<ISlangBlob> diagnostics;
1114-
auto compileResult = device->getEntryPointCodeFromShaderCache(
1115-
entryPointComponent,
1116-
entryPointIndex,
1117-
0,
1118-
kernelCode.writeRef(),
1119-
diagnostics.writeRef());
1120-
if (diagnostics)
11211115
{
1122-
DebugMessageType msgType = DebugMessageType::Warning;
1123-
if (compileResult != SLANG_OK)
1124-
msgType = DebugMessageType::Error;
1125-
getDebugCallback()->handleMessage(
1126-
msgType,
1127-
DebugMessageSource::Slang,
1128-
(char*)diagnostics->getBufferPointer());
1116+
ComPtr<ISlangBlob> diagnostics;
1117+
auto compileResult = device->getEntryPointCodeFromShaderCache(
1118+
entryPointComponent,
1119+
entryPointIndex,
1120+
0,
1121+
kernelCode.writeRef(),
1122+
diagnostics.writeRef());
1123+
if (diagnostics)
1124+
{
1125+
DebugMessageType msgType = DebugMessageType::Warning;
1126+
if (compileResult != SLANG_OK)
1127+
msgType = DebugMessageType::Error;
1128+
getDebugCallback()->handleMessage(
1129+
msgType,
1130+
DebugMessageSource::Slang,
1131+
(char*)diagnostics->getBufferPointer());
1132+
}
1133+
SLANG_RETURN_ON_FAIL(compileResult);
1134+
kernelCodes.add(kernelCode);
1135+
}
1136+
1137+
// If target precompilation was used, kernelCode may only represent the
1138+
// glue code holding together the bits of precompiled target IR.
1139+
// Collect those dependency target IRs too.
1140+
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
1141+
if (entryPointComponent->queryInterface(
1142+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1143+
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
1144+
{
1145+
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
1146+
if (dependencyCount > 0)
1147+
{
1148+
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
1149+
{
1150+
ComPtr<slang::IModule> dependencyModule;
1151+
{
1152+
ComPtr<slang::IBlob> diagnosticsBlob;
1153+
auto result = componentPrecompileService->getModuleDependency(
1154+
dependencyIndex,
1155+
dependencyModule.writeRef(),
1156+
diagnosticsBlob.writeRef());
1157+
if (diagnosticsBlob)
1158+
{
1159+
DebugMessageType msgType = DebugMessageType::Warning;
1160+
if (result != SLANG_OK)
1161+
msgType = DebugMessageType::Error;
1162+
getDebugCallback()->handleMessage(
1163+
msgType,
1164+
DebugMessageSource::Slang,
1165+
(char*)diagnosticsBlob->getBufferPointer());
1166+
}
1167+
SLANG_RETURN_ON_FAIL(result);
1168+
}
1169+
1170+
ComPtr<slang::IBlob> spirv;
1171+
{
1172+
ComPtr<slang::IBlob> diagnosticsBlob;
1173+
SlangResult result = SLANG_OK;
1174+
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
1175+
result = dependencyModule->queryInterface(
1176+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1177+
(void**)precompileService.writeRef());
1178+
if (result == SLANG_OK)
1179+
{
1180+
ComPtr<slang::IBlob> diagnosticsBlob;
1181+
auto result = precompileService->getPrecompiledTargetCode(
1182+
SLANG_SPIRV,
1183+
spirv.writeRef(),
1184+
diagnosticsBlob.writeRef());
1185+
if (result == SLANG_OK)
1186+
{
1187+
kernelCodes.add(spirv);
1188+
}
1189+
if (diagnosticsBlob)
1190+
{
1191+
DebugMessageType msgType = DebugMessageType::Warning;
1192+
if (result != SLANG_OK)
1193+
msgType = DebugMessageType::Error;
1194+
getDebugCallback()->handleMessage(
1195+
msgType,
1196+
DebugMessageSource::Slang,
1197+
(char*)diagnosticsBlob->getBufferPointer());
1198+
}
1199+
}
1200+
SLANG_RETURN_ON_FAIL(result);
1201+
}
1202+
}
1203+
}
11291204
}
1130-
SLANG_RETURN_ON_FAIL(compileResult);
1131-
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));
1205+
1206+
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
11321207
return SLANG_OK;
11331208
};
11341209

@@ -1160,10 +1235,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
11601235

11611236
Result ShaderProgramBase::createShaderModule(
11621237
slang::EntryPointReflection* entryPointInfo,
1163-
ComPtr<ISlangBlob> kernelCode)
1238+
Slang::List<ComPtr<ISlangBlob> > kernelCodes)
11641239
{
11651240
SLANG_UNUSED(entryPointInfo);
1166-
SLANG_UNUSED(kernelCode);
1241+
SLANG_UNUSED(kernelCodes);
11671242
return SLANG_OK;
11681243
}
11691244

tools/gfx/renderer-shared.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
877877
Slang::Result compileShaders(RendererBase* device);
878878
virtual Slang::Result createShaderModule(
879879
slang::EntryPointReflection* entryPointInfo,
880-
Slang::ComPtr<ISlangBlob> kernelCode);
880+
Slang::List<Slang::ComPtr<ISlangBlob> > kernelCodes);
881881

882882
virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
883883
findTypeByName(const char* name) override

tools/gfx/vulkan/vk-pipeline-state.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ Result PipelineStateImpl::createVKComputePipelineState()
311311

312312
VkComputePipelineCreateInfo computePipelineInfo = {
313313
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
314+
assert(programImpl->m_stageCreateInfos.getCount() == 1);
314315
computePipelineInfo.stage = programImpl->m_stageCreateInfos[0];
315316
computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
316317

tools/gfx/vulkan/vk-shader-program.cpp

+52-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "vk-device.h"
55
#include "vk-util.h"
66

7+
#include "external/spirv-tools/include/spirv-tools/linker.hpp"
8+
79
namespace gfx
810
{
911

@@ -69,19 +71,66 @@ VkPipelineShaderStageCreateInfo ShaderProgramImpl::compileEntryPoint(
6971
return shaderStageCreateInfo;
7072
}
7173

74+
static ComPtr<ISlangBlob> LinkWithSPIRVTools(List<ComPtr<ISlangBlob> > kernelCodes)
75+
{
76+
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
77+
spvtools::LinkerOptions options;
78+
spvtools::MessageConsumer consumer = [](spv_message_level_t level,
79+
const char* source,
80+
const spv_position_t& position,
81+
const char* message)
82+
{
83+
printf("SPIRV-TOOLS: %s\n", message);
84+
printf("SPIRV-TOOLS: %s\n", source);
85+
printf("SPIRV-TOOLS: %d:%d\n", position.index, position.column);
86+
};
87+
context.SetMessageConsumer(consumer);
88+
std::vector<uint32_t*> binaries;
89+
std::vector<size_t> binary_sizes;
90+
for (auto kernelCode : kernelCodes)
91+
{
92+
binaries.push_back((uint32_t*)kernelCode->getBufferPointer());
93+
binary_sizes.push_back(kernelCode->getBufferSize() / sizeof(uint32_t));
94+
}
95+
96+
std::vector<uint32_t> linked_binary;
97+
98+
spvtools::Link(
99+
context,
100+
binaries.data(),
101+
binary_sizes.data(),
102+
binaries.size(),
103+
&linked_binary,
104+
options);
105+
106+
// Create a blob to hold the linked binary
107+
ComPtr<ISlangBlob> linkedKernelCode;
108+
109+
// Replace kernel code with linked binary
110+
// Creates a new blob with the linked binary
111+
linkedKernelCode = RawBlob::create(linked_binary.data(), linked_binary.size() * sizeof(uint32_t));
112+
113+
return linkedKernelCode;
114+
}
72115
Result ShaderProgramImpl::createShaderModule(
73116
slang::EntryPointReflection* entryPointInfo,
74-
ComPtr<ISlangBlob> kernelCode)
117+
List<ComPtr<ISlangBlob>> kernelCodes)
75118
{
76-
m_codeBlobs.add(kernelCode);
119+
//for (auto kernelCode : kernelCodes)
120+
// m_codeBlobs.add(kernelCode);
121+
122+
ComPtr<ISlangBlob> linkedKernel = LinkWithSPIRVTools(kernelCodes);
123+
m_codeBlobs.add(linkedKernel);
124+
77125
VkShaderModule shaderModule;
78126
auto realEntryPointName = entryPointInfo->getNameOverride();
79127
const char* spirvBinaryEntryPointName = "main";
80128
m_stageCreateInfos.add(compileEntryPoint(
81129
spirvBinaryEntryPointName,
82-
kernelCode,
130+
linkedKernel,
83131
(VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()),
84132
shaderModule));
133+
85134
m_entryPointNames.add(realEntryPointName);
86135
m_modules.add(shaderModule);
87136
return SLANG_OK;

tools/gfx/vulkan/vk-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3737

3838
virtual Result createShaderModule(
3939
slang::EntryPointReflection* entryPointInfo,
40-
ComPtr<ISlangBlob> kernelCode) override;
40+
List<ComPtr<ISlangBlob> > kernelCodes) override;
4141
};
4242

4343

0 commit comments

Comments
 (0)