@@ -2390,7 +2390,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
2390
2390
}
2391
2391
}
2392
2392
2393
- void legalizeRayTracingEntryPointParameterForGLSL (
2393
+ void handleSingleParam (
2394
2394
GLSLLegalizationContext* context,
2395
2395
IRFunc* func,
2396
2396
IRParam* pp,
@@ -2442,6 +2442,136 @@ void legalizeRayTracingEntryPointParameterForGLSL(
2442
2442
builder->addDependsOnDecoration (func, globalParam);
2443
2443
}
2444
2444
2445
+ static void consolidateParameters (GLSLLegalizationContext* context, List<IRParam*>& params)
2446
+ {
2447
+ auto builder = context->getBuilder ();
2448
+
2449
+ // Create a struct type to hold all parameters
2450
+ IRInst* consolidatedVar = nullptr ;
2451
+ auto structType = builder->createStructType ();
2452
+
2453
+ // Inside the structure, add fields for each parameter
2454
+ for (auto _param : params)
2455
+ {
2456
+ auto _paramType = _param->getDataType ();
2457
+ IRType* valueType = _paramType;
2458
+
2459
+ if (as<IROutTypeBase>(_paramType))
2460
+ valueType = as<IROutTypeBase>(_paramType)->getValueType ();
2461
+
2462
+ auto key = builder->createStructKey ();
2463
+ if (auto nameDecor = _param->findDecoration <IRNameHintDecoration>())
2464
+ builder->addNameHintDecoration (key, nameDecor->getName ());
2465
+ auto field = builder->createStructField (structType, key, valueType);
2466
+ field->removeFromParent ();
2467
+ field->insertAtEnd (structType);
2468
+ }
2469
+
2470
+ // Create a global variable to hold the consolidated struct
2471
+ consolidatedVar = builder->createGlobalVar (structType);
2472
+ auto ptrType = builder->getPtrType (kIROp_PtrType , structType, AddressSpace::IncomingRayPayload);
2473
+ consolidatedVar->setFullType (ptrType);
2474
+ consolidatedVar->moveToEnd ();
2475
+
2476
+ // Add the ray payload decoration and assign location 0.
2477
+ builder->addVulkanRayPayloadDecoration (consolidatedVar, 0 );
2478
+
2479
+ // Replace each parameter with a field in the consolidated struct
2480
+ for (Index i = 0 ; i < params.getCount (); ++i)
2481
+ {
2482
+ auto _param = params[i];
2483
+
2484
+ // Find the i-th field
2485
+ IRStructField* targetField = nullptr ;
2486
+ Index fieldIndex = 0 ;
2487
+ for (auto field : structType->getFields ())
2488
+ {
2489
+ if (fieldIndex == i)
2490
+ {
2491
+ targetField = field;
2492
+ break ;
2493
+ }
2494
+ fieldIndex++;
2495
+ }
2496
+ SLANG_ASSERT (targetField);
2497
+
2498
+ // Create the field address with the correct type
2499
+ auto _paramType = _param->getDataType ();
2500
+ auto fieldType = targetField->getFieldType ();
2501
+
2502
+ // If the parameter is an out/inout type, we need to create a pointer type
2503
+ IRType* fieldPtrType = nullptr ;
2504
+ if (as<IROutType>(_paramType))
2505
+ {
2506
+ fieldPtrType = builder->getPtrType (kIROp_OutType , fieldType);
2507
+ }
2508
+ else if (as<IRInOutType>(_paramType))
2509
+ {
2510
+ fieldPtrType = builder->getPtrType (kIROp_InOutType , fieldType);
2511
+ }
2512
+
2513
+ auto fieldAddr =
2514
+ builder->emitFieldAddress (fieldPtrType, consolidatedVar, targetField->getKey ());
2515
+
2516
+ // Replace parameter uses with field address
2517
+ _param->replaceUsesWith (fieldAddr);
2518
+ }
2519
+ }
2520
+
2521
+ // Consolidate ray tracing parameters for an entry point function
2522
+ void consolidateRayTracingParameters (GLSLLegalizationContext* context, IRFunc* func)
2523
+ {
2524
+ auto builder = context->getBuilder ();
2525
+ auto firstBlock = func->getFirstBlock ();
2526
+ if (!firstBlock)
2527
+ return ;
2528
+
2529
+ // Collect all out/inout parameters that need to be consolidated
2530
+ List<IRParam*> outParams;
2531
+ List<IRParam*> params;
2532
+
2533
+ for (auto param = firstBlock->getFirstParam (); param; param = param->getNextParam ())
2534
+ {
2535
+ builder->setInsertBefore (firstBlock->getFirstOrdinaryInst ());
2536
+ if (as<IROutType>(param->getDataType ()) || as<IRInOutType>(param->getDataType ()))
2537
+ {
2538
+ outParams.add (param);
2539
+ }
2540
+ params.add (param);
2541
+ }
2542
+
2543
+ // We don't need consolidation here.
2544
+ if (outParams.getCount () <= 1 )
2545
+ {
2546
+ for (auto param : params)
2547
+ {
2548
+ auto paramLayoutDecoration = param->findDecoration <IRLayoutDecoration>();
2549
+ SLANG_ASSERT (paramLayoutDecoration);
2550
+ auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout ());
2551
+ handleSingleParam (context, func, param, paramLayout);
2552
+ }
2553
+ return ;
2554
+ }
2555
+ else
2556
+ {
2557
+ // We need consolidation here, but before that, handle parameters other than inout/out.
2558
+ for (auto param : params)
2559
+ {
2560
+ if (outParams.contains (param))
2561
+ {
2562
+ continue ;
2563
+ }
2564
+ auto paramLayoutDecoration = param->findDecoration <IRLayoutDecoration>();
2565
+ SLANG_ASSERT (paramLayoutDecoration);
2566
+ auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout ());
2567
+ handleSingleParam (context, func, param, paramLayout);
2568
+ }
2569
+
2570
+ // Now, consolidate the inout/out parameters
2571
+ consolidateParameters (context, outParams);
2572
+ }
2573
+ }
2574
+
2445
2575
static void legalizeMeshPayloadInputParam (
2446
2576
GLSLLegalizationContext* context,
2447
2577
CodeGenContext* codeGenContext,
@@ -3129,7 +3259,6 @@ void legalizeEntryPointParameterForGLSL(
3129
3259
}
3130
3260
}
3131
3261
3132
-
3133
3262
// We need to create a global variable that will replace the parameter.
3134
3263
// It seems superficially obvious that the variable should have
3135
3264
// the same type as the parameter.
@@ -3286,7 +3415,6 @@ void legalizeEntryPointParameterForGLSL(
3286
3415
case Stage::Intersection:
3287
3416
case Stage::Miss:
3288
3417
case Stage::RayGeneration:
3289
- legalizeRayTracingEntryPointParameterForGLSL (context, func, pp, paramLayout);
3290
3418
return ;
3291
3419
}
3292
3420
@@ -3916,12 +4044,33 @@ void legalizeEntryPointForGLSL(
3916
4044
invokePathConstantFuncInHullShader (&context, codeGenContext, scalarizedGlobalOutput);
3917
4045
}
3918
4046
4047
+ // Special handling for ray tracing shaders
4048
+ bool isRayTracingShader = false ;
4049
+ switch (stage)
4050
+ {
4051
+ case Stage::AnyHit:
4052
+ case Stage::Callable:
4053
+ case Stage::ClosestHit:
4054
+ case Stage::Intersection:
4055
+ case Stage::Miss:
4056
+ case Stage::RayGeneration:
4057
+ isRayTracingShader = true ;
4058
+ consolidateRayTracingParameters (&context, func);
4059
+ break ;
4060
+ default :
4061
+ break ;
4062
+ }
4063
+
3919
4064
// Next we will walk through any parameters of the entry-point function,
3920
4065
// and turn them into global variables.
3921
4066
if (auto firstBlock = func->getFirstBlock ())
3922
4067
{
3923
4068
for (auto pp = firstBlock->getFirstParam (); pp; pp = pp->getNextParam ())
3924
4069
{
4070
+ if (isRayTracingShader)
4071
+ {
4072
+ continue ;
4073
+ }
3925
4074
// Any initialization code we insert for parameters needs
3926
4075
// to be at the start of the "ordinary" instructions in the block:
3927
4076
builder.setInsertBefore (firstBlock->getFirstOrdinaryInst ());
0 commit comments