Skip to content

Commit 1b2e686

Browse files
committed
Consolidate multiple inout/outs into struct
Fixes #5121 VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700 requires that there be only one IncomingRayPayloadKHR per entry point. This change does two things: 1. If an entry point has the one inout or out, or has only ins, then stay with current implementation. 2. If there are multiple outs/inouts, then create a new structure to consolidate these fields and emit this structure. These two code paths are split into two separate functions for clarity. This patch also adds a new test: multipleinout.slang to test this.
1 parent 187ec44 commit 1b2e686

File tree

2 files changed

+237
-37
lines changed

2 files changed

+237
-37
lines changed

source/slang/slang-ir-glsl-legalize.cpp

+166-37
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ struct GLSLLegalizationContext
329329

330330
IRBuilder* builder;
331331
IRBuilder* getBuilder() { return builder; }
332+
333+
// For ray tracing shaders, we need to consolidate all parameters into a single structure
334+
Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars;
335+
Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams;
332336
};
333337

334338
// This examines the passed type and determines the GLSL mesh shader indices
@@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
23022306
}
23032307
}
23042308

2305-
void legalizeRayTracingEntryPointParameterForGLSL(
2309+
static void handleSingleParam(
23062310
GLSLLegalizationContext* context,
23072311
IRFunc* func,
23082312
IRParam* pp,
@@ -2311,49 +2315,174 @@ void legalizeRayTracingEntryPointParameterForGLSL(
23112315
auto builder = context->getBuilder();
23122316
auto paramType = pp->getDataType();
23132317

2314-
// The parameter might be either an `in` parameter,
2315-
// or an `out` or `in out` parameter, and in those
2316-
// latter cases its IR-level type will include a
2317-
// wrapping "pointer-like" type (e.g., `Out<Float>`
2318-
// instead of just `Float`).
2319-
//
2320-
// Because global shader parameters are read-only
2321-
// in the same way function types are, we can take
2322-
// care of that detail here just by allocating a
2323-
// global shader parameter with exactly the type
2324-
// of the original function parameter.
2325-
//
23262318
auto globalParam = addGlobalParam(builder->getModule(), paramType);
23272319
builder->addLayoutDecoration(globalParam, paramLayout);
23282320
moveValueBefore(globalParam, builder->getFunc());
23292321
pp->replaceUsesWith(globalParam);
2330-
2331-
// Because linkage between ray-tracing shaders is
2332-
// based on the type of incoming/outgoing payload
2333-
// and attribute parameters, it would be an error to
2334-
// eliminate the global parameter *even if* it is
2335-
// not actually used inside the entry point.
2336-
//
2337-
// We attach a decoration to the entry point that
2338-
// makes note of the dependency, so that steps
2339-
// like dead code elimination cannot get rid of
2340-
// the parameter.
2341-
//
2342-
// TODO: We could consider using a structure like
2343-
// this for *all* of the entry point parameters
2344-
// that get moved to the global scope, since SPIR-V
2345-
// ends up requiring such information on an `OpEntryPoint`.
2346-
//
2347-
// As a further alternative, we could decide to
2348-
// keep entry point varying input/outtput attached
2349-
// to the parameter list through all of the Slang IR
2350-
// steps, and only declare it as global variables at
2351-
// the last minute when emitting a GLSL `main` or
2352-
// SPIR-V for an entry point.
2353-
//
23542322
builder->addDependsOnDecoration(func, globalParam);
23552323
}
23562324

2325+
static void consolidateParameters(
2326+
GLSLLegalizationContext* context,
2327+
IRFunc* func,
2328+
List<IRParam*>& params)
2329+
{
2330+
auto builder = context->getBuilder();
2331+
2332+
// Create a struct type to hold all parameters
2333+
IRInst* consolidatedVar = nullptr;
2334+
auto structType = builder->createStructType();
2335+
2336+
// Inside the structure, add fields for each parameter
2337+
for (auto _param : params)
2338+
{
2339+
auto _paramType = _param->getDataType();
2340+
IRType* valueType = _paramType;
2341+
2342+
if (as<IROutType>(_paramType))
2343+
valueType = as<IROutType>(_paramType)->getValueType();
2344+
else if (auto inOutType = as<IRInOutType>(_paramType))
2345+
valueType = inOutType->getValueType();
2346+
2347+
auto key = builder->createStructKey();
2348+
builder->addNameHintDecoration(key, UnownedStringSlice("field"));
2349+
auto field = builder->createStructField(structType, key, valueType);
2350+
field->removeFromParent();
2351+
field->insertAtEnd(structType);
2352+
}
2353+
2354+
// Create a global variable to hold the consolidated struct
2355+
consolidatedVar = builder->createGlobalVar(structType);
2356+
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
2357+
consolidatedVar->setFullType(ptrType);
2358+
consolidatedVar->moveToEnd();
2359+
2360+
// Add the ray payload decoration and assign location 0.
2361+
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);
2362+
2363+
// Store the consolidated variable for this function
2364+
context->rayTracingConsolidatedVars[func] = consolidatedVar;
2365+
2366+
// Replace each parameter with a field in the consolidated struct
2367+
for (Index i = 0; i < params.getCount(); ++i)
2368+
{
2369+
auto _param = params[i];
2370+
2371+
// Find the i-th field
2372+
IRStructField* targetField = nullptr;
2373+
Index fieldIndex = 0;
2374+
for (auto field : structType->getFields())
2375+
{
2376+
if (fieldIndex == i)
2377+
{
2378+
targetField = field;
2379+
break;
2380+
}
2381+
fieldIndex++;
2382+
}
2383+
SLANG_ASSERT(targetField);
2384+
2385+
// Create the field address with the correct type
2386+
auto _paramType = _param->getDataType();
2387+
auto fieldType = targetField->getFieldType();
2388+
2389+
// If the parameter is an out/inout type, we need to create a pointer type
2390+
IRType* fieldPtrType = nullptr;
2391+
if (as<IROutType>(_paramType))
2392+
{
2393+
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
2394+
}
2395+
else if (as<IRInOutType>(_paramType))
2396+
{
2397+
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
2398+
}
2399+
2400+
auto fieldAddr = builder->emitFieldAddress(
2401+
fieldPtrType,
2402+
consolidatedVar,
2403+
targetField->getKey());
2404+
2405+
// Replace parameter uses with field address
2406+
_param->replaceUsesWith(fieldAddr);
2407+
}
2408+
}
2409+
2410+
static void handleMultipleParams(
2411+
GLSLLegalizationContext* context,
2412+
IRFunc* func,
2413+
IRParam* pp)
2414+
{
2415+
auto firstBlock = func->getFirstBlock();
2416+
2417+
// Now we run the consolidation step, but if we've already
2418+
// processed this parameter, skip it.
2419+
List<IRParam*>* processedParams = nullptr;
2420+
if (auto foundList = context->rayTracingProcessedParams.tryGetValue(func))
2421+
{
2422+
processedParams = foundList;
2423+
if (processedParams->contains(pp))
2424+
return;
2425+
}
2426+
else
2427+
{
2428+
context->rayTracingProcessedParams[func] = List<IRParam*>();
2429+
processedParams = &context->rayTracingProcessedParams[func];
2430+
}
2431+
2432+
// Collect all parameters that need to be consolidated
2433+
List<IRParam*> params;
2434+
List<IRVarLayout*> paramLayouts;
2435+
2436+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
2437+
{
2438+
auto pLayoutDecoration = _param->findDecoration<IRLayoutDecoration>();
2439+
SLANG_ASSERT(pLayoutDecoration);
2440+
auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout());
2441+
SLANG_ASSERT(pLayout);
2442+
2443+
// Only include parameters that haven't been processed yet
2444+
auto _paramType = _param->getDataType();
2445+
bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType));
2446+
if (!processedParams->contains(_param) && needsConsolidation)
2447+
{
2448+
params.add(_param);
2449+
paramLayouts.add(pLayout);
2450+
processedParams->add(_param);
2451+
}
2452+
}
2453+
2454+
consolidateParameters(context, func, params);
2455+
}
2456+
2457+
void legalizeRayTracingEntryPointParameterForGLSL(
2458+
GLSLLegalizationContext* context,
2459+
IRFunc* func,
2460+
IRParam* pp,
2461+
IRVarLayout* paramLayout)
2462+
{
2463+
auto firstBlock = func->getFirstBlock();
2464+
2465+
// Count the number of inout or out parameters
2466+
int inoutOrOutParamCount = 0;
2467+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
2468+
{
2469+
auto _paramType = _param->getDataType();
2470+
if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType))
2471+
{
2472+
inoutOrOutParamCount++;
2473+
}
2474+
}
2475+
2476+
// If we have just one inout or out param, we don't need consolidation.
2477+
if (inoutOrOutParamCount <= 1)
2478+
{
2479+
handleSingleParam(context, func, pp, paramLayout);
2480+
return;
2481+
}
2482+
2483+
handleMultipleParams(context, func, pp);
2484+
}
2485+
23572486
static void legalizeMeshPayloadInputParam(
23582487
GLSLLegalizationContext* context,
23592488
CodeGenContext* codeGenContext,

tests/vkray/multipleinout.slang

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// slang-repro.slang
2+
3+
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly
4+
5+
// This test checks whether the spirv generated when there are multiple inout or out variables, they
6+
// all get consolidated into one IncomingRayPayloadKHR.
7+
8+
struct ReflectionRay
9+
{
10+
float4 color;
11+
};
12+
13+
StructuredBuffer<float4> colors;
14+
15+
[shader("closesthit")]
16+
void main(
17+
BuiltInTriangleIntersectionAttributes attributes,
18+
inout ReflectionRay ioPayload,
19+
out float3 dummy)
20+
{
21+
uint materialID = (InstanceIndex() << 1)
22+
+ InstanceID()
23+
+ PrimitiveIndex()
24+
+ HitKind();
25+
26+
ioPayload.color = colors[materialID];
27+
dummy = HitTriangleVertexPosition(0);
28+
}
29+
30+
// CHECK: OpCapability RayTracingKHR
31+
// CHECK: OpCapability RayTracingPositionFetchKHR
32+
// CHECK: OpCapability Shader
33+
// CHECK: OpExtension "SPV_KHR_ray_tracing"
34+
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
35+
// CHECK: OpExtension "SPV_KHR_ray_tracing_position_fetch"
36+
// CHECK: OpMemoryModel Logical GLSL450
37+
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
38+
39+
// CHECK: OpName %ReflectionRay "ReflectionRay"
40+
// CHECK: OpMemberName %ReflectionRay 0 "color"
41+
// CHECK: OpName %materialID "materialID"
42+
// CHECK: OpName %StructuredBuffer "StructuredBuffer"
43+
// CHECK: OpName %colors "colors"
44+
// CHECK: OpName %main "main"
45+
46+
// CHECK-DAG: OpDecorate %gl_InstanceID BuiltIn InstanceId
47+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn InstanceCustomIndexKHR
48+
// CHECK-DAG: OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
49+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitKindKHR
50+
// CHECK-DAG: OpDecorate %_runtimearr_v4float ArrayStride 16
51+
// CHECK-DAG: OpDecorate %StructuredBuffer Block
52+
// CHECK-DAG: OpDecorate %colors Binding 0
53+
// CHECK-DAG: OpDecorate %colors DescriptorSet 0
54+
// CHECK-DAG: OpDecorate %colors NonWritable
55+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitTriangleVertexPositionsKHR
56+
57+
// CHECK: %ReflectionRay = OpTypeStruct %v4float
58+
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
59+
// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
60+
// CHECK: %_ptr_IncomingRayPayloadKHR_ReflectionRay = OpTypePointer IncomingRayPayloadKHR %ReflectionRay
61+
// CHECK: %_ptr_IncomingRayPayloadKHR_v3float = OpTypePointer IncomingRayPayloadKHR %v3float
62+
// CHECK: %StructuredBuffer = OpTypeStruct %_runtimearr_v4float
63+
// CHECK: %_ptr_StorageBuffer_StructuredBuffer = OpTypePointer StorageBuffer %StructuredBuffer
64+
// CHECK: %_ptr_Input__arr_v3float_{{.*}} = OpTypePointer Input %_arr_v3float_{{.*}}
65+
66+
// CHECK: %main = OpFunction %void None %{{.*}}
67+
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
68+
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
69+
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
70+
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
71+
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}

0 commit comments

Comments
 (0)