diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 1123e1f2ad..455c924ca5 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2390,7 +2390,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val) } } -void legalizeRayTracingEntryPointParameterForGLSL( +void handleSingleParam( GLSLLegalizationContext* context, IRFunc* func, IRParam* pp, @@ -2442,6 +2442,136 @@ void legalizeRayTracingEntryPointParameterForGLSL( builder->addDependsOnDecoration(func, globalParam); } +static void consolidateParameters(GLSLLegalizationContext* context, List& params) +{ + auto builder = context->getBuilder(); + + // Create a struct type to hold all parameters + IRInst* consolidatedVar = nullptr; + auto structType = builder->createStructType(); + + // Inside the structure, add fields for each parameter + for (auto _param : params) + { + auto _paramType = _param->getDataType(); + IRType* valueType = _paramType; + + if (as(_paramType)) + valueType = as(_paramType)->getValueType(); + + auto key = builder->createStructKey(); + if (auto nameDecor = _param->findDecoration()) + builder->addNameHintDecoration(key, nameDecor->getName()); + auto field = builder->createStructField(structType, key, valueType); + field->removeFromParent(); + field->insertAtEnd(structType); + } + + // Create a global variable to hold the consolidated struct + consolidatedVar = builder->createGlobalVar(structType); + auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload); + consolidatedVar->setFullType(ptrType); + consolidatedVar->moveToEnd(); + + // Add the ray payload decoration and assign location 0. + builder->addVulkanRayPayloadDecoration(consolidatedVar, 0); + + // Replace each parameter with a field in the consolidated struct + for (Index i = 0; i < params.getCount(); ++i) + { + auto _param = params[i]; + + // Find the i-th field + IRStructField* targetField = nullptr; + Index fieldIndex = 0; + for (auto field : structType->getFields()) + { + if (fieldIndex == i) + { + targetField = field; + break; + } + fieldIndex++; + } + SLANG_ASSERT(targetField); + + // Create the field address with the correct type + auto _paramType = _param->getDataType(); + auto fieldType = targetField->getFieldType(); + + // If the parameter is an out/inout type, we need to create a pointer type + IRType* fieldPtrType = nullptr; + if (as(_paramType)) + { + fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType); + } + else if (as(_paramType)) + { + fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType); + } + + auto fieldAddr = + builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey()); + + // Replace parameter uses with field address + _param->replaceUsesWith(fieldAddr); + } +} + +// Consolidate ray tracing parameters for an entry point function +void consolidateRayTracingParameters(GLSLLegalizationContext* context, IRFunc* func) +{ + auto builder = context->getBuilder(); + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return; + + // Collect all out/inout parameters that need to be consolidated + List outParams; + List params; + + for (auto param = firstBlock->getFirstParam(); param; param = param->getNextParam()) + { + builder->setInsertBefore(firstBlock->getFirstOrdinaryInst()); + if (as(param->getDataType()) || as(param->getDataType())) + { + outParams.add(param); + } + params.add(param); + } + + // We don't need consolidation here. + if (outParams.getCount() <= 1) + { + for (auto param : params) + { + auto paramLayoutDecoration = param->findDecoration(); + SLANG_ASSERT(paramLayoutDecoration); + auto paramLayout = as(paramLayoutDecoration->getLayout()); + handleSingleParam(context, func, param, paramLayout); + } + return; + } + else + { + // We need consolidation here, but before that, handle parameters other than inout/out. + for (auto param : params) + { + if (outParams.contains(param)) + { + continue; + } + auto paramLayoutDecoration = param->findDecoration(); + SLANG_ASSERT(paramLayoutDecoration); + auto paramLayout = as(paramLayoutDecoration->getLayout()); + handleSingleParam(context, func, param, paramLayout); + } + + // Now, consolidate the inout/out parameters + consolidateParameters(context, outParams); + } +} + static void legalizeMeshPayloadInputParam( GLSLLegalizationContext* context, CodeGenContext* codeGenContext, @@ -3129,7 +3259,6 @@ void legalizeEntryPointParameterForGLSL( } } - // We need to create a global variable that will replace the parameter. // It seems superficially obvious that the variable should have // the same type as the parameter. @@ -3286,7 +3415,6 @@ void legalizeEntryPointParameterForGLSL( case Stage::Intersection: case Stage::Miss: case Stage::RayGeneration: - legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout); return; } @@ -3916,12 +4044,33 @@ void legalizeEntryPointForGLSL( invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput); } + // Special handling for ray tracing shaders + bool isRayTracingShader = false; + switch (stage) + { + case Stage::AnyHit: + case Stage::Callable: + case Stage::ClosestHit: + case Stage::Intersection: + case Stage::Miss: + case Stage::RayGeneration: + isRayTracingShader = true; + consolidateRayTracingParameters(&context, func); + break; + default: + break; + } + // Next we will walk through any parameters of the entry-point function, // and turn them into global variables. if (auto firstBlock = func->getFirstBlock()) { for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam()) { + if (isRayTracingShader) + { + continue; + } // Any initialization code we insert for parameters needs // to be at the start of the "ordinary" instructions in the block: builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); diff --git a/tests/vkray/multipleinout.slang b/tests/vkray/multipleinout.slang new file mode 100644 index 0000000000..52e1758b03 --- /dev/null +++ b/tests/vkray/multipleinout.slang @@ -0,0 +1,36 @@ +//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly + +// This test checks whether the spirv generated when there are multiple inout or out variables, they +// all get consolidated into one IncomingRayPayloadKHR. + +struct ReflectionRay +{ + float4 color; +}; + +StructuredBuffer colors; + +[shader("closesthit")] +void main( + BuiltInTriangleIntersectionAttributes attributes, + inout ReflectionRay ioPayload, + out float3 dummy) +{ + uint materialID = (InstanceIndex() << 1) + + InstanceID() + + PrimitiveIndex() + + HitKind(); + + ioPayload.color = colors[materialID]; + dummy = HitTriangleVertexPosition(0); +} + +// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}} +// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float +// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}} +// CHECK: %main = OpFunction %void None %{{.*}} +// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}} +// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID +// CHECK: %{{.*}} = OpLoad %v4float %{{.*}} +// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0 +// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}