Skip to content

Commit 9d8ec3e

Browse files
committed
Consolidate multiple inout/outs into struct
Fixes #5121 VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700 requires that there be only one IncomingRayPayloadKHR per entry point. This change does two things: 1. If an entry point has the one inout or out, or has only ins, then stay with current implementation. 2. If there are multiple outs/inouts, then create a new structure to consolidate these fields and emit this structure. These two code paths are split into two separate functions for clarity. This patch also adds a new test: multipleinout.slang to test this.
1 parent 187ec44 commit 9d8ec3e

File tree

2 files changed

+239
-39
lines changed

2 files changed

+239
-39
lines changed

source/slang/slang-ir-glsl-legalize.cpp

+170-39
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ struct GLSLLegalizationContext
329329

330330
IRBuilder* builder;
331331
IRBuilder* getBuilder() { return builder; }
332+
333+
// For ray tracing shaders, we need to consolidate all parameters into a single structure
334+
Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars;
335+
Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams;
332336
};
333337

334338
// This examines the passed type and determines the GLSL mesh shader indices
@@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
23022306
}
23032307
}
23042308

2305-
void legalizeRayTracingEntryPointParameterForGLSL(
2309+
static void handleSingleParam(
23062310
GLSLLegalizationContext* context,
23072311
IRFunc* func,
23082312
IRParam* pp,
@@ -2311,49 +2315,156 @@ void legalizeRayTracingEntryPointParameterForGLSL(
23112315
auto builder = context->getBuilder();
23122316
auto paramType = pp->getDataType();
23132317

2314-
// The parameter might be either an `in` parameter,
2315-
// or an `out` or `in out` parameter, and in those
2316-
// latter cases its IR-level type will include a
2317-
// wrapping "pointer-like" type (e.g., `Out<Float>`
2318-
// instead of just `Float`).
2319-
//
2320-
// Because global shader parameters are read-only
2321-
// in the same way function types are, we can take
2322-
// care of that detail here just by allocating a
2323-
// global shader parameter with exactly the type
2324-
// of the original function parameter.
2325-
//
23262318
auto globalParam = addGlobalParam(builder->getModule(), paramType);
23272319
builder->addLayoutDecoration(globalParam, paramLayout);
23282320
moveValueBefore(globalParam, builder->getFunc());
23292321
pp->replaceUsesWith(globalParam);
2330-
2331-
// Because linkage between ray-tracing shaders is
2332-
// based on the type of incoming/outgoing payload
2333-
// and attribute parameters, it would be an error to
2334-
// eliminate the global parameter *even if* it is
2335-
// not actually used inside the entry point.
2336-
//
2337-
// We attach a decoration to the entry point that
2338-
// makes note of the dependency, so that steps
2339-
// like dead code elimination cannot get rid of
2340-
// the parameter.
2341-
//
2342-
// TODO: We could consider using a structure like
2343-
// this for *all* of the entry point parameters
2344-
// that get moved to the global scope, since SPIR-V
2345-
// ends up requiring such information on an `OpEntryPoint`.
2346-
//
2347-
// As a further alternative, we could decide to
2348-
// keep entry point varying input/outtput attached
2349-
// to the parameter list through all of the Slang IR
2350-
// steps, and only declare it as global variables at
2351-
// the last minute when emitting a GLSL `main` or
2352-
// SPIR-V for an entry point.
2353-
//
23542322
builder->addDependsOnDecoration(func, globalParam);
23552323
}
23562324

2325+
static void consolidateParameters(
2326+
GLSLLegalizationContext* context,
2327+
IRFunc* func,
2328+
List<IRParam*>& params)
2329+
{
2330+
auto builder = context->getBuilder();
2331+
2332+
// Create a struct type to hold all parameters
2333+
IRInst* consolidatedVar = nullptr;
2334+
auto structType = builder->createStructType();
2335+
2336+
// Inside the structure, add fields for each parameter
2337+
for (auto _param : params)
2338+
{
2339+
auto _paramType = _param->getDataType();
2340+
IRType* valueType = _paramType;
2341+
2342+
if (as<IROutType>(_paramType))
2343+
valueType = as<IROutType>(_paramType)->getValueType();
2344+
else if (auto inOutType = as<IRInOutType>(_paramType))
2345+
valueType = inOutType->getValueType();
2346+
2347+
auto key = builder->createStructKey();
2348+
builder->addNameHintDecoration(key, UnownedStringSlice("field"));
2349+
auto field = builder->createStructField(structType, key, valueType);
2350+
field->removeFromParent();
2351+
field->insertAtEnd(structType);
2352+
}
2353+
2354+
// Create a global variable to hold the consolidated struct
2355+
consolidatedVar = builder->createGlobalVar(structType);
2356+
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
2357+
consolidatedVar->setFullType(ptrType);
2358+
consolidatedVar->moveToEnd();
2359+
2360+
// Add the ray payload decoration and assign location 0.
2361+
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);
2362+
2363+
// Store the consolidated variable for this function
2364+
context->rayTracingConsolidatedVars[func] = consolidatedVar;
2365+
2366+
// Replace each parameter with a field in the consolidated struct
2367+
for (Index i = 0; i < params.getCount(); ++i)
2368+
{
2369+
auto _param = params[i];
2370+
2371+
// Find the i-th field
2372+
IRStructField* targetField = nullptr;
2373+
Index fieldIndex = 0;
2374+
for (auto field : structType->getFields())
2375+
{
2376+
if (fieldIndex == i)
2377+
{
2378+
targetField = field;
2379+
break;
2380+
}
2381+
fieldIndex++;
2382+
}
2383+
SLANG_ASSERT(targetField);
2384+
2385+
// Create the field address with the correct type
2386+
auto _paramType = _param->getDataType();
2387+
auto fieldType = targetField->getFieldType();
2388+
2389+
// If the parameter is an out/inout type, we need to create a pointer type
2390+
IRType* fieldPtrType = nullptr;
2391+
if (as<IROutType>(_paramType))
2392+
{
2393+
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
2394+
}
2395+
else if (as<IRInOutType>(_paramType))
2396+
{
2397+
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
2398+
}
2399+
2400+
auto fieldAddr =
2401+
builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey());
2402+
2403+
// Replace parameter uses with field address
2404+
_param->replaceUsesWith(fieldAddr);
2405+
}
2406+
}
2407+
2408+
static void handleMultipleParams(GLSLLegalizationContext* context, IRFunc* func, IRParam* pp)
2409+
{
2410+
auto firstBlock = func->getFirstBlock();
2411+
2412+
// Now we run the consolidation step, but if we've already
2413+
// processed this parameter, skip it.
2414+
List<IRParam*>* processedParams = nullptr;
2415+
if (auto foundList = context->rayTracingProcessedParams.tryGetValue(func))
2416+
{
2417+
processedParams = foundList;
2418+
if (processedParams->contains(pp))
2419+
return;
2420+
}
2421+
else
2422+
{
2423+
context->rayTracingProcessedParams[func] = List<IRParam*>();
2424+
processedParams = &context->rayTracingProcessedParams[func];
2425+
}
2426+
2427+
// Collect all parameters that need to be consolidated
2428+
List<IRParam*> params;
2429+
List<IRVarLayout*> paramLayouts;
2430+
2431+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
2432+
{
2433+
auto pLayoutDecoration = _param->findDecoration<IRLayoutDecoration>();
2434+
SLANG_ASSERT(pLayoutDecoration);
2435+
auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout());
2436+
SLANG_ASSERT(pLayout);
2437+
2438+
// Only include parameters that haven't been processed yet
2439+
auto _paramType = _param->getDataType();
2440+
bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType));
2441+
if (!processedParams->contains(_param) && needsConsolidation)
2442+
{
2443+
params.add(_param);
2444+
paramLayouts.add(pLayout);
2445+
processedParams->add(_param);
2446+
}
2447+
}
2448+
2449+
consolidateParameters(context, func, params);
2450+
}
2451+
2452+
void legalizeRayTracingEntryPointParameterForGLSL(
2453+
GLSLLegalizationContext* context,
2454+
IRFunc* func,
2455+
IRParam* pp,
2456+
IRVarLayout* paramLayout,
2457+
bool hasSingleOutOrInOutParam)
2458+
{
2459+
if (hasSingleOutOrInOutParam)
2460+
{
2461+
handleSingleParam(context, func, pp, paramLayout);
2462+
return;
2463+
}
2464+
2465+
handleMultipleParams(context, func, pp);
2466+
}
2467+
23572468
static void legalizeMeshPayloadInputParam(
23582469
GLSLLegalizationContext* context,
23592470
CodeGenContext* codeGenContext,
@@ -3041,7 +3152,6 @@ void legalizeEntryPointParameterForGLSL(
30413152
}
30423153
}
30433154

