Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate multiple inouts/outs into struct #6435

Merged
merged 7 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 152 additions & 3 deletions source/slang/slang-ir-glsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2390,7 +2390,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
}
}

void legalizeRayTracingEntryPointParameterForGLSL(
void handleSingleParam(
GLSLLegalizationContext* context,
IRFunc* func,
IRParam* pp,
Expand Down Expand Up @@ -2442,6 +2442,136 @@ void legalizeRayTracingEntryPointParameterForGLSL(
builder->addDependsOnDecoration(func, globalParam);
}

static void consolidateParameters(GLSLLegalizationContext* context, List<IRParam*>& params)
{
auto builder = context->getBuilder();

// Create a struct type to hold all parameters
IRInst* consolidatedVar = nullptr;
auto structType = builder->createStructType();

// Inside the structure, add fields for each parameter
for (auto _param : params)
{
auto _paramType = _param->getDataType();
IRType* valueType = _paramType;

if (as<IROutTypeBase>(_paramType))
valueType = as<IROutTypeBase>(_paramType)->getValueType();

auto key = builder->createStructKey();
if (auto nameDecor = _param->findDecoration<IRNameHintDecoration>())
builder->addNameHintDecoration(key, nameDecor->getName());
auto field = builder->createStructField(structType, key, valueType);
field->removeFromParent();
field->insertAtEnd(structType);
}

// Create a global variable to hold the consolidated struct
consolidatedVar = builder->createGlobalVar(structType);
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
consolidatedVar->setFullType(ptrType);
consolidatedVar->moveToEnd();

// Add the ray payload decoration and assign location 0.
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);

// Replace each parameter with a field in the consolidated struct
for (Index i = 0; i < params.getCount(); ++i)
{
auto _param = params[i];

// Find the i-th field
IRStructField* targetField = nullptr;
Index fieldIndex = 0;
for (auto field : structType->getFields())
{
if (fieldIndex == i)
{
targetField = field;
break;
}
fieldIndex++;
}
SLANG_ASSERT(targetField);

// Create the field address with the correct type
auto _paramType = _param->getDataType();
auto fieldType = targetField->getFieldType();

// If the parameter is an out/inout type, we need to create a pointer type
IRType* fieldPtrType = nullptr;
if (as<IROutType>(_paramType))
{
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
}
else if (as<IRInOutType>(_paramType))
{
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
}

auto fieldAddr =
builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey());

// Replace parameter uses with field address
_param->replaceUsesWith(fieldAddr);
}
}

// Consolidate ray tracing parameters for an entry point function
void consolidateRayTracingParameters(GLSLLegalizationContext* context, IRFunc* func)
{
auto builder = context->getBuilder();
auto firstBlock = func->getFirstBlock();
if (!firstBlock)
return;

// Collect all out/inout parameters that need to be consolidated
List<IRParam*> outParams;
List<IRParam*> params;

for (auto param = firstBlock->getFirstParam(); param; param = param->getNextParam())
{
builder->setInsertBefore(firstBlock->getFirstOrdinaryInst());
if (as<IROutType>(param->getDataType()) || as<IRInOutType>(param->getDataType()))
{
outParams.add(param);
}
params.add(param);
}

// We don't need consolidation here.
if (outParams.getCount() <= 1)
{
for (auto param : params)
{
auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
SLANG_ASSERT(paramLayoutDecoration);
auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
handleSingleParam(context, func, param, paramLayout);
}
return;
}
else
{
// We need consolidation here, but before that, handle parameters other than inout/out.
for (auto param : params)
{
if (outParams.contains(param))
{
continue;
}
auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
SLANG_ASSERT(paramLayoutDecoration);
auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
handleSingleParam(context, func, param, paramLayout);
}

// Now, consolidate the inout/out parameters
consolidateParameters(context, outParams);
}
}

static void legalizeMeshPayloadInputParam(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
Expand Down Expand Up @@ -3129,7 +3259,6 @@ void legalizeEntryPointParameterForGLSL(
}
}


// We need to create a global variable that will replace the parameter.
// It seems superficially obvious that the variable should have
// the same type as the parameter.
Expand Down Expand Up @@ -3286,7 +3415,6 @@ void legalizeEntryPointParameterForGLSL(
case Stage::Intersection:
case Stage::Miss:
case Stage::RayGeneration:
legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout);
return;
}

Expand Down Expand Up @@ -3916,12 +4044,33 @@ void legalizeEntryPointForGLSL(
invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput);
}

// Special handling for ray tracing shaders
bool isRayTracingShader = false;
switch (stage)
{
case Stage::AnyHit:
case Stage::Callable:
case Stage::ClosestHit:
case Stage::Intersection:
case Stage::Miss:
case Stage::RayGeneration:
isRayTracingShader = true;
consolidateRayTracingParameters(&context, func);
break;
default:
break;
}

// Next we will walk through any parameters of the entry-point function,
// and turn them into global variables.
if (auto firstBlock = func->getFirstBlock())
{
for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam())
{
if (isRayTracingShader)
{
continue;
}
// Any initialization code we insert for parameters needs
// to be at the start of the "ordinary" instructions in the block:
builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
Expand Down
36 changes: 36 additions & 0 deletions tests/vkray/multipleinout.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly

// This test checks whether the spirv generated when there are multiple inout or out variables, they
// all get consolidated into one IncomingRayPayloadKHR.

struct ReflectionRay
{
float4 color;
};

StructuredBuffer<float4> colors;

[shader("closesthit")]
void main(
BuiltInTriangleIntersectionAttributes attributes,
inout ReflectionRay ioPayload,
out float3 dummy)
{
uint materialID = (InstanceIndex() << 1)
+ InstanceID()
+ PrimitiveIndex()
+ HitKind();

ioPayload.color = colors[materialID];
dummy = HitTriangleVertexPosition(0);
}

// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the synthesized struct in the entry point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaizhangNV The spirv code generated is like so. So it does include the entry point

OpEntryPoint ClosestHitKHR %main "main" %48 %31 %gl_PrimitiveID %25 %gl_InstanceID %colors %11
...
%11 = OpVariable %_ptr_IncomingRayPayloadKHR__struct_5 IncomingRayPayloadKHR ; Location 0

So it does cover it, but I guess this can't be added as a check condition here since the actual assembly instruction number would differ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see.
Then it's fine, just keep it as it.

// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can keep this line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and 57

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
// CHECK: %main = OpFunction %void None %{{.*}}
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}
Loading