Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate multiple inouts/outs into struct #6435

Merged
merged 7 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 170 additions & 3 deletions source/slang/slang-ir-glsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ struct GLSLLegalizationContext

IRBuilder* builder;
IRBuilder* getBuilder() { return builder; }

// For ray tracing shaders, we need to consolidate all parameters into a single structure
Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars;
Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams;
};

// This examines the passed type and determines the GLSL mesh shader indices
Expand Down Expand Up @@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
}
}

void legalizeRayTracingEntryPointParameterForGLSL(
static void handleSingleParam(
GLSLLegalizationContext* context,
IRFunc* func,
IRParam* pp,
Expand Down Expand Up @@ -2354,6 +2358,149 @@ void legalizeRayTracingEntryPointParameterForGLSL(
builder->addDependsOnDecoration(func, globalParam);
}

static void consolidateParameters(
GLSLLegalizationContext* context,
IRFunc* func,
List<IRParam*>& params)
{
auto builder = context->getBuilder();

// Create a struct type to hold all parameters
IRInst* consolidatedVar = nullptr;
auto structType = builder->createStructType();

// Inside the structure, add fields for each parameter
for (auto _param : params)
{
auto _paramType = _param->getDataType();
IRType* valueType = _paramType;

if (as<IROutType>(_paramType))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as<IROutTypeBase> should cover both cases in this if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

valueType = as<IROutType>(_paramType)->getValueType();
else if (auto inOutType = as<IRInOutType>(_paramType))
valueType = inOutType->getValueType();

auto key = builder->createStructKey();
builder->addNameHintDecoration(key, UnownedStringSlice("field"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the name "field"? the name should be coming from _param.
just do:

if (auto nameDecor = _param->findDecoration<IRNameHintDecoration>())
     builder->addNameHintDecoration(key, nameDecor->getName());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto field = builder->createStructField(structType, key, valueType);
field->removeFromParent();
field->insertAtEnd(structType);
}

// Create a global variable to hold the consolidated struct
consolidatedVar = builder->createGlobalVar(structType);
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
consolidatedVar->setFullType(ptrType);
consolidatedVar->moveToEnd();

// Add the ray payload decoration and assign location 0.
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);

// Store the consolidated variable for this function
context->rayTracingConsolidatedVars[func] = consolidatedVar;

// Replace each parameter with a field in the consolidated struct
for (Index i = 0; i < params.getCount(); ++i)
{
auto _param = params[i];

// Find the i-th field
IRStructField* targetField = nullptr;
Index fieldIndex = 0;
for (auto field : structType->getFields())
{
if (fieldIndex == i)
{
targetField = field;
break;
}
fieldIndex++;
}
SLANG_ASSERT(targetField);

// Create the field address with the correct type
auto _paramType = _param->getDataType();
auto fieldType = targetField->getFieldType();

// If the parameter is an out/inout type, we need to create a pointer type
IRType* fieldPtrType = nullptr;
if (as<IROutType>(_paramType))
{
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
}
else if (as<IRInOutType>(_paramType))
{
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
}

auto fieldAddr =
builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey());

// Replace parameter uses with field address
_param->replaceUsesWith(fieldAddr);
}
}

static void handleMultipleParams(GLSLLegalizationContext* context, IRFunc* func, IRParam* pp)
{
auto firstBlock = func->getFirstBlock();

// Now we run the consolidation step, but if we've already
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like we are doing consolidation too late. Maybe consolidation should be called directly from legalizeEntryPointForGLSL instead of from legalizeEntryPointParameterForGLSL so we don't need to track if a parameter has already been processed or not, and we also won't need that hasSingleOutOrInOutParam parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for that suggestion. That will indeed simplify the code. Updated.

// processed this parameter, skip it.
List<IRParam*>* processedParams = nullptr;
if (auto foundList = context->rayTracingProcessedParams.tryGetValue(func))
{
processedParams = foundList;
if (processedParams->contains(pp))
return;
}
else
{
context->rayTracingProcessedParams[func] = List<IRParam*>();
processedParams = &context->rayTracingProcessedParams[func];
}

// Collect all parameters that need to be consolidated
List<IRParam*> params;
List<IRVarLayout*> paramLayouts;

for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
{
auto pLayoutDecoration = _param->findDecoration<IRLayoutDecoration>();
SLANG_ASSERT(pLayoutDecoration);
auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout());
SLANG_ASSERT(pLayout);

// Only include parameters that haven't been processed yet
auto _paramType = _param->getDataType();
bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType));
if (!processedParams->contains(_param) && needsConsolidation)
{
params.add(_param);
paramLayouts.add(pLayout);
processedParams->add(_param);
}
}

consolidateParameters(context, func, params);
}

