@@ -329,6 +329,10 @@ struct GLSLLegalizationContext
329
329
330
330
IRBuilder* builder;
331
331
IRBuilder* getBuilder () { return builder; }
332
+
333
+ // For ray tracing shaders, we need to consolidate all parameters into a single structure
334
+ Dictionary<IRFunc*, IRInst*> rayTracingConsolidatedVars;
335
+ Dictionary<IRFunc*, List<IRParam*>> rayTracingProcessedParams;
332
336
};
333
337
334
338
// This examines the passed type and determines the GLSL mesh shader indices
@@ -2302,7 +2306,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
2302
2306
}
2303
2307
}
2304
2308
2305
- void legalizeRayTracingEntryPointParameterForGLSL (
2309
+ static void handleSingleParam (
2306
2310
GLSLLegalizationContext* context,
2307
2311
IRFunc* func,
2308
2312
IRParam* pp,
@@ -2354,6 +2358,149 @@ void legalizeRayTracingEntryPointParameterForGLSL(
2354
2358
builder->addDependsOnDecoration (func, globalParam);
2355
2359
}
2356
2360
2361
+ static void consolidateParameters (
2362
+ GLSLLegalizationContext* context,
2363
+ IRFunc* func,
2364
+ List<IRParam*>& params)
2365
+ {
2366
+ auto builder = context->getBuilder ();
2367
+
2368
+ // Create a struct type to hold all parameters
2369
+ IRInst* consolidatedVar = nullptr ;
2370
+ auto structType = builder->createStructType ();
2371
+
2372
+ // Inside the structure, add fields for each parameter
2373
+ for (auto _param : params)
2374
+ {
2375
+ auto _paramType = _param->getDataType ();
2376
+ IRType* valueType = _paramType;
2377
+
2378
+ if (as<IROutType>(_paramType))
2379
+ valueType = as<IROutType>(_paramType)->getValueType ();
2380
+ else if (auto inOutType = as<IRInOutType>(_paramType))
2381
+ valueType = inOutType->getValueType ();
2382
+
2383
+ auto key = builder->createStructKey ();
2384
+ builder->addNameHintDecoration (key, UnownedStringSlice (" field" ));
2385
+ auto field = builder->createStructField (structType, key, valueType);
2386
+ field->removeFromParent ();
2387
+ field->insertAtEnd (structType);
2388
+ }
2389
+
2390
+ // Create a global variable to hold the consolidated struct
2391
+ consolidatedVar = builder->createGlobalVar (structType);
2392
+ auto ptrType = builder->getPtrType (kIROp_PtrType , structType, AddressSpace::IncomingRayPayload);
2393
+ consolidatedVar->setFullType (ptrType);
2394
+ consolidatedVar->moveToEnd ();
2395
+
2396
+ // Add the ray payload decoration and assign location 0.
2397
+ builder->addVulkanRayPayloadDecoration (consolidatedVar, 0 );
2398
+
2399
+ // Store the consolidated variable for this function
2400
+ context->rayTracingConsolidatedVars [func] = consolidatedVar;
2401
+
2402
+ // Replace each parameter with a field in the consolidated struct
2403
+ for (Index i = 0 ; i < params.getCount (); ++i)
2404
+ {
2405
+ auto _param = params[i];
2406
+
2407
+ // Find the i-th field
2408
+ IRStructField* targetField = nullptr ;
2409
+ Index fieldIndex = 0 ;
2410
+ for (auto field : structType->getFields ())
2411
+ {
2412
+ if (fieldIndex == i)
2413
+ {
2414
+ targetField = field;
2415
+ break ;
2416
+ }
2417
+ fieldIndex++;
2418
+ }
2419
+ SLANG_ASSERT (targetField);
2420
+
2421
+ // Create the field address with the correct type
2422
+ auto _paramType = _param->getDataType ();
2423
+ auto fieldType = targetField->getFieldType ();
2424
+
2425
+ // If the parameter is an out/inout type, we need to create a pointer type
2426
+ IRType* fieldPtrType = nullptr ;
2427
+ if (as<IROutType>(_paramType))
2428
+ {
2429
+ fieldPtrType = builder->getPtrType (kIROp_OutType , fieldType);
2430
+ }
2431
+ else if (as<IRInOutType>(_paramType))
2432
+ {
2433
+ fieldPtrType = builder->getPtrType (kIROp_InOutType , fieldType);
2434
+ }
2435
+
2436
+ auto fieldAddr =
2437
+ builder->emitFieldAddress (fieldPtrType, consolidatedVar, targetField->getKey ());
2438
+
2439
+ // Replace parameter uses with field address
2440
+ _param->replaceUsesWith (fieldAddr);
2441
+ }
2442
+ }
2443
+
2444
+ static void handleMultipleParams (GLSLLegalizationContext* context, IRFunc* func, IRParam* pp)
2445
+ {
2446
+ auto firstBlock = func->getFirstBlock ();
2447
+
2448
+ // Now we run the consolidation step, but if we've already
2449
+ // processed this parameter, skip it.
2450
+ List<IRParam*>* processedParams = nullptr ;
2451
+ if (auto foundList = context->rayTracingProcessedParams .tryGetValue (func))
2452
+ {
2453
+ processedParams = foundList;
2454
+ if (processedParams->contains (pp))
2455
+ return ;
2456
+ }
2457
+ else
2458
+ {
2459
+ context->rayTracingProcessedParams [func] = List<IRParam*>();
2460
+ processedParams = &context->rayTracingProcessedParams [func];
2461
+ }
2462
+
2463
+ // Collect all parameters that need to be consolidated
2464
+ List<IRParam*> params;
2465
+ List<IRVarLayout*> paramLayouts;
2466
+
2467
+ for (auto _param = firstBlock->getFirstParam (); _param; _param = _param->getNextParam ())
2468
+ {
2469
+ auto pLayoutDecoration = _param->findDecoration <IRLayoutDecoration>();
2470
+ SLANG_ASSERT (pLayoutDecoration);
2471
+ auto pLayout = as<IRVarLayout>(pLayoutDecoration->getLayout ());
2472
+ SLANG_ASSERT (pLayout);
2473
+
2474
+ // Only include parameters that haven't been processed yet
2475
+ auto _paramType = _param->getDataType ();
2476
+ bool needsConsolidation = (as<IROutType>(_paramType) || as<IRInOutType>(_paramType));
2477
+ if (!processedParams->contains (_param) && needsConsolidation)
2478
+ {
2479
+ params.add (_param);
2480
+ paramLayouts.add (pLayout);
2481
+ processedParams->add (_param);
2482
+ }
2483
+ }
2484
+
2485
+ consolidateParameters (context, func, params);
2486
+ }
2487
+
2488
+ void legalizeRayTracingEntryPointParameterForGLSL (
2489
+ GLSLLegalizationContext* context,
2490
+ IRFunc* func,
2491
+ IRParam* pp,
2492
+ IRVarLayout* paramLayout,
2493
+ bool hasSingleOutOrInOutParam)
2494
+ {
2495
+ if (hasSingleOutOrInOutParam)
2496
+ {
2497
+ handleSingleParam (context, func, pp, paramLayout);
2498
+ return ;
2499
+ }
2500
+
2501
+ handleMultipleParams (context, func, pp);
2502
+ }
2503
+
2357
2504
static void legalizeMeshPayloadInputParam (
2358
2505
GLSLLegalizationContext* context,
2359
2506
CodeGenContext* codeGenContext,
@@ -3041,7 +3188,6 @@ void legalizeEntryPointParameterForGLSL(
3041
3188
}
3042
3189
}
3043
3190
3044
-
3045
3191
// We need to create a global variable that will replace the parameter.
3046
3192
// It seems superficially obvious that the variable should have
3047
3193
// the same type as the parameter.
@@ -3198,7 +3344,28 @@ void legalizeEntryPointParameterForGLSL(
3198
3344
case Stage::Intersection:
3199
3345
case Stage::Miss:
3200
3346
case Stage::RayGeneration:
3201
- legalizeRayTracingEntryPointParameterForGLSL (context, func, pp, paramLayout);
3347
+ {
3348
+ // Count the number of inout or out parameters
3349
+ int inoutOrOutParamCount = 0 ;
3350
+ auto firstBlock = func->getFirstBlock ();
3351
+ for (auto _param = firstBlock->getFirstParam (); _param; _param = _param->getNextParam ())
3352
+ {
3353
+ auto _paramType = _param->getDataType ();
3354
+ if (as<IROutType>(_paramType) || as<IRInOutType>(_paramType))
3355
+ {
3356
+ inoutOrOutParamCount++;
3357
+ }
3358
+ }
3359
+
3360
+ // If we have just one inout or out param, we don't need consolidation.
3361
+ bool hasSingleOutOrInOutParam = inoutOrOutParamCount <= 1 ;
3362
+ legalizeRayTracingEntryPointParameterForGLSL (
3363
+ context,
3364
+ func,
3365
+ pp,
3366
+ paramLayout,
3367
+ hasSingleOutOrInOutParam);
3368
+ }
3202
3369
return ;
3203
3370
}
3204
3371
0 commit comments