Skip to content

Commit 64690fb

Browse files
committed
Fix precompiledTargetModule tests
1 parent a70113c commit 64690fb

11 files changed

+155
-37
lines changed

tests/expected-failure-github.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ tests/autodiff/custom-intrinsic.slang.2 syn (wgpu)
1212
tests/bugs/buffer-swizzle-store.slang.3 syn (wgpu)
1313
tests/compute/interface-shader-param-in-struct.slang.4 syn (wgpu)
1414
tests/compute/interface-shader-param.slang.5 syn (wgpu)
15-
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)
16-
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal
15+
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)

tests/expected-failure.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ tests/language-feature/saturated-cooperation/fuse.slang (vk)
55
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
66
tests/ir/loop-unroll-0.slang.1 (vk)
77
tests/hlsl-intrinsic/texture/float-atomics.slang (vk)
8-
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal

tools/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ if(SLANG_ENABLE_GFX)
141141
$<$<BOOL:${SLANG_ENABLE_XLIB}>:X11::X11>
142142
$<$<BOOL:${SLANG_ENABLE_CUDA}>:CUDA::cuda_driver>
143143
$<$<BOOL:${SLANG_ENABLE_NVAPI}>:${NVAPI_LIBRARIES}>
144+
SPIRV-Tools-link
144145
LINK_WITH_FRAMEWORK Foundation Cocoa QuartzCore Metal
145146
EXTRA_COMPILE_DEFINITIONS_PRIVATE
146147
$<$<BOOL:${SLANG_ENABLE_CUDA}>:GFX_ENABLE_CUDA>

tools/gfx/d3d12/d3d12-shader-program.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ using namespace Slang;
1010

1111
Result ShaderProgramImpl::createShaderModule(
1212
slang::EntryPointReflection* entryPointInfo,
13-
ComPtr<ISlangBlob> kernelCode)
13+
List<ComPtr<ISlangBlob>> kernelCodes)
1414
{
1515
ShaderBinary shaderBin;
1616
shaderBin.stage = entryPointInfo->getStage();
1717
shaderBin.entryPointInfo = entryPointInfo;
1818
shaderBin.code.addRange(
19-
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
20-
(Index)kernelCode->getBufferSize());
19+
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
20+
(Index)kernelCodes[0]->getBufferSize());
2121
m_shaders.add(_Move(shaderBin));
2222
return SLANG_OK;
2323
}

tools/gfx/d3d12/d3d12-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase
2727

2828
virtual Result createShaderModule(
2929
slang::EntryPointReflection* entryPointInfo,
30-
ComPtr<ISlangBlob> kernelCode) override;
30+
List<ComPtr<ISlangBlob>> kernelCodes) override;
3131
};
3232

3333
} // namespace d3d12

tools/gfx/metal/metal-shader-program.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ ShaderProgramImpl::~ShaderProgramImpl() {}
2121

2222
Result ShaderProgramImpl::createShaderModule(
2323
slang::EntryPointReflection* entryPointInfo,
24-
ComPtr<ISlangBlob> kernelCode)
24+
List<ComPtr<ISlangBlob>> kernelCodes)
2525
{
2626
Module module;
2727
module.stage = entryPointInfo->getStage();
2828
module.entryPointName = entryPointInfo->getNameOverride();
29-
module.code = kernelCode;
29+
module.code = kernelCodes[0];
3030

3131
dispatch_data_t data = dispatch_data_create(
32-
kernelCode->getBufferPointer(),
33-
kernelCode->getBufferSize(),
32+
kernelCodes[0]->getBufferPointer(),
33+
kernelCodes[0]->getBufferSize(),
3434
dispatch_get_main_queue(),
3535
NULL);
3636
NS::Error* error;

tools/gfx/metal/metal-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3333

3434
virtual Result createShaderModule(
3535
slang::EntryPointReflection* entryPointInfo,
36-
ComPtr<ISlangBlob> kernelCode) override;
36+
List<ComPtr<ISlangBlob>> kernelCodes) override;
3737
};
3838

3939

tools/gfx/renderer-shared.cpp