void legalizeRayTracingEntryPointParameterForGLSL(
GLSLLegalizationContext* context,
IRFunc* func,
IRParam* pp,
IRVarLayout* paramLayout,
bool hasSingleOutOrInOutParam)
{
if (hasSingleOutOrInOutParam)
{
handleSingleParam(context, func, pp, paramLayout);
return;
}

handleMultipleParams(context, func, pp);
}

static void legalizeMeshPayloadInputParam(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
Expand Down Expand Up @@ -3041,7 +3188,6 @@ void legalizeEntryPointParameterForGLSL(
}
}


// We need to create a global variable that will replace the parameter.
// It seems superficially obvious that the variable should have
// the same type as the parameter.
Expand Down Expand Up @@ -3198,7 +3344,28 @@ void legalizeEntryPointParameterForGLSL(
case Stage::Intersection:
case Stage::Miss:
case Stage::RayGeneration:
legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout);
{
// Count the number of inout or out parameters
int inoutOrOutParamCount = 0;
auto firstBlock = func->getFirstBlock();
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam())
{
auto _paramType = _param->getDataType();
if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType))
{
inoutOrOutParamCount++;
}
}

// If we have just one inout or out param, we don't need consolidation.
bool hasSingleOutOrInOutParam = inoutOrOutParamCount <= 1;
legalizeRayTracingEntryPointParameterForGLSL(
context,
func,
pp,
paramLayout,
hasSingleOutOrInOutParam);
}
return;
}

Expand Down
69 changes: 69 additions & 0 deletions tests/vkray/multipleinout.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly

// This test checks whether the spirv generated when there are multiple inout or out variables, they
// all get consolidated into one IncomingRayPayloadKHR.

struct ReflectionRay
{
float4 color;
};

StructuredBuffer<float4> colors;

[shader("closesthit")]
void main(
BuiltInTriangleIntersectionAttributes attributes,
inout ReflectionRay ioPayload,
out float3 dummy)
{
uint materialID = (InstanceIndex() << 1)
+ InstanceID()
+ PrimitiveIndex()
+ HitKind();

ioPayload.color = colors[materialID];
dummy = HitTriangleVertexPosition(0);
}

// CHECK: OpCapability RayTracingKHR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to check from line 28-34

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

// CHECK: OpCapability RayTracingPositionFetchKHR
// CHECK: OpCapability Shader
// CHECK: OpExtension "SPV_KHR_ray_tracing"
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
// CHECK: OpExtension "SPV_KHR_ray_tracing_position_fetch"
// CHECK: OpMemoryModel Logical GLSL450
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the synthesized struct in the entry point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaizhangNV The spirv code generated is like so. So it does include the entry point

OpEntryPoint ClosestHitKHR %main "main" %48 %31 %gl_PrimitiveID %25 %gl_InstanceID %colors %11
...
%11 = OpVariable %_ptr_IncomingRayPayloadKHR__struct_5 IncomingRayPayloadKHR ; Location 0

So it does cover it, but I guess this can't be added as a check condition here since the actual assembly instruction number would differ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see.
Then it's fine, just keep it as it.


// CHECK: OpName %ReflectionRay "ReflectionRay"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to check 37-42

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

// CHECK: OpMemberName %ReflectionRay 0 "color"
// CHECK: OpName %materialID "materialID"
// CHECK: OpName %StructuredBuffer "StructuredBuffer"
// CHECK: OpName %colors "colors"
// CHECK: OpName %main "main"

// CHECK-DAG: OpDecorate %gl_InstanceID BuiltIn InstanceId
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to check 44-53

// CHECK-DAG: OpDecorate %{{.*}} BuiltIn InstanceCustomIndexKHR
// CHECK-DAG: OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitKindKHR
// CHECK-DAG: OpDecorate %_runtimearr_v4float ArrayStride 16
// CHECK-DAG: OpDecorate %StructuredBuffer Block
// CHECK-DAG: OpDecorate %colors Binding 0
// CHECK-DAG: OpDecorate %colors DescriptorSet 0
// CHECK-DAG: OpDecorate %colors NonWritable
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitTriangleVertexPositionsKHR

// CHECK: %ReflectionRay = OpTypeStruct %v4float
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can keep this line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and 57

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
// CHECK: %_ptr_IncomingRayPayloadKHR_ReflectionRay = OpTypePointer IncomingRayPayloadKHR %ReflectionRay
// CHECK: %_ptr_IncomingRayPayloadKHR_v3float = OpTypePointer IncomingRayPayloadKHR %v3float
// CHECK: %StructuredBuffer = OpTypeStruct %_runtimearr_v4float
// CHECK: %_ptr_StorageBuffer_StructuredBuffer = OpTypePointer StorageBuffer %StructuredBuffer
// CHECK: %_ptr_Input__arr_v3float_{{.*}} = OpTypePointer Input %_arr_v3float_{{.*}}

// CHECK: %main = OpFunction %void None %{{.*}}
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}
Loading