Skip to content

Commit b86925c

Browse files
mkeshavaNVslangbotcsyonghe
authored
Consolidate multiple inouts/outs into struct (#6435)
* Consolidate multiple inout/outs into struct Fixes #5121 VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700 requires that there be only one IncomingRayPayloadKHR per entry point. This change does two things: 1. If an entry point has the one inout or out, or has only ins, then stay with current implementation. 2. If there are multiple outs/inouts, then create a new structure to consolidate these fields and emit this structure. These two code paths are split into two separate functions for clarity. This patch also adds a new test: multipleinout.slang to test this. * Address review comments * Refactor code as per review comments * format code * fix failing tests --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
1 parent dd9d24d commit b86925c

File tree

2 files changed

+188
-3
lines changed

2 files changed

+188
-3
lines changed

source/slang/slang-ir-glsl-legalize.cpp

+152-3
Original file line numberDiff line numberDiff line change
@@ -2390,7 +2390,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
23902390
}
23912391
}
23922392

2393-
void legalizeRayTracingEntryPointParameterForGLSL(
2393+
void handleSingleParam(
23942394
GLSLLegalizationContext* context,
23952395
IRFunc* func,
23962396
IRParam* pp,
@@ -2442,6 +2442,136 @@ void legalizeRayTracingEntryPointParameterForGLSL(
24422442
builder->addDependsOnDecoration(func, globalParam);
24432443
}
24442444

2445+
static void consolidateParameters(GLSLLegalizationContext* context, List<IRParam*>& params)
2446+
{
2447+
auto builder = context->getBuilder();
2448+
2449+
// Create a struct type to hold all parameters
2450+
IRInst* consolidatedVar = nullptr;
2451+
auto structType = builder->createStructType();
2452+
2453+
// Inside the structure, add fields for each parameter
2454+
for (auto _param : params)
2455+
{
2456+
auto _paramType = _param->getDataType();
2457+
IRType* valueType = _paramType;
2458+
2459+
if (as<IROutTypeBase>(_paramType))
2460+
valueType = as<IROutTypeBase>(_paramType)->getValueType();
2461+
2462+
auto key = builder->createStructKey();
2463+
if (auto nameDecor = _param->findDecoration<IRNameHintDecoration>())
2464+
builder->addNameHintDecoration(key, nameDecor->getName());
2465+
auto field = builder->createStructField(structType, key, valueType);
2466+
field->removeFromParent();
2467+
field->insertAtEnd(structType);
2468+
}
2469+
2470+
// Create a global variable to hold the consolidated struct
2471+
consolidatedVar = builder->createGlobalVar(structType);
2472+
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
2473+
consolidatedVar->setFullType(ptrType);
2474+
consolidatedVar->moveToEnd();
2475+
2476+
// Add the ray payload decoration and assign location 0.
2477+
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);
2478+
2479+
// Replace each parameter with a field in the consolidated struct
2480+
for (Index i = 0; i < params.getCount(); ++i)
2481+
{
2482+
auto _param = params[i];
2483+
2484+
// Find the i-th field
2485+
IRStructField* targetField = nullptr;
2486+
Index fieldIndex = 0;
2487+
for (auto field : structType->getFields())
2488+
{
2489+
if (fieldIndex == i)
2490+
{
2491+
targetField = field;
2492+
break;
2493+
}
2494+
fieldIndex++;
2495+
}
2496+
SLANG_ASSERT(targetField);
2497+
2498+
// Create the field address with the correct type
2499+
auto _paramType = _param->getDataType();
2500+
auto fieldType = targetField->getFieldType();
2501+
2502+
// If the parameter is an out/inout type, we need to create a pointer type
2503+
IRType* fieldPtrType = nullptr;
2504+
if (as<IROutType>(_paramType))
2505+
{
2506+
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
2507+
}
2508+
else if (as<IRInOutType>(_paramType))
2509+
{
2510+
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
2511+
}
2512+
2513+
auto fieldAddr =
2514+
builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey());
2515+
2516+
// Replace parameter uses with field address
2517+
_param->replaceUsesWith(fieldAddr);
2518+
}
2519+
}
2520+
2521+
// Consolidate ray tracing parameters for an entry point function
2522+
void consolidateRayTracingParameters(GLSLLegalizationContext* context, IRFunc* func)
2523+
{
2524+
auto builder = context->getBuilder();
2525+
auto firstBlock = func->getFirstBlock();
2526+
if (!firstBlock)
2527+
return;
2528+
2529+
// Collect all out/inout parameters that need to be consolidated
2530+
List<IRParam*> outParams;
2531+
List<IRParam*> params;
2532+
2533+
for (auto param = firstBlock->getFirstParam(); param; param = param->getNextParam())
2534+
{
2535+
builder->setInsertBefore(firstBlock->getFirstOrdinaryInst());
2536+
if (as<IROutType>(param->getDataType()) || as<IRInOutType>(param->getDataType()))
2537+
{
2538+
outParams.add(param);
2539+
}
2540+
params.add(param);
2541+
}
2542+
2543+
// We don't need consolidation here.
2544+
if (outParams.getCount() <= 1)
2545+
{
2546+
for (auto param : params)
2547+
{
2548+
auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
2549+
SLANG_ASSERT(paramLayoutDecoration);
2550+
auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
2551+
handleSingleParam(context, func, param, paramLayout);
2552+
}
2553+
return;
2554+
}
2555+
else
2556+
{
2557+
// We need consolidation here, but before that, handle parameters other than inout/out.
2558+
for (auto param : params)
2559+
{
2560+
if (outParams.contains(param))
2561+
{
2562+
continue;
2563+
}
2564+
auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
2565+
SLANG_ASSERT(paramLayoutDecoration);
2566+
auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
2567+
handleSingleParam(context, func, param, paramLayout);
2568+
}
2569+
2570+
// Now, consolidate the inout/out parameters
2571+
consolidateParameters(context, outParams);
2572+
}
2573+
}
2574+
24452575
static void legalizeMeshPayloadInputParam(
24462576
GLSLLegalizationContext* context,
24472577
CodeGenContext* codeGenContext,
@@ -3129,7 +3259,6 @@ void legalizeEntryPointParameterForGLSL(
31293259
}
31303260
}
31313261

