Skip to content

Commit 02706df

Browse files
cheneym2slangbotcsyonghe
authored
Fix precompiledTargetModule tests (#6455)
* Fix precompiledTargetModule tests Add SPIRV-Tool linker support to gfx unit tests and use the linker in precompileModule tests that use precompiled modules to reconstitute SPIRV shaders that were modularly compiled. Fix a Slang reference count bug in the precompile service. * Use sm_6_6 New DXC requires higher version for linkability. * Rename helper function, pass by reference * Link through slang-glslang * Add missing files * Fix metal * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
1 parent 6e862bb commit 02706df

20 files changed

+344
-40
lines changed

source/slang-glslang/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ if(SLANG_ENABLE_SLANG_GLSLANG)
66
.
77
MODULE
88
USE_FEWER_WARNINGS
9-
LINK_WITH_PRIVATE glslang SPIRV SPIRV-Tools-opt
9+
LINK_WITH_PRIVATE glslang SPIRV SPIRV-Tools-opt SPIRV-Tools-link
1010
INCLUDE_DIRECTORIES_PRIVATE ${slang_SOURCE_DIR}/include
1111
INSTALL
1212
EXPORT_SET_NAME SlangTargets

source/slang-glslang/slang-glslang.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "glslang/Public/ShaderLang.h"
77
#include "slang.h"
88
#include "spirv-tools/libspirv.h"
9+
#include "spirv-tools/linker.hpp"
910
#include "spirv-tools/optimizer.hpp"
1011

1112
#ifdef _WIN32
@@ -979,3 +980,67 @@ extern "C"
979980
request.set(*inRequest);
980981
return glslang_compile_1_1(&request);
981982
}
983+
984+
extern "C"
985+
#ifdef _MSC_VER
986+
_declspec(dllexport)
987+
#else
988+
__attribute__((__visibility__("default")))
989+
#endif
990+
int glslang_linkSPIRV(glslang_LinkRequest* request)
991+
{
992+
if (!request || !request->modules || request->linkResult)
993+
return false;
994+
995+
try
996+
{
997+
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
998+
spvtools::LinkerOptions options = {};
999+
1000+
spvtools::MessageConsumer consumer = [](spv_message_level_t level,
1001+
const char* source,
1002+
const spv_position_t& position,
1003+
const char* message)
1004+
{
1005+
printf("SPIRV-TOOLS: %s\n", message);
1006+
printf("SPIRV-TOOLS: %s\n", source);
1007+
printf("SPIRV-TOOLS: %zu:%zu\n", position.index, position.column);
1008+
};
1009+
context.SetMessageConsumer(consumer);
1010+
1011+
std::vector<std::vector<uint32_t>> moduleVecs(request->moduleCount);
1012+
std::vector<const uint32_t*> moduleData(request->moduleCount);
1013+
std::vector<size_t> moduleSizes(request->moduleCount);
1014+
1015+
for (size_t i = 0; i < request->moduleCount; ++i)
1016+
{
1017+
moduleData[i] = request->modules[i];
1018+
moduleSizes[i] = request->moduleSizes[i];
1019+
}
1020+
1021+
std::vector<uint32_t> linkedBinary;
1022+
spv_result_t success = spvtools::Link(
1023+
context,
1024+
moduleData.data(),
1025+
moduleSizes.data(),
1026+
request->moduleCount,
1027+
&linkedBinary,
1028+
options);
1029+
1030+
if (success == SPV_SUCCESS)
1031+
{
1032+
request->linkResult = new uint32_t[linkedBinary.size()];
1033+
memcpy(
1034+
(void*)request->linkResult,
1035+
linkedBinary.data(),
1036+
linkedBinary.size() * sizeof(uint32_t));
1037+
request->linkResultSize = linkedBinary.size();
1038+
}
1039+
1040+
return success;
1041+
}
1042+
catch (...)
1043+
{
1044+
return false;
1045+
}
1046+
}

source/slang-glslang/slang-glslang.h

+10-1
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,19 @@ inline void glslang_CompileRequest_1_2::set(const glslang_CompileRequest_1_1& in
152152
memcpy(this, &in, sizeof(in));
153153
}
154154

