-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidate multiple inouts/outs into struct #6435
Changes from 1 commit
924e570
029ecc1
ef9a937
35c35fd
c2ce125
5694fe5
0133540
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -329,6 +329,10 @@ struct GLSLLegalizationContext | |
|
||
IRBuilder* builder; | ||
IRBuilder* getBuilder() { return builder; } | ||
|
||
// For ray tracing shaders, we need to consolidate all parameters into a single structure | ||
Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars; | ||
Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams; | ||
}; | ||
|
||
// This examines the passed type and determines the GLSL mesh shader indices | ||
|
@@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val) | |
} | ||
} | ||
|
||
void legalizeRayTracingEntryPointParameterForGLSL( | ||
static void handleSingleParam( | ||
GLSLLegalizationContext* context, | ||
IRFunc* func, | ||
IRParam* pp, | ||
|
@@ -2354,6 +2358,149 @@ void legalizeRayTracingEntryPointParameterForGLSL( | |
builder->addDependsOnDecoration(func, globalParam); | ||
} | ||
|
||
static void consolidateParameters( | ||
GLSLLegalizationContext* context, | ||
IRFunc* func, | ||
List<IRParam*>& params) | ||
{ | ||
auto builder = context->getBuilder(); | ||
|
||
// Create a struct type to hold all parameters | ||
IRInst* consolidatedVar = nullptr; | ||
auto structType = builder->createStructType(); | ||
|
||
// Inside the structure, add fields for each parameter | ||
for (auto _param : params) | ||
{ | ||
auto _paramType = _param->getDataType(); | ||
IRType* valueType = _paramType; | ||
|
||
if (as<IROutType>(_paramType)) | ||
valueType = as<IROutType>(_paramType)->getValueType(); | ||
else if (auto inOutType = as<IRInOutType>(_paramType)) | ||
valueType = inOutType->getValueType(); | ||
|
||
auto key = builder->createStructKey(); | ||
builder->addNameHintDecoration(key, UnownedStringSlice("field")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the name "field"? the name should be coming from
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
auto field = builder->createStructField(structType, key, valueType); | ||
field->removeFromParent(); | ||
field->insertAtEnd(structType); | ||
} | ||
|
||
// Create a global variable to hold the consolidated struct | ||
consolidatedVar = builder->createGlobalVar(structType); | ||
auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload); | ||
consolidatedVar->setFullType(ptrType); | ||
consolidatedVar->moveToEnd(); | ||
|
||
// Add the ray payload decoration and assign location 0. | ||
builder->addVulkanRayPayloadDecoration(consolidatedVar, 0); | ||
|
||
// Store the consolidated variable for this function | ||
context->rayTracingConsolidatedVars[func] = consolidatedVar; | ||
|
||
// Replace each parameter with a field in the consolidated struct | ||
for (Index i = 0; i < params.getCount(); ++i) | ||
{ | ||
auto _param = params[i]; | ||
|
||
// Find the i-th field | ||
IRStructField* targetField = nullptr; | ||
Index fieldIndex = 0; | ||
for (auto field : structType->getFields()) | ||
{ | ||
if (fieldIndex == i) | ||
{ | ||
targetField = field; | ||
break; | ||
} | ||
fieldIndex++; | ||
} | ||
SLANG_ASSERT(targetField); | ||
|
||
// Create the field address with the correct type | ||
auto _paramType = _param->getDataType(); | ||
auto fieldType = targetField->getFieldType(); | ||
|
||
// If the parameter is an out/inout type, we need to create a pointer type | ||
IRType* fieldPtrType = nullptr; | ||
if (as<IROutType>(_paramType)) | ||
{ | ||
fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType); | ||
} | ||
else if (as<IRInOutType>(_paramType)) | ||
{ | ||
fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType); | ||
} | ||
|
||
auto fieldAddr = | ||
builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey()); | ||
|
||
// Replace parameter uses with field address | ||
_param->replaceUsesWith(fieldAddr); | ||
} | ||
} | ||
|
||
static void handleMultipleParams(GLSLLegalizationContext* context, IRFunc* func, IRParam* pp) | ||
{ | ||
auto firstBlock = func->getFirstBlock(); | ||
|
||
// Now we run the consolidation step, but if we've already | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels like we are doing consolidation too late. Maybe consolidation should be called directly from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for that suggestion. That will indeed simplify the code. Updated. |
||
// processed this parameter, skip it. | ||
List<IRParam*>* processedParams = nullptr; | ||
if (auto foundList = context->rayTracingProcessedParams.tryGetValue(func)) | ||
{ | ||
processedParams = foundList; | ||
if (processedParams->contains(pp)) | ||
return; | ||
} | ||
else | ||
{ | ||
context->rayTracingProcessedParams[func] = List<IRParam*>(); | ||
processedParams = &context->rayTracingProcessedParams[func]; | ||
} | ||
|
||
// Collect all parameters that need to be consolidated | ||
List<IRParam*> params; | ||
List<IRVarLayout*> paramLayouts; | ||
|
||
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam()) | ||
{ | ||
auto pLayoutDecoration = _param->findDecoration<IRLayoutDecoration>(); | ||
SLANG_ASSERT(pLayoutDecoration); | ||
auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout()); | ||
SLANG_ASSERT(pLayout); | ||
|
||
// Only include parameters that haven't been processed yet | ||
auto _paramType = _param->getDataType(); | ||
bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType)); | ||
if (!processedParams->contains(_param) && needsConsolidation) | ||
{ | ||
params.add(_param); | ||
paramLayouts.add(pLayout); | ||
processedParams->add(_param); | ||
} | ||
} | ||
|
||
consolidateParameters(context, func, params); | ||
} | ||
|
||
void legalizeRayTracingEntryPointParameterForGLSL( | ||
GLSLLegalizationContext* context, | ||
IRFunc* func, | ||
IRParam* pp, | ||
IRVarLayout* paramLayout, | ||
bool hasSingleOutOrInOutParam) | ||
{ | ||
if (hasSingleOutOrInOutParam) | ||
{ | ||
handleSingleParam(context, func, pp, paramLayout); | ||
return; | ||
} | ||
|
||
handleMultipleParams(context, func, pp); | ||
} | ||
|
||
static void legalizeMeshPayloadInputParam( | ||
GLSLLegalizationContext* context, | ||
CodeGenContext* codeGenContext, | ||
|
@@ -3041,7 +3188,6 @@ void legalizeEntryPointParameterForGLSL( | |
} | ||
} | ||
|
||
|
||
// We need to create a global variable that will replace the parameter. | ||
// It seems superficially obvious that the variable should have | ||
// the same type as the parameter. | ||
|
@@ -3198,7 +3344,28 @@ void legalizeEntryPointParameterForGLSL( | |
case Stage::Intersection: | ||
case Stage::Miss: | ||
case Stage::RayGeneration: | ||
legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout); | ||
{ | ||
// Count the number of inout or out parameters | ||
int inoutOrOutParamCount = 0; | ||
auto firstBlock = func->getFirstBlock(); | ||
for (auto _param = firstBlock->getFirstParam(); _param; _param = _param->getNextParam()) | ||
{ | ||
auto _paramType = _param->getDataType(); | ||
if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType)) | ||
{ | ||
inoutOrOutParamCount++; | ||
} | ||
} | ||
|
||
// If we have just one inout or out param, we don't need consolidation. | ||
bool hasSingleOutOrInOutParam = inoutOrOutParamCount <= 1; | ||
legalizeRayTracingEntryPointParameterForGLSL( | ||
context, | ||
func, | ||
pp, | ||
paramLayout, | ||
hasSingleOutOrInOutParam); | ||
} | ||
return; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly | ||
|
||
// This test checks whether the spirv generated when there are multiple inout or out variables, they | ||
// all get consolidated into one IncomingRayPayloadKHR. | ||
|
||
struct ReflectionRay | ||
{ | ||
float4 color; | ||
}; | ||
|
||
StructuredBuffer<float4> colors; | ||
|
||
[shader("closesthit")] | ||
void main( | ||
BuiltInTriangleIntersectionAttributes attributes, | ||
inout ReflectionRay ioPayload, | ||
out float3 dummy) | ||
{ | ||
uint materialID = (InstanceIndex() << 1) | ||
+ InstanceID() | ||
+ PrimitiveIndex() | ||
+ HitKind(); | ||
|
||
ioPayload.color = colors[materialID]; | ||
dummy = HitTriangleVertexPosition(0); | ||
} | ||
|
||
// CHECK: OpCapability RayTracingKHR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to check from line 28-34 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ack |
||
// CHECK: OpCapability RayTracingPositionFetchKHR | ||
// CHECK: OpCapability Shader | ||
// CHECK: OpExtension "SPV_KHR_ray_tracing" | ||
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class" | ||
// CHECK: OpExtension "SPV_KHR_ray_tracing_position_fetch" | ||
// CHECK: OpMemoryModel Logical GLSL450 | ||
// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is the synthesized struct in the entry point? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kaizhangNV The spirv code generated is like so. So it does include the entry point OpEntryPoint ClosestHitKHR %main "main" %48 %31 %gl_PrimitiveID %25 %gl_InstanceID %colors %11 So it does cover it, but I guess this can't be added as a check condition here since the actual assembly instruction number would differ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. |
||
|
||
// CHECK: OpName %ReflectionRay "ReflectionRay" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to check 37-42 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ack |
||
// CHECK: OpMemberName %ReflectionRay 0 "color" | ||
// CHECK: OpName %materialID "materialID" | ||
// CHECK: OpName %StructuredBuffer "StructuredBuffer" | ||
// CHECK: OpName %colors "colors" | ||
// CHECK: OpName %main "main" | ||
|
||
// CHECK-DAG: OpDecorate %gl_InstanceID BuiltIn InstanceId | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to check 44-53 |
||
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn InstanceCustomIndexKHR | ||
// CHECK-DAG: OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId | ||
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitKindKHR | ||
// CHECK-DAG: OpDecorate %_runtimearr_v4float ArrayStride 16 | ||
// CHECK-DAG: OpDecorate %StructuredBuffer Block | ||
// CHECK-DAG: OpDecorate %colors Binding 0 | ||
// CHECK-DAG: OpDecorate %colors DescriptorSet 0 | ||
// CHECK-DAG: OpDecorate %colors NonWritable | ||
// CHECK-DAG: OpDecorate %{{.*}} BuiltIn HitTriangleVertexPositionsKHR | ||
|
||
// CHECK: %ReflectionRay = OpTypeStruct %v4float | ||
// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can keep this line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and 57 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ack |
||
// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}} | ||
// CHECK: %_ptr_IncomingRayPayloadKHR_ReflectionRay = OpTypePointer IncomingRayPayloadKHR %ReflectionRay | ||
// CHECK: %_ptr_IncomingRayPayloadKHR_v3float = OpTypePointer IncomingRayPayloadKHR %v3float | ||
// CHECK: %StructuredBuffer = OpTypeStruct %_runtimearr_v4float | ||
// CHECK: %_ptr_StorageBuffer_StructuredBuffer = OpTypePointer StorageBuffer %StructuredBuffer | ||
// CHECK: %_ptr_Input__arr_v3float_{{.*}} = OpTypePointer Input %_arr_v3float_{{.*}} | ||
|
||
// CHECK: %main = OpFunction %void None %{{.*}} | ||
// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}} | ||
// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID | ||
// CHECK: %{{.*}} = OpLoad %v4float %{{.*}} | ||
// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0 | ||
// CHECK: %{{.*}} = OpLoad %v3float %{{.*}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as<IROutTypeBase>
should cover both cases in this if.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.