+93-20
Original file line numberDiff line numberDiff line change
@@ -1109,26 +1109,99 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
11091109
SlangInt entryPointIndex)
11101110
{
11111111
auto stage = entryPointInfo->getStage();
1112-
ComPtr<ISlangBlob> kernelCode;
1113-
ComPtr<ISlangBlob> diagnostics;
1114-
auto compileResult = device->getEntryPointCodeFromShaderCache(
1115-
entryPointComponent,
1116-
entryPointIndex,
1117-
0,
1118-
kernelCode.writeRef(),
1119-
diagnostics.writeRef());
1120-
if (diagnostics)
1112+
List<ComPtr<ISlangBlob>> kernelCodes;
11211113
{
1122-
DebugMessageType msgType = DebugMessageType::Warning;
1123-
if (compileResult != SLANG_OK)
1124-
msgType = DebugMessageType::Error;
1125-
getDebugCallback()->handleMessage(
1126-
msgType,
1127-
DebugMessageSource::Slang,
1128-
(char*)diagnostics->getBufferPointer());
1114+
ComPtr<ISlangBlob> spirv;
1115+
ComPtr<ISlangBlob> diagnostics;
1116+
auto compileResult = device->getEntryPointCodeFromShaderCache(
1117+
entryPointComponent,
1118+
entryPointIndex,
1119+
0,
1120+
spirv.writeRef(),
1121+
diagnostics.writeRef());
1122+
if (diagnostics)
1123+
{
1124+
DebugMessageType msgType = DebugMessageType::Warning;
1125+
if (compileResult != SLANG_OK)
1126+
msgType = DebugMessageType::Error;
1127+
getDebugCallback()->handleMessage(
1128+
msgType,
1129+
DebugMessageSource::Slang,
1130+
(char*)diagnostics->getBufferPointer());
1131+
}
1132+
kernelCodes.add(spirv);
11291133
}
1130-
SLANG_RETURN_ON_FAIL(compileResult);
1131-
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));
1134+
1135+
// If target precompilation was used, kernelCode may only represent the
1136+
// glue code holding together the bits of precompiled target IR.
1137+
// Collect those dependency target IRs too.
1138+
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
1139+
if (entryPointComponent->queryInterface(
1140+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1141+
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
1142+
{
1143+
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
1144+
if (dependencyCount > 0)
1145+
{
1146+
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
1147+
{
1148+
ComPtr<slang::IModule> dependencyModule;
1149+
{
1150+
ComPtr<slang::IBlob> diagnosticsBlob;
1151+
auto result = componentPrecompileService->getModuleDependency(
1152+
dependencyIndex,
1153+
dependencyModule.writeRef(),
1154+
diagnosticsBlob.writeRef());
1155+
if (diagnosticsBlob)
1156+
{
1157+
DebugMessageType msgType = DebugMessageType::Warning;
1158+
if (result != SLANG_OK)
1159+
msgType = DebugMessageType::Error;
1160+
getDebugCallback()->handleMessage(
1161+
msgType,
1162+
DebugMessageSource::Slang,
1163+
(char*)diagnosticsBlob->getBufferPointer());
1164+
}
1165+
SLANG_RETURN_ON_FAIL(result);
1166+
}
1167+
1168+
ComPtr<slang::IBlob> spirv;
1169+
{
1170+
ComPtr<slang::IBlob> diagnosticsBlob;
1171+
SlangResult result = SLANG_OK;
1172+
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
1173+
result = dependencyModule->queryInterface(
1174+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1175+
(void**)precompileService.writeRef());
1176+
if (result == SLANG_OK)
1177+
{
1178+
ComPtr<slang::IBlob> diagnosticsBlob;
1179+
auto result = precompileService->getPrecompiledTargetCode(
1180+
SLANG_SPIRV,
1181+
spirv.writeRef(),
1182+
diagnosticsBlob.writeRef());
1183+
if (result == SLANG_OK)
1184+
{
1185+
kernelCodes.add(spirv);
1186+
}
1187+
if (diagnosticsBlob)
1188+
{
1189+
DebugMessageType msgType = DebugMessageType::Warning;
1190+
if (result != SLANG_OK)
1191+
msgType = DebugMessageType::Error;
1192+
getDebugCallback()->handleMessage(
1193+
msgType,
1194+
DebugMessageSource::Slang,
1195+
(char*)diagnosticsBlob->getBufferPointer());
1196+
}
1197+
}
1198+
SLANG_RETURN_ON_FAIL(result);
1199+
}
1200+
}
1201+
}
1202+
}
1203+
1204+
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
11321205
return SLANG_OK;
11331206
};
11341207

@@ -1160,10 +1233,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
11601233

11611234
Result ShaderProgramBase::createShaderModule(
11621235
slang::EntryPointReflection* entryPointInfo,
1163-
ComPtr<ISlangBlob> kernelCode)
1236+
List<ComPtr<ISlangBlob>> kernelCodes)
11641237
{
11651238
SLANG_UNUSED(entryPointInfo);
1166-
SLANG_UNUSED(kernelCode);
1239+
SLANG_UNUSED(kernelCodes);
11671240
return SLANG_OK;
11681241
}
11691242