155+
typedef struct glslang_LinkRequest_t
156+
{
157+
const uint32_t** modules; // Input: array of pointers to SPIR-V modules
158+
const uint32_t* moduleSizes; // Input: array of sizes of SPIR-V modules in 32-bit words
159+
int moduleCount; // Input: number of modules in the array
160+
const uint32_t* linkResult; // Output: pointer to linked SPIR-V module
161+
size_t linkResultSize; // Output: size of the linked SPIR-V module in 32-bit words
162+
} glslang_LinkRequest;
163+
155164
typedef int (*glslang_CompileFunc_1_0)(glslang_CompileRequest_1_0* request);
156165
typedef int (*glslang_CompileFunc_1_1)(glslang_CompileRequest_1_1* request);
157166
typedef int (*glslang_CompileFunc_1_2)(glslang_CompileRequest_1_2* request);
158167
typedef bool (*glslang_ValidateSPIRVFunc)(const uint32_t* contents, int contentsSize);
159168
typedef bool (*glslang_DisassembleSPIRVFunc)(const uint32_t* contents, int contentsSize);
160-
169+
typedef bool (*glslang_LinkSPIRVFunc)(glslang_LinkRequest* request);
161170
#endif

source/slang/slang-compiler-tu.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getModuleDependency(
291291
{
292292
return SLANG_E_INVALID_ARG;
293293
}
294+
getModuleDependencies()[dependencyIndex]->addRef();
294295
*outModule = getModuleDependencies()[dependencyIndex];
295296
return SLANG_OK;
296297
}

tests/expected-failure-github.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ tests/autodiff/custom-intrinsic.slang.2 syn (wgpu)
1010
tests/bugs/buffer-swizzle-store.slang.3 syn (wgpu)
1111
tests/compute/interface-shader-param-in-struct.slang.4 syn (wgpu)
1212
tests/compute/interface-shader-param.slang.5 syn (wgpu)
13-
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)
14-
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal
13+
tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu)

tests/expected-failure.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ tests/language-feature/saturated-cooperation/fuse.slang (vk)
55
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
66
tests/ir/loop-unroll-0.slang.1 (vk)
77
tests/hlsl-intrinsic/texture/float-atomics.slang (vk)
8-
gfx-unit-test-tool/precompiledTargetModule2Vulkan.internal

tools/gfx-unit-test/precompiled-module-2.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void precompiledModule2TestImplCommon(
112112
{
113113
case gfx::DeviceType::DirectX12:
114114
targetDesc.format = SLANG_DXIL;
115-
targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("sm_6_1");
115+
targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("sm_6_6");
116116
break;
117117
case gfx::DeviceType::Vulkan:
118118
targetDesc.format = SLANG_SPIRV;

tools/gfx/d3d12/d3d12-shader-program.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ using namespace Slang;
1010

1111
Result ShaderProgramImpl::createShaderModule(
1212
slang::EntryPointReflection* entryPointInfo,
13-
ComPtr<ISlangBlob> kernelCode)
13+
List<ComPtr<ISlangBlob>>& kernelCodes)
1414
{
1515
ShaderBinary shaderBin;
1616
shaderBin.stage = entryPointInfo->getStage();
1717
shaderBin.entryPointInfo = entryPointInfo;
1818
shaderBin.code.addRange(
19-
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
20-
(Index)kernelCode->getBufferSize());
19+
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
20+
(Index)kernelCodes[0]->getBufferSize());
2121
m_shaders.add(_Move(shaderBin));
2222
return SLANG_OK;
2323
}

tools/gfx/d3d12/d3d12-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase
2727

2828
virtual Result createShaderModule(
2929
slang::EntryPointReflection* entryPointInfo,
30-
ComPtr<ISlangBlob> kernelCode) override;
30+
List<ComPtr<ISlangBlob>>& kernelCodes) override;
3131
};
3232

3333
} // namespace d3d12

tools/gfx/metal/metal-shader-program.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ ShaderProgramImpl::~ShaderProgramImpl() {}
2121

