Skip to content

Commit 924e570

Browse files
committed
Consolidate multiple inout/outs into struct
Fixes #5121 VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700 requires that there be only one IncomingRayPayloadKHR per entry point. This change does two things: 1. If an entry point has the one inout or out, or has only ins, then stay with current implementation. 2. If there are multiple outs/inouts, then create a new structure to consolidate these fields and emit this structure. These two code paths are split into two separate functions for clarity. This patch also adds a new test: multipleinout.slang to test this.
1 parent 187ec44 commit 924e570

File tree

2 files changed

+239
-3
lines changed

2 files changed

+239
-3
lines changed

source/slang/slang-ir-glsl-legalize.cpp

+170-3
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ struct GLSLLegalizationContext
329329

330330
IRBuilder* builder;
331331
IRBuilder* getBuilder() { return builder; }
332+
333+
// For ray tracing shaders, we need to consolidate all parameters into a single structure
334+
Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars;
335+
Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams;
332336
};
333337

334338
// This examines the passed type and determines the GLSL mesh shader indices
@@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
23022306
}
23032307
}
23042308

2305-
void legalizeRayTracingEntryPointParameterForGLSL(
2309+
static void handleSingleParam(
23062310
GLSLLegalizationContext* context,
23072311
IRFunc* func,
23082312
IRParam* pp,
@@ -2354,6 +2358,149 @@ void legalizeRayTracingEntryPointParameterForGLSL(
23542358
builder->addDependsOnDecoration(func, globalParam);
23552359
}
23562360

2361+
static void consolidateParameters(
2362+
GLSLLegalizationContext* context,
2363+
IRFunc* func,
2364+
List<IRParam*>& params)
2365+
{
2366+
auto builder = context->getBuilder();
2367+
2368+
// Create a struct type to hold all parameters
2369+
IRInst* consolidatedVar = nullptr;
2370+
auto structType = builder->createStructType();
2371+
2372+
// Inside the structure, add fields for each parameter
2373+
for (auto _param : params)
2374+
{
2375+
auto _paramType = _param->getDataType();
2376+
IRType* valueType = _paramType;
2377+
2378+
if (as<IROutType>(_paramType))
2379+
valueType = as<IROutType>(_paramType)->getValueType();
2380+
else if (auto inOutType = as<IRInOutType>(_paramType))
2381+
valueType = inOutType->getValueType();
2382+
2383+
auto key = builder->createStructKey();
2384+
builder->addNameHintDecoration(key, UnownedStringSlice("field"));
2385+
auto field = builder->createStructField(structType, key, valueType);
2386+
field->removeFromParent();
2387+
field->insertAtEnd(structType);
2388+
}
2389+
2390+
// Create a global variable to hold the consolidated struct
2391+
consolidatedVar = builder->createGlobalVar(structType);
2392+
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
2393+
consolidatedVar->setFullType(ptrType);
2394+
consolidatedVar->moveToEnd();
2395+
2396+
// Add the ray payload decoration and assign location 0.
2397+
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);
2398+
2399+
// Store the consolidated variable for this function
2400+
context->rayTracingConsolidatedVars[func] = consolidatedVar;
2401+
2402+
// Replace each parameter with a field in the consolidated struct
2403+
for (Index i = 0; i < params.getCount(); ++i)
2404+
{
2405+
auto _param = params[i];
2406+
2407+
// Find the i-th field
2408+
IRStructField* targetField = nullptr;
2409+
Index fieldIndex = 0;
2410+
for (auto field : structType->getFields())
2411+
{
2412+
if (fieldIndex == i)
2413+
{
2414+
targetField = field;
2415+
break;
2416+
}
2417+
fieldIndex++;
2418+
}
2419+
SLANG_ASSERT(targetField);
2420+
2421+
// Create the field address with the correct type
2422+
auto _paramType = _param->getDataType();
2423+
auto fieldType = targetField->getFieldType();
2424+
2425+
// If the parameter is an out/inout type, we need to create a pointer type
2426+
IRType* fieldPtrType = nullptr;
2427+
if (as<IROutType>(_paramType))
2428+
{
2429+
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
2430+
}
2431+
else if (as<IRInOutType>(_paramType))
2432+
{
2433+
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
2434+
}
2435+
2436+
auto fieldAddr =
2437+
builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey());
2438+
2439+
// Replace parameter uses with field address
2440+
_param->replaceUsesWith(fieldAddr);
2441+
}
2442+
}
2443+
2444+
static void handleMultipleParams(GLSLLegalizationContext* context, IRFunc* func, IRParam* pp)
2445+
{
2446+
auto firstBlock = func->getFirstBlock();
2447+
2448+
// Now we run the consolidation step, but if we've already
2449+
// processed this parameter, skip it.
2450+
List<IRParam*>* processedParams = nullptr;
2451+
if (auto foundList = context->rayTracingProcessedParams.tryGetValue(func))
2452+
{
2453+
processedParams = foundList;
2454+
if (processedParams->contains(pp))
2455+
return;
2456+
}
2457+
else
2458+
{
2459+
context->rayTracingProcessedParams[func] = List<IRParam*>();
2460+
processedParams = &context->rayTracingProcessedParams[func];
2461+
}
2462+
2463+
// Collect all parameters that need to be consolidated
2464+
List<IRParam*> params;
2465+
List<IRVarLayout*> paramLayouts;
2466+
2467+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
2468+
{
2469+
auto pLayoutDecoration = _param->findDecoration<IRLayoutDecoration>();
2470+
SLANG_ASSERT(pLayoutDecoration);
2471+
auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout());
2472+
SLANG_ASSERT(pLayout);
2473+
2474+
// Only include parameters that haven't been processed yet
2475+
auto _paramType = _param->getDataType();
2476+
bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType));
2477+
if (!processedParams->contains(_param) && needsConsolidation)
2478+
{
2479+
params.add(_param);
2480+
paramLayouts.add(pLayout);
2481+
processedParams->add(_param);
2482+
}
2483+
}
2484+
2485+
consolidateParameters(context, func, params);
2486+
}
2487+
2488+
void legalizeRayTracingEntryPointParameterForGLSL(
2489+
GLSLLegalizationContext* context,
2490+
IRFunc* func,
2491+
IRParam* pp,
2492+
IRVarLayout* paramLayout,
2493+
bool hasSingleOutOrInOutParam)
2494+
{
2495+
if (hasSingleOutOrInOutParam)
2496+
{
2497+
handleSingleParam(context, func, pp, paramLayout);
2498+
return;
2499+
}
2500+
2501+
handleMultipleParams(context, func, pp);
2502+
}
2503+
23572504
static void legalizeMeshPayloadInputParam(
23582505
GLSLLegalizationContext* context,
23592506
CodeGenContext* codeGenContext,
@@ -3041,7 +3188,6 @@ void legalizeEntryPointParameterForGLSL(
30413188
}
30423189
}
30433190

