Skip to content

Commit eefdd4a

Browse files
skallweitNVjkwak-workcsyonghe
authored
add support for callable shaders in gfx (shader-slang#3460)
Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
1 parent d9443d6 commit eefdd4a

8 files changed

+60
-9
lines changed

slang-gfx.h

+4
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,10 @@ class IShaderTable : public ISlangUnknown
14101410
const char** hitGroupNames;
14111411
const ShaderRecordOverwrite* hitGroupRecordOverwrites;
14121412

1413+
GfxCount callableShaderCount;
1414+
const char** callableShaderEntryPointNames;
1415+
const ShaderRecordOverwrite* callableShaderRecordOverwrites;
1416+
14131417
IShaderProgram* program;
14141418
};
14151419
};

tools/gfx/d3d12/d3d12-command-encoder.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,15 @@ Result RayTracingCommandEncoderImpl::dispatchRays(
14091409
dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
14101410
}
14111411

1412+
if (shaderTableImpl->m_callableShaderCount > 0)
1413+
{
1414+
dispatchDesc.CallableShaderTable.StartAddress =
1415+
shaderTableAddr + shaderTableImpl->m_callableTableOffset;
1416+
dispatchDesc.CallableShaderTable.SizeInBytes =
1417+
shaderTableImpl->m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
1418+
dispatchDesc.CallableShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
1419+
}
1420+
14121421
dispatchDesc.Width = (UINT)width;
14131422
dispatchDesc.Height = (UINT)height;
14141423
dispatchDesc.Depth = (UINT)depth;

tools/gfx/d3d12/d3d12-shader-table.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
2020
uint32_t raygenTableSize = m_rayGenShaderCount * kRayGenRecordSize;
2121
uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
2222
uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
23+
uint32_t callableTableSize = m_callableShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
2324
m_rayGenTableOffset = 0;
2425
m_missTableOffset = raygenTableSize;
2526
m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned(
2627
m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
27-
uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize;
28+
m_callableTableOffset = (uint32_t)D3DUtil::calcAligned(
29+
m_hitGroupTableOffset + hitgroupTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
30+
uint32_t tableSize = m_callableTableOffset + callableTableSize;
2831

2932
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
3033
ComPtr<IBufferResource> bufferResource;
@@ -88,6 +91,13 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
8891
m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i],
8992
m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]);
9093
}
94+
for (uint32_t i = 0; i < m_callableShaderCount; i++)
95+
{
96+
copyShaderIdInto(
97+
stagingBufferPtr + m_callableTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
98+
m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i],
99+
m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + m_hitGroupCount + i]);
100+
}
91101

92102
stagingBuffer->unmap(nullptr);
93103
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);

tools/gfx/d3d12/d3d12-shader-table.h

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class ShaderTableImpl : public ShaderTableBase
1616
uint32_t m_rayGenTableOffset;
1717
uint32_t m_missTableOffset;
1818
uint32_t m_hitGroupTableOffset;
19+
uint32_t m_callableTableOffset;
1920

2021
DeviceImpl* m_device;
2122

tools/gfx/renderer-shared.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -1264,8 +1264,9 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
12641264
m_rayGenShaderCount = desc.rayGenShaderCount;
12651265
m_missShaderCount = desc.missShaderCount;
12661266
m_hitGroupCount = desc.hitGroupCount;
1267-
m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
1268-
m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
1267+
m_callableShaderCount = desc.callableShaderCount;
1268+
m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount + desc.callableShaderCount);
1269+
m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount + desc.callableShaderCount);
12691270
for (GfxIndex i = 0; i < desc.rayGenShaderCount; i++)
12701271
{
12711272
m_shaderGroupNames.add(desc.rayGenShaderEntryPointNames[i]);
@@ -1302,6 +1303,18 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
13021303
m_recordOverwrites.add(ShaderRecordOverwrite{});
13031304
}
13041305
}
1306+
for (GfxIndex i = 0; i < desc.callableShaderCount; i++)
1307+
{
1308+
m_shaderGroupNames.add(desc.callableShaderEntryPointNames[i]);
1309+
if (desc.callableShaderRecordOverwrites)
1310+
{
1311+
m_recordOverwrites.add(desc.callableShaderRecordOverwrites[i]);
1312+
}
1313+
else
1314+
{
1315+
m_recordOverwrites.add(ShaderRecordOverwrite{});
1316+
}
1317+
}
13051318
return SLANG_OK;
13061319
}
13071320

tools/gfx/renderer-shared.h

+1
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ class ShaderTableBase
11831183
uint32_t m_rayGenShaderCount;
11841184
uint32_t m_missShaderCount;
11851185
uint32_t m_hitGroupCount;
1186+
uint32_t m_callableShaderCount;
11861187

11871188
Slang::Dictionary<PipelineStateBase*, Slang::RefPtr<BufferResource>> m_deviceBuffers;
11881189

tools/gfx/vulkan/vk-command-encoder.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -1480,11 +1480,10 @@ Result RayTracingCommandEncoder::dispatchRays(
14801480
hitSBT.stride = alignedHandleSize;
14811481
hitSBT.size = shaderTableImpl->m_hitTableSize;
14821482

1483-
// TODO: Are callable shaders needed?
14841483
VkStridedDeviceAddressRegionKHR callableSBT;
1485-
callableSBT.deviceAddress = 0;
1486-
callableSBT.stride = 0;
1487-
callableSBT.size = 0;
1484+
callableSBT.deviceAddress = hitSBT.deviceAddress + hitSBT.size;
1485+
callableSBT.stride = alignedHandleSize;
1486+
callableSBT.size = shaderTableImpl->m_callableTableSize;
14881487

14891488
vkApi.vkCmdTraceRaysKHR(
14901489
vkCommandBuffer,

tools/gfx/vulkan/vk-shader-table.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
2727
m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
2828
m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(
2929
m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment);
30-
m_callableTableSize = 0; // TODO: Are callable shaders needed?
30+
m_callableTableSize = (uint32_t)VulkanUtil::calcAligned(
31+
m_callableShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
3132
uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize;
3233

3334
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
@@ -122,7 +123,20 @@ RefPtr<BufferResource> ShaderTableImpl::createDeviceBuffer(
122123
}
123124
subTablePtr += m_hitTableSize;
124125

125-
// TODO: Callable shaders?
126+
for (uint32_t i = 0; i < m_callableShaderCount; i++)
127+
{
128+
auto dstHandlePtr = subTablePtr + i * handleSize;
129+
auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
130+
auto shaderGroupIndexPtr =
131+
pipelineImpl->shaderGroupNameToIndex.tryGetValue(shaderGroupName);
132+
if (!shaderGroupIndexPtr)
133+
continue;
134+
135+
auto shaderGroupIndex = *shaderGroupIndexPtr;
136+
auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
137+
memcpy(dstHandlePtr, srcHandlePtr, handleSize);
138+
}
139+
subTablePtr += m_callableTableSize;
126140

127141
stagingBuffer->unmap(nullptr);
128142
encoder->copyBuffer(bufferResource, 0, stagingBuffer, stagingBufferOffset, tableSize);

0 commit comments

Comments
 (0)