3044-
30453155
// We need to create a global variable that will replace the parameter.
30463156
// It seems superficially obvious that the variable should have
30473157
// the same type as the parameter.
@@ -3198,7 +3308,28 @@ void legalizeEntryPointParameterForGLSL(
31983308
case Stage::Intersection:
31993309
case Stage::Miss:
32003310
case Stage::RayGeneration:
3201-
legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout);
3311+
{
3312+
// Count the number of inout or out parameters
3313+
int inoutOrOutParamCount = 0;
3314+
auto firstBlock = func->getFirstBlock();
3315+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
3316+
{
3317+
auto _paramType = _param->getDataType();
3318+
if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType))
3319+
{
3320+
inoutOrOutParamCount++;
3321+
}
3322+
}
3323+
3324+
// If we have just one inout or out param, we don't need consolidation.
3325+
bool hasSingleOutOrInOutParam = inoutOrOutParamCount <= 1;
3326+
legalizeRayTracingEntryPointParameterForGLSL(
3327+
context,
3328+
func,
3329+
pp,
3330+
paramLayout,
3331+
hasSingleOutOrInOutParam);
3332+
}
32023333
return;
32033334
}
32043335

tests/vkray/multipleinout.slang

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly
2+
3+
// This test checks whether the spirv generated when there are multiple inout or out variables, they
4+
// all get consolidated into one IncomingRayPayloadKHR.
5+
6+
struct ReflectionRay
7+
{
8+
float4 color;
9+
};
10+
11+
StructuredBuffer<float4> colors;
12+
13+
[shader("closesthit")]
14+
void main(
15+
BuiltInTriangleIntersectionAttributes attributes,
16+
inout ReflectionRay ioPayload,
17+
out float3 dummy)
18+
{
19+
uint materialID = (InstanceIndex() << 1)
20+
+ InstanceID()
21+
+ PrimitiveIndex()
22+
+ HitKind();
23+
24+
ioPayload.color = colors[materialID];
25+
dummy = HitTriangleVertexPosition(0);
26+
}
27+
28+
// CHECK: OpCapability RayTracingKHR
29+
// CHECK: OpCapability RayTracingPositionFetchKHR
30+
// CHECK: OpCapability Shader
31+
// CHECK: OpExtension "SPV_KHR_ray_tracing"
32+
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
33+
// CHECK: OpExtension "SPV_KHR_ray_tracing_position_fetch"
34+
// CHECK: OpMemoryModel Logical GLSL450
35+
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
36+
37+
// CHECK: OpName %ReflectionRay "ReflectionRay"
38+
// CHECK: OpMemberName %ReflectionRay 0 "color"
39+
// CHECK: OpName %materialID "materialID"
40+
// CHECK: OpName %StructuredBuffer "StructuredBuffer"
41+
// CHECK: OpName %colors "colors"
42+
// CHECK: OpName %main "main"
43+
44+
// CHECK-DAG: OpDecorate %gl_InstanceID BuiltIn InstanceId
45+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn InstanceCustomIndexKHR
46+
// CHECK-DAG: OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
47+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitKindKHR
48+
// CHECK-DAG: OpDecorate %_runtimearr_v4float ArrayStride 16
49+
// CHECK-DAG: OpDecorate %StructuredBuffer Block
50+
// CHECK-DAG: OpDecorate %colors Binding 0
51+
// CHECK-DAG: OpDecorate %colors DescriptorSet 0
52+
// CHECK-DAG: OpDecorate %colors NonWritable
53+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitTriangleVertexPositionsKHR
54+
55+
// CHECK: %ReflectionRay = OpTypeStruct %v4float
56+
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
57+
// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
58+
// CHECK: %_ptr_IncomingRayPayloadKHR_ReflectionRay = OpTypePointer IncomingRayPayloadKHR %ReflectionRay
59+
// CHECK: %_ptr_IncomingRayPayloadKHR_v3float = OpTypePointer IncomingRayPayloadKHR %v3float
60+
// CHECK: %StructuredBuffer = OpTypeStruct %_runtimearr_v4float
61+
// CHECK: %_ptr_StorageBuffer_StructuredBuffer = OpTypePointer StorageBuffer %StructuredBuffer
62+
// CHECK: %_ptr_Input__arr_v3float_{{.*}} = OpTypePointer Input %_arr_v3float_{{.*}}
63+
64+
// CHECK: %main = OpFunction %void None %{{.*}}
65+
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
66+
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
67+
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
68+
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
69+
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}

0 commit comments

Comments
 (0)