3044-
30453191
// We need to create a global variable that will replace the parameter.
30463192
// It seems superficially obvious that the variable should have
30473193
// the same type as the parameter.
@@ -3198,7 +3344,28 @@ void legalizeEntryPointParameterForGLSL(
31983344
case Stage::Intersection:
31993345
case Stage::Miss:
32003346
case Stage::RayGeneration:
3201-
legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout);
3347+
{
3348+
// Count the number of inout or out parameters
3349+
int inoutOrOutParamCount = 0;
3350+
auto firstBlock = func->getFirstBlock();
3351+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
3352+
{
3353+
auto _paramType = _param->getDataType();
3354+
if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType))
3355+
{
3356+
inoutOrOutParamCount++;
3357+
}
3358+
}
3359+
3360+
// If we have just one inout or out param, we don't need consolidation.
3361+
bool hasSingleOutOrInOutParam = inoutOrOutParamCount <= 1;
3362+
legalizeRayTracingEntryPointParameterForGLSL(
3363+
context,
3364+
func,
3365+
pp,
3366+
paramLayout,
3367+
hasSingleOutOrInOutParam);
3368+
}
32023369
return;
32033370
}
32043371

tests/vkray/multipleinout.slang

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly
2+
3+
// This test checks whether the spirv generated when there are multiple inout or out variables, they
4+
// all get consolidated into one IncomingRayPayloadKHR.
5+
6+
struct ReflectionRay
7+
{
8+
float4 color;
9+
};
10+
11+
StructuredBuffer<float4> colors;
12+
13+
[shader("closesthit")]
14+
void main(
15+
BuiltInTriangleIntersectionAttributes attributes,
16+
inout ReflectionRay ioPayload,
17+
out float3 dummy)
18+
{
19+
uint materialID = (InstanceIndex() << 1)
20+
+ InstanceID()
21+
+ PrimitiveIndex()
22+
+ HitKind();
23+
24+
ioPayload.color = colors[materialID];
25+
dummy = HitTriangleVertexPosition(0);
26+
}
27+
28+
// CHECK: OpCapability RayTracingKHR
29+
// CHECK: OpCapability RayTracingPositionFetchKHR
30+
// CHECK: OpCapability Shader
31+
// CHECK: OpExtension "SPV_KHR_ray_tracing"
32+
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
33+
// CHECK: OpExtension "SPV_KHR_ray_tracing_position_fetch"
34+
// CHECK: OpMemoryModel Logical GLSL450
35+
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
36+
37+
// CHECK: OpName %ReflectionRay "ReflectionRay"
38+
// CHECK: OpMemberName %ReflectionRay 0 "color"
39+
// CHECK: OpName %materialID "materialID"
40+
// CHECK: OpName %StructuredBuffer "StructuredBuffer"
41+
// CHECK: OpName %colors "colors"
42+
// CHECK: OpName %main "main"
43+
44+
// CHECK-DAG: OpDecorate %gl_InstanceID BuiltIn InstanceId
45+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn InstanceCustomIndexKHR
46+
// CHECK-DAG: OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
47+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitKindKHR
48+
// CHECK-DAG: OpDecorate %_runtimearr_v4float ArrayStride 16
49+
// CHECK-DAG: OpDecorate %StructuredBuffer Block
50+
// CHECK-DAG: OpDecorate %colors Binding 0
51+
// CHECK-DAG: OpDecorate %colors DescriptorSet 0
52+
// CHECK-DAG: OpDecorate %colors NonWritable
53+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitTriangleVertexPositionsKHR
54+
55+
// CHECK: %ReflectionRay = OpTypeStruct %v4float
56+
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
57+
// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
58+
// CHECK: %_ptr_IncomingRayPayloadKHR_ReflectionRay = OpTypePointer IncomingRayPayloadKHR %ReflectionRay
59+
// CHECK: %_ptr_IncomingRayPayloadKHR_v3float = OpTypePointer IncomingRayPayloadKHR %v3float
60+
// CHECK: %StructuredBuffer = OpTypeStruct %_runtimearr_v4float
61+
// CHECK: %_ptr_StorageBuffer_StructuredBuffer = OpTypePointer StorageBuffer %StructuredBuffer
62+
// CHECK: %_ptr_Input__arr_v3float_{{.*}} = OpTypePointer Input %_arr_v3float_{{.*}}
63+
64+
// CHECK: %main = OpFunction %void None %{{.*}}
65+
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
66+
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
67+
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
68+
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
69+
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}

0 commit comments

Comments
 (0)