2222
Result ShaderProgramImpl::createShaderModule(
2323
slang::EntryPointReflection* entryPointInfo,
24-
ComPtr<ISlangBlob> kernelCode)
24+
List<ComPtr<ISlangBlob>>& kernelCodes)
2525
{
2626
Module module;
2727
module.stage = entryPointInfo->getStage();
2828
module.entryPointName = entryPointInfo->getNameOverride();
29-
module.code = kernelCode;
29+
module.code = kernelCodes[0];
3030

3131
dispatch_data_t data = dispatch_data_create(
32-
kernelCode->getBufferPointer(),
33-
kernelCode->getBufferSize(),
32+
kernelCodes[0]->getBufferPointer(),
33+
kernelCodes[0]->getBufferSize(),
3434
dispatch_get_main_queue(),
3535
NULL);
3636
NS::Error* error;

tools/gfx/metal/metal-shader-program.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase
3333

3434
virtual Result createShaderModule(
3535
slang::EntryPointReflection* entryPointInfo,
36-
ComPtr<ISlangBlob> kernelCode) override;
36+
List<ComPtr<ISlangBlob>>& kernelCodes) override;
3737
};
3838

3939

tools/gfx/renderer-shared.cpp

+94-20
Original file line numberDiff line numberDiff line change
@@ -1103,32 +1103,106 @@ void ShaderProgramBase::init(const IShaderProgram::Desc& inDesc)
11031103