3132-
31333262
// We need to create a global variable that will replace the parameter.
31343263
// It seems superficially obvious that the variable should have
31353264
// the same type as the parameter.
@@ -3286,7 +3415,6 @@ void legalizeEntryPointParameterForGLSL(
32863415
case Stage::Intersection:
32873416
case Stage::Miss:
32883417
case Stage::RayGeneration:
3289-
legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout);
32903418
return;
32913419
}
32923420

@@ -3916,12 +4044,33 @@ void legalizeEntryPointForGLSL(
39164044
invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput);
39174045
}
39184046

4047+
// Special handling for ray tracing shaders
4048+
bool isRayTracingShader = false;
4049+
switch (stage)
4050+
{
4051+
case Stage::AnyHit:
4052+
case Stage::Callable:
4053+
case Stage::ClosestHit:
4054+
case Stage::Intersection:
4055+
case Stage::Miss:
4056+
case Stage::RayGeneration:
4057+
isRayTracingShader = true;
4058+
consolidateRayTracingParameters(&context, func);
4059+
break;
4060+
default:
4061+
break;
4062+
}
4063+
39194064
// Next we will walk through any parameters of the entry-point function,
39204065
// and turn them into global variables.
39214066
if (auto firstBlock = func->getFirstBlock())
39224067
{
39234068
for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam())
39244069
{
4070+
if (isRayTracingShader)
4071+
{
4072+
continue;
4073+
}
39254074
// Any initialization code we insert for parameters needs
39264075
// to be at the start of the "ordinary" instructions in the block:
39274076
builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());

tests/vkray/multipleinout.slang

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly
2+
3+
// This test checks whether the spirv generated when there are multiple inout or out variables, they
4+
// all get consolidated into one IncomingRayPayloadKHR.
5+
6+
struct ReflectionRay
7+
{
8+
float4 color;
9+
};
10+
11+
StructuredBuffer<float4> colors;
12+
13+
[shader("closesthit")]
14+
void main(
15+
BuiltInTriangleIntersectionAttributes attributes,
16+
inout ReflectionRay ioPayload,
17+
out float3 dummy)
18+
{
19+
uint materialID = (InstanceIndex() << 1)
20+
+ InstanceID()
21+
+ PrimitiveIndex()
22+
+ HitKind();
23+
24+
ioPayload.color = colors[materialID];
25+
dummy = HitTriangleVertexPosition(0);
26+
}
27+
28+
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
29+
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
30+
// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
31+
// CHECK: %main = OpFunction %void None %{{.*}}
32+
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
33+
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
34+
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
35+
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
36+
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}

0 commit comments

Comments
 (0)