diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index d3d4a764c4..c9f367c30a 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -131,6 +131,7 @@ class RenderTestApp ComPtr m_vertexBuffer; ComPtr m_shaderProgram; ComPtr m_pipeline; + ComPtr m_shaderTable; ComPtr m_depthBuffer; ComPtr m_depthBufferView; ComPtr m_colorBuffer; @@ -648,6 +649,7 @@ SlangResult RenderTestApp::initialize( m_pipeline = device->createRenderPipeline(desc); } break; + case Options::ShaderProgramType::GraphicsMeshCompute: case Options::ShaderProgramType::GraphicsTaskMeshCompute: { @@ -660,6 +662,33 @@ SlangResult RenderTestApp::initialize( desc.depthStencil.format = Format::D32_FLOAT; m_pipeline = device->createRenderPipeline(desc); } + break; + + case Options::ShaderProgramType::RayTracing: + { + RayTracingPipelineDesc desc; + desc.program = m_shaderProgram; + + m_pipeline = device->createRayTracingPipeline(desc); + + const char* raygenNames[] = {"raygenMain"}; + + // We don't define a miss shader for this test. OptiX allows + // passing nullptr to indicate no miss shader, but something in + // slang-rhi assumes that the miss shader always has a name. To + // work around that, use a dummy name. + const char* missNames[] = {"missNull"}; + + ShaderTableDesc shaderTableDesc = {}; + shaderTableDesc.program = m_shaderProgram; + shaderTableDesc.rayGenShaderCount = 1; + shaderTableDesc.rayGenShaderEntryPointNames = raygenNames; + shaderTableDesc.missShaderCount = 1; + shaderTableDesc.missShaderEntryPointNames = missNames; + SLANG_RETURN_ON_FAIL( + device->createShaderTable(shaderTableDesc, m_shaderTable.writeRef())); + } + break; } } // If success must have a pipeline state @@ -972,6 +1001,25 @@ Result RenderTestApp::update() m_options.computeDispatchSize[2]); passEncoder->end(); } + else if (m_options.shaderType == Options::ShaderProgramType::RayTracing) + { + auto rootObject = m_device->createRootShaderObject(m_pipeline); + applyBinding(rootObject); + rootObject->finalize(); + + auto passEncoder = encoder->beginRayTracingPass(); + RayTracingState state; + state.pipeline = static_cast(m_pipeline.get()); + state.rootObject = rootObject; + state.shaderTable = m_shaderTable; + passEncoder->setRayTracingState(state); + passEncoder->dispatchRays( + 0, + m_options.computeDispatchSize[0], + m_options.computeDispatchSize[1], + m_options.computeDispatchSize[2]); + passEncoder->end(); + } else { auto rootObject = m_device->createRootShaderObject(m_pipeline); @@ -1072,7 +1120,8 @@ Result RenderTestApp::update() if (m_options.shaderType == Options::ShaderProgramType::Compute || m_options.shaderType == Options::ShaderProgramType::GraphicsCompute || m_options.shaderType == Options::ShaderProgramType::GraphicsMeshCompute || - m_options.shaderType == Options::ShaderProgramType::GraphicsTaskMeshCompute) + m_options.shaderType == Options::ShaderProgramType::GraphicsTaskMeshCompute || + m_options.shaderType == Options::ShaderProgramType::RayTracing) { SLANG_RETURN_ON_FAIL(writeBindingOutput(m_options.outputPath)); }