11041104
Result ShaderProgramBase::compileShaders(RendererBase* device)
11051105
{
1106+
auto compileTarget = device->slangContext.compileTarget;
11061107
// For a fully specialized program, read and store its kernel code in `shaderProgram`.
11071108
auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
11081109
slang::IComponentType* entryPointComponent,
11091110
SlangInt entryPointIndex)
11101111
{
11111112
auto stage = entryPointInfo->getStage();
1112-
ComPtr<ISlangBlob> kernelCode;
1113-
ComPtr<ISlangBlob> diagnostics;
1114-
auto compileResult = device->getEntryPointCodeFromShaderCache(
1115-
entryPointComponent,
1116-
entryPointIndex,
1117-
0,
1118-
kernelCode.writeRef(),
1119-
diagnostics.writeRef());
1120-
if (diagnostics)
1113+
List<ComPtr<ISlangBlob>> kernelCodes;
11211114
{
1122-
DebugMessageType msgType = DebugMessageType::Warning;
1123-
if (compileResult != SLANG_OK)
1124-
msgType = DebugMessageType::Error;
1125-
getDebugCallback()->handleMessage(
1126-
msgType,
1127-
DebugMessageSource::Slang,
1128-
(char*)diagnostics->getBufferPointer());
1115+
ComPtr<ISlangBlob> downstreamIR;
1116+
ComPtr<ISlangBlob> diagnostics;
1117+
auto compileResult = device->getEntryPointCodeFromShaderCache(
1118+
entryPointComponent,
1119+
entryPointIndex,
1120+
0,
1121+
downstreamIR.writeRef(),
1122+
diagnostics.writeRef());
1123+
if (diagnostics)
1124+
{
1125+
DebugMessageType msgType = DebugMessageType::Warning;
1126+
if (compileResult != SLANG_OK)
1127+
msgType = DebugMessageType::Error;
1128+
getDebugCallback()->handleMessage(
1129+
msgType,
1130+
DebugMessageSource::Slang,
1131+
(char*)diagnostics->getBufferPointer());
1132+
}
1133+
kernelCodes.add(downstreamIR);
11291134
}
1130-
SLANG_RETURN_ON_FAIL(compileResult);
1131-
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));
1135+
1136+
// If target precompilation was used, kernelCode may only represent the
1137+
// glue code holding together the bits of precompiled target IR.
1138+
// Collect those dependency target IRs too.
1139+
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
1140+
if (entryPointComponent->queryInterface(
1141+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1142+
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
1143+
{
1144+
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
1145+
if (dependencyCount > 0)
1146+
{
1147+
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
1148+
{
1149+
ComPtr<slang::IModule> dependencyModule;
1150+
{
1151+
ComPtr<slang::IBlob> diagnosticsBlob;
1152+
auto result = componentPrecompileService->getModuleDependency(
1153+
dependencyIndex,
1154+
dependencyModule.writeRef(),
1155+
diagnosticsBlob.writeRef());
1156+
if (diagnosticsBlob)
1157+
{
1158+
DebugMessageType msgType = DebugMessageType::Warning;
1159+
if (result != SLANG_OK)
1160+
msgType = DebugMessageType::Error;
1161+
getDebugCallback()->handleMessage(
1162+
msgType,
1163+
DebugMessageSource::Slang,
1164+
(char*)diagnosticsBlob->getBufferPointer());
1165+
}
1166+
SLANG_RETURN_ON_FAIL(result);
1167+
}
1168+
1169+
ComPtr<slang::IBlob> downstreamIR;
1170+
{
1171+
ComPtr<slang::IBlob> diagnosticsBlob;
1172+
SlangResult result = SLANG_OK;
1173+
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
1174+
result = dependencyModule->queryInterface(
1175+
slang::IModulePrecompileService_Experimental::getTypeGuid(),
1176+
(void**)precompileService.writeRef());
1177+
if (result == SLANG_OK)
1178+
{
1179+
ComPtr<slang::IBlob> diagnosticsBlob;
1180+
auto result = precompileService->getPrecompiledTargetCode(
1181+
compileTarget,
1182+
downstreamIR.writeRef(),
1183+
diagnosticsBlob.writeRef());
1184+
if (result == SLANG_OK)
1185+
{
1186+
kernelCodes.add(downstreamIR);
1187+
}
1188+
if (diagnosticsBlob)
1189+
{
1190+
DebugMessageType msgType = DebugMessageType::Warning;
1191+
if (result != SLANG_OK)
1192+
msgType = DebugMessageType::Error;
1193+
getDebugCallback()->handleMessage(
1194+
msgType,
1195+
DebugMessageSource::Slang,
1196+
(char*)diagnosticsBlob->getBufferPointer());
1197+
}
1198+
}
1199+
SLANG_RETURN_ON_FAIL(result);
1200+
}
1201+
}
1202+
}
1203+
}
1204+
1205+
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
11321206
return SLANG_OK;
11331207
};
11341208

@@ -1160,10 +1234,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
11601234

11611235
Result ShaderProgramBase::createShaderModule(
11621236
slang::EntryPointReflection* entryPointInfo,
1163-
ComPtr<ISlangBlob> kernelCode)
1237+
List<ComPtr<ISlangBlob>>& kernelCodes)
11641238
{
11651239
SLANG_UNUSED(entryPointInfo);
1166-
SLANG_UNUSED(kernelCode);
1240+
SLANG_UNUSED(kernelCodes);
11671241
return SLANG_OK;
11681242
}
11691243

tools/gfx/renderer-shared.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
877877
Slang::Result compileShaders(RendererBase* device);
878878
virtual Slang::Result createShaderModule(
879879
slang::EntryPointReflection* entryPointInfo,
880-
Slang::ComPtr<ISlangBlob> kernelCode);
880+
Slang::List<Slang::ComPtr<ISlangBlob>>& kernelCodes);
881881

882882
virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
883883
findTypeByName(const char* name) override

tools/gfx/slang-context.h

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class SlangContext
1010
public:
1111
Slang::ComPtr<slang::IGlobalSession> globalSession;
1212
Slang::ComPtr<slang::ISession> session;
13+
SlangCompileTarget compileTarget;
1314
Result initialize(
1415
const gfx::IDevice::SlangDesc& desc,
1516
uint32_t extendedDescCount,
@@ -27,6 +28,7 @@ class SlangContext
2728
SLANG_RETURN_ON_FAIL(slang::createGlobalSession(globalSession.writeRef()));
2829
}
2930

31+
this->compileTarget = compileTarget;
3032
slang::SessionDesc slangSessionDesc = {};
3133
slangSessionDesc.defaultMatrixLayoutMode = desc.defaultMatrixLayoutMode;
3234
slangSessionDesc.searchPathCount = desc.searchPathCount;

0 commit comments

Comments
 (0)