tools/gfx/renderer-shared.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
877877
Slang::Result compileShaders(RendererBase* device);
878878
virtual Slang::Result createShaderModule(
879879
slang::EntryPointReflection* entryPointInfo,
880-
Slang::ComPtr<ISlangBlob> kernelCode);
880+
Slang::List<Slang::ComPtr<ISlangBlob>> kernelCodes);
881881

882882
virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
883883
findTypeByName(const char* name) override

tools/gfx/vulkan/vk-shader-program.cpp

+49-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "vk-device.h"
55
#include "vk-util.h"
66

7+
#include "external/spirv-tools/include/spirv-tools/linker.hpp"
8+
79
namespace gfx
810
{
911

@@ -69,17 +71,61 @@ VkPipelineShaderStageCreateInfo ShaderProgramImpl::compileEntryPoint(
6971
return shaderStageCreateInfo;
7072
}
7173

74+
static ComPtr<ISlangBlob> LinkUsingSPIRVTools(List<ComPtr<ISlangBlob> > kernelCodes)
75+
{
76+
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
77+
spvtools::LinkerOptions options;
78+
spvtools::MessageConsumer consumer = [](spv_message_level_t level,
79+
const char* source,
80+
const spv_position_t& position,
81+
const char* message)
82+
{
83+
printf("SPIRV-TOOLS: %s\n", message);
84+
printf("SPIRV-TOOLS: %s\n", source);
85+
printf("SPIRV-TOOLS: %zu:%zu\n", position.index, position.column);
86+
};
87+
context.SetMessageConsumer(consumer);
88+
std::vector<uint32_t*> binaries;
89+
std::vector<size_t> binary_sizes;
90+
for (auto kernelCode : kernelCodes)
91+
{
92+
binaries.push_back((uint32_t*)kernelCode->getBufferPointer());
93+
binary_sizes.push_back(kernelCode->getBufferSize() / sizeof(uint32_t));
94+
}
95+
96+
std::vector<uint32_t> linked_binary;
97+
98+
spvtools::Link(
99+
context,
100+
binaries.data(),
101+
binary_sizes.data(),
102+
binaries.size(),
103+
&linked_binary,
104+
options);
105+
106+
// Create a blob to hold the linked binary
107+
ComPtr<ISlangBlob> linkedKernelCode;
108+
109+
// Replace kernel code with linked binary
110+
// Creates a new blob with the linked binary
111+
linkedKernelCode = RawBlob::create(linked_binary.data(), linked_binary.size() * sizeof(uint32_t));
112+
113+
return linkedKernelCode;
114+
}
115+
72116
Result ShaderProgramImpl::createShaderModule(
73117
slang::EntryPointReflection* entryPointInfo,
74-
ComPtr<ISlangBlob> kernelCode)
118+
List<ComPtr<ISlangBlob>> kernelCodes)
75119
{
76-
m_codeBlobs.add(kernelCode);
120+
ComPtr<ISlangBlob> linkedKernel = LinkUsingSPIRVTools(kernelCodes);
121+
m_codeBlobs.add(linkedKernel);
122+
77123
VkShaderModule shaderModule;
78124
auto realEntryPointName = entryPointInfo->getNameOverride();
79125
const char* spirvBinaryEntryPointName = "main";
80126
m_stageCreateInfos.add(compileEntryPoint(
81127
spirvBinaryEntryPointName,
82-
kernelCode,
128+
linkedKernel,
83129
(VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()),
84130
shaderModule));
85131
m_entryPointNames.add(realEntryPointName);

tools/gfx/vulkan/vk-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3737

3838
virtual Result createShaderModule(
3939
slang::EntryPointReflection* entryPointInfo,
40-
ComPtr<ISlangBlob> kernelCode) override;
40+
List<ComPtr<ISlangBlob>> kernelCodes) override;
4141
};
4242

4343

0 commit comments

Comments
 (0)