Skip to content

Commit 259b608

Browse files
committed
Consolidate multiple inout/outs into struct
Fixes #5121 VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700 requires that there be only one IncomingRayPayloadKHR per entry point. This change does two things: 1. If an entry point has the one inout or out, or has only ins, then stay with current implementation. 2. If there are multiple outs/inouts, then create a new structure to consolidate these fields and emit this structure. These two code paths are split into two separate functions for clarity. This patch also adds a new test: multipleinout.slang to test this.
1 parent 187ec44 commit 259b608

File tree

2 files changed

+254
-37
lines changed

2 files changed

+254
-37
lines changed

source/slang/slang-ir-glsl-legalize.cpp

+186-37
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ struct GLSLLegalizationContext
329329

330330
IRBuilder* builder;
331331
IRBuilder* getBuilder() { return builder; }
332+
333+
// For ray tracing shaders, we need to consolidate all parameters into a single structure
334+
Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars;
335+
Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams;
332336
};
333337

334338
// This examines the passed type and determines the GLSL mesh shader indices
@@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
23022306
}
23032307
}
23042308

2305-
void legalizeRayTracingEntryPointParameterForGLSL(
2309+
static void handleSingleParam(
23062310
GLSLLegalizationContext* context,
23072311
IRFunc* func,
23082312
IRParam* pp,
@@ -2311,49 +2315,194 @@ void legalizeRayTracingEntryPointParameterForGLSL(
23112315
auto builder = context->getBuilder();
23122316
auto paramType = pp->getDataType();
23132317

2314-
// The parameter might be either an `in` parameter,
2315-
// or an `out` or `in out` parameter, and in those
2316-
// latter cases its IR-level type will include a
2317-
// wrapping "pointer-like" type (e.g., `Out<Float>`
2318-
// instead of just `Float`).
2319-
//
2320-
// Because global shader parameters are read-only
2321-
// in the same way function types are, we can take
2322-
// care of that detail here just by allocating a
2323-
// global shader parameter with exactly the type
2324-
// of the original function parameter.
2325-
//
23262318
auto globalParam = addGlobalParam(builder->getModule(), paramType);
23272319
builder->addLayoutDecoration(globalParam, paramLayout);
23282320
moveValueBefore(globalParam, builder->getFunc());
23292321
pp->replaceUsesWith(globalParam);
2330-
2331-
// Because linkage between ray-tracing shaders is
2332-
// based on the type of incoming/outgoing payload
2333-
// and attribute parameters, it would be an error to
2334-
// eliminate the global parameter *even if* it is
2335-
// not actually used inside the entry point.
2336-
//
2337-
// We attach a decoration to the entry point that
2338-
// makes note of the dependency, so that steps
2339-
// like dead code elimination cannot get rid of
2340-
// the parameter.
2341-
//
2342-
// TODO: We could consider using a structure like
2343-
// this for *all* of the entry point parameters
2344-
// that get moved to the global scope, since SPIR-V
2345-
// ends up requiring such information on an `OpEntryPoint`.
2346-
//
2347-
// As a further alternative, we could decide to
2348-
// keep entry point varying input/outtput attached
2349-
// to the parameter list through all of the Slang IR
2350-
// steps, and only declare it as global variables at
2351-
// the last minute when emitting a GLSL `main` or
2352-
// SPIR-V for an entry point.
2353-
//
23542322
builder->addDependsOnDecoration(func, globalParam);
23552323
}
23562324

2325+
static void consolidateParameters(
2326+
GLSLLegalizationContext* context,
2327+
IRFunc* func,
2328+
List<IRParam*>& params)
2329+
{
2330+
auto builder = context->getBuilder();
2331+
2332+
// Check if we already have a consolidated variable for this function
2333+
IRInst* consolidatedVar = nullptr;
2334+
if (auto foundVar = context->rayTracingConsolidatedVars.tryGetValue(func))
2335+
{
2336+
consolidatedVar = *foundVar;
2337+
}
2338+
else
2339+
{
2340+
// Create a struct type to hold all parameters
2341+
auto structType = builder->createStructType();
2342+
2343+
// Add fields for each parameter
2344+
for (Index i = 0; i < params.getCount(); ++i)
2345+
{
2346+
auto _param = params[i];
2347+
auto _paramType = _param->getDataType();
2348+
IRType* valueType = _paramType;
2349+
2350+
if (as<IROutType>(_paramType))
2351+
valueType = as<IROutType>(_paramType)->getValueType();
2352+
else if (auto inOutType = as<IRInOutType>(_paramType))
2353+
valueType = inOutType->getValueType();
2354+
2355+
auto key = builder->createStructKey();
2356+
builder->addNameHintDecoration(key, UnownedStringSlice("field"));
2357+
auto field = builder->createStructField(structType, key, valueType);
2358+
field->removeFromParent();
2359+
field->insertAtEnd(structType);
2360+
}
2361+
2362+
// Create a global variable to hold the consolidated struct
2363+
consolidatedVar = builder->createGlobalVar(structType);
2364+
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::RayPayloadKHR);
2365+
consolidatedVar->setFullType(ptrType);
2366+
consolidatedVar->moveToEnd();
2367+
2368+
// Add the ray payload decoration
2369+
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);
2370+
2371+
// Store the consolidated variable for this function
2372+
context->rayTracingConsolidatedVars[func] = consolidatedVar;
2373+
}
2374+
2375+
// Replace each parameter with a field in the consolidated struct
2376+
for (Index i = 0; i < params.getCount(); ++i)
2377+
{
2378+
auto _param = params[i];
2379+
2380+
// Get the struct type from the consolidated variable's type
2381+
auto ptrType = as<IRPtrTypeBase>(consolidatedVar->getDataType());
2382+
SLANG_ASSERT(ptrType);
2383+
auto structType = as<IRStructType>(ptrType->getValueType());
2384+
SLANG_ASSERT(structType);
2385+
2386+
// Find the i-th field
2387+
IRStructField* targetField = nullptr;
2388+
Index fieldIndex = 0;
2389+
for (auto field : structType->getFields())
2390+
{
2391+
if (fieldIndex == i)
2392+
{
2393+
targetField = field;
2394+
break;
2395+
}
2396+
fieldIndex++;
2397+
}
2398+
SLANG_ASSERT(targetField);
2399+
2400+
// Create the field address with the correct type
2401+
auto _paramType = _param->getDataType();
2402+
auto fieldType = targetField->getFieldType();
2403+
2404+
// If the parameter is an out/inout type, we need to create a pointer type
2405+
IRType* fieldPtrType = nullptr;
2406+
if (as<IROutType>(_paramType))
2407+
{
2408+
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
2409+
}
2410+
else if (as<IRInOutType>(_paramType))
2411+
{
2412+
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
2413+
}
2414+
else
2415+
{
2416+
fieldPtrType = builder->getPtrType(kIROp_PtrType, fieldType, AddressSpace::RayPayloadKHR);
2417+
}
2418+
2419+
auto fieldAddr = builder->emitFieldAddress(
2420+
fieldPtrType,
2421+
consolidatedVar,
2422+
targetField->getKey());
2423+
2424+
// Replace parameter uses with field address
2425+
_param->replaceUsesWith(fieldAddr);
2426+
}
2427+
}
2428+
2429+
static void handleMultipleParams(
2430+
GLSLLegalizationContext* context,
2431+
IRFunc* func,
2432+
IRParam* pp)
2433+
{
2434+
auto firstBlock = func->getFirstBlock();
2435+
2436+
// Now we run the consolidation step, but if we've already
2437+
// processed this parameter, skip it.
2438+
List<IRParam*>* processedParams = nullptr;
2439+
if (auto foundList = context->rayTracingProcessedParams.tryGetValue(func))
2440+
{
2441+
processedParams = foundList;
2442+
if (processedParams->contains(pp))
2443+
return;
2444+
}
2445+
else
2446+
{
2447+
context->rayTracingProcessedParams[func] = List<IRParam*>();
2448+
processedParams = &context->rayTracingProcessedParams[func];
2449+
}
2450+
2451+
// Collect all parameters that need to be consolidated
2452+
List<IRParam*> params;
2453+
List<IRVarLayout*> paramLayouts;
2454+
2455+
// Only consolidate if there is more than one inout or out parameter
2456+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
2457+
{
2458+
auto pLayoutDecoration = _param->findDecoration<IRLayoutDecoration>();
2459+
SLANG_ASSERT(pLayoutDecoration);
2460+
auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout());
2461+
SLANG_ASSERT(pLayout);
2462+
2463+
// Only include parameters that haven't been processed yet
2464+
auto _paramType = _param->getDataType();
2465+
bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType));
2466+
if (!processedParams->contains(_param) && needsConsolidation)
2467+
{
2468+
params.add(_param);
2469+
paramLayouts.add(pLayout);
2470+
processedParams->add(_param);
2471+
}
2472+
}
2473+
2474+
consolidateParameters(context, func, params);
2475+
}
2476+
2477+
void legalizeRayTracingEntryPointParameterForGLSL(
2478+
GLSLLegalizationContext* context,
2479+
IRFunc* func,
2480+
IRParam* pp,
2481+
IRVarLayout* paramLayout)
2482+
{
2483+
auto firstBlock = func->getFirstBlock();
2484+
2485+
// Count the number of inout or out parameters
2486+
int inoutOrOutParamCount = 0;
2487+
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
2488+
{
2489+
auto _paramType = _param->getDataType();
2490+
if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType))
2491+
{
2492+
inoutOrOutParamCount++;
2493+
}
2494+
}
2495+
2496+
// If we have just one inout or out param, we don't need consolidation.
2497+
if (inoutOrOutParamCount <= 1)
2498+
{
2499+
handleSingleParam(context, func, pp, paramLayout);
2500+
return;
2501+
}
2502+
2503+
handleMultipleParams(context, func, pp);
2504+
}
2505+
23572506
static void legalizeMeshPayloadInputParam(
23582507
GLSLLegalizationContext* context,
23592508
CodeGenContext* codeGenContext,

tests/vkray/multipleinout.slang

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// slang-repro.slang
2+
3+
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly
4+
5+
struct ReflectionRay
6+
{
7+
float4 color;
8+
};
9+
10+
StructuredBuffer<float4> colors;
11+
12+
[shader("closesthit")]
13+
void main(
14+
BuiltInTriangleIntersectionAttributes attributes,
15+
inout ReflectionRay ioPayload,
16+
out float3 dummy)
17+
{
18+
uint materialID = (InstanceIndex() << 1)
19+
+ InstanceID()
20+
+ PrimitiveIndex()
21+
+ HitKind();
22+
23+
ioPayload.color = colors[materialID];
24+
dummy = HitTriangleVertexPosition(0);
25+
}
26+
27+
// CHECK: OpCapability RayTracingKHR
28+
// CHECK: OpCapability RayTracingPositionFetchKHR
29+
// CHECK: OpCapability Shader
30+
// CHECK: OpExtension "SPV_KHR_ray_tracing"
31+
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
32+
// CHECK: OpExtension "SPV_KHR_ray_tracing_position_fetch"
33+
// CHECK: OpMemoryModel Logical GLSL450
34+
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
35+
36+
// CHECK: OpName %ReflectionRay "ReflectionRay"
37+
// CHECK: OpMemberName %ReflectionRay 0 "color"
38+
// CHECK: OpName %materialID "materialID"
39+
// CHECK: OpName %StructuredBuffer "StructuredBuffer"
40+
// CHECK: OpName %colors "colors"
41+
// CHECK: OpName %main "main"
42+
43+
// CHECK-DAG: OpDecorate %gl_InstanceID BuiltIn InstanceId
44+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn InstanceCustomIndexKHR
45+
// CHECK-DAG: OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
46+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitKindKHR
47+
// CHECK-DAG: OpDecorate %_runtimearr_v4float ArrayStride 16
48+
// CHECK-DAG: OpDecorate %StructuredBuffer Block
49+
// CHECK-DAG: OpDecorate %colors Binding 0
50+
// CHECK-DAG: OpDecorate %colors DescriptorSet 0
51+
// CHECK-DAG: OpDecorate %colors NonWritable
52+
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitTriangleVertexPositionsKHR
53+
54+
// CHECK: %ReflectionRay = OpTypeStruct %v4float
55+
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
56+
// CHECK: %_ptr_RayPayloadKHR__struct_{{.*}} = OpTypePointer RayPayloadKHR %_struct_{{.*}}
57+
// CHECK: %_ptr_RayPayloadKHR_ReflectionRay = OpTypePointer RayPayloadKHR %ReflectionRay
58+
// CHECK: %_ptr_RayPayloadKHR_v3float = OpTypePointer RayPayloadKHR %v3float
59+
// CHECK: %StructuredBuffer = OpTypeStruct %_runtimearr_v4float
60+
// CHECK: %_ptr_StorageBuffer_StructuredBuffer = OpTypePointer StorageBuffer %StructuredBuffer
61+
// CHECK: %_ptr_Input__arr_v3float_{{.*}} = OpTypePointer Input %_arr_v3float_{{.*}}
62+
63+
// CHECK: %main = OpFunction %void None %{{.*}}
64+
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
65+
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
66+
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
67+
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
68+
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}

0 commit comments

Comments
 (0)