Skip to content

Commit 14211ec

Browse files
Remove unnecessary parameters from Metal entry point signature (#6131)
* fix metal entry point global params * address review comments, cleanup and test * remove dead code * undo accidental change * address review comments and cleanup * minor fix and cleanup --------- Co-authored-by: Yong He <yonghe@outlook.com>
1 parent ea98e24 commit 14211ec

12 files changed

+277
-160
lines changed

source/slang/slang-emit.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,8 @@
9393
#include "slang-ir-ssa-simplification.h"
9494
#include "slang-ir-ssa.h"
9595
#include "slang-ir-string-hash.h"
96-
#include "slang-ir-strip-cached-dict.h"
9796
#include "slang-ir-strip-default-construct.h"
98-
#include "slang-ir-strip-witness-tables.h"
99-
#include "slang-ir-strip.h"
97+
#include "slang-ir-strip-legalization-insts.h"
10098
#include "slang-ir-synthesize-active-mask.h"
10199
#include "slang-ir-translate-glsl-global-var.h"
102100
#include "slang-ir-uniformity.h"
@@ -1501,12 +1499,10 @@ Result linkAndOptimizeIR(
15011499
break;
15021500
}
15031501

1504-
stripCachedDictionaries(irModule);
1505-
15061502
// TODO: our current dynamic dispatch pass will remove all uses of witness tables.
15071503
// If we are going to support function-pointer based, "real" modular dynamic dispatch,
15081504
// we will need to disable this pass.
1509-
stripWitnessTables(irModule);
1505+
stripLegalizationOnlyInstructions(irModule);
15101506

15111507
switch (target)
15121508
{

source/slang/slang-ir-entry-point-uniforms.cpp

+16-8
Original file line numberDiff line numberDiff line change
@@ -494,19 +494,27 @@ struct MoveEntryPointUniformParametersToGlobalScope : PerEntryPointPass
494494
// for CPU/CUDA) that might want to treat entry-point parameters
495495
// different from other cases.
496496
//
497-
// TODO: Once we have support for multiple entry points to be emitted
498-
// at once, we need a way to associate these per-entry-point parameters
499-
// more closely with the original entry point. The two easiest options
500-
// are:
497+
// We need a way to associate these per-entry-point parameters
498+
// more closely with the original entry point. The two current
499+
// methods are:
501500
//
502501
// 1. Don't move the new aggregate parameter to the global scope
503502
// on those targets, and instead keep it as a parameter of the
504-
// entry point.
503+
// entry point. This is used for CPU/CUDA targets.
505504
//
506-
// 2. Use a decoration on the entry point itself to point at the
507-
// global parameter for its per-entry-point parameter data.
505+
// 2. Use a decoration on the global param itself to point at the
506+
// entry point for its per-entry-point parameter data, without moving
507+
// the parameter to the global scope. This is used for Metal targets, as
508+
// Metal does not have global parameters at the global scope.
508509
//
509-
builder->addDecoration(globalParam, kIROp_EntryPointParamDecoration);
510+
// Method (1) is not used because Metal contains shading language concepts
511+
// such as binding offets, similar to other shading language targets.
512+
// We want to reuse code from other shading language targets for Metal, hence
513+
// we move parameters to the global scope, and then move the parameters back to
514+
// the entry points that they originate from. The originating entry points are
515+
// tracked through this decoration.
516+
//
517+
builder->addEntryPointParamDecoration(globalParam, entryPointFunc);
510518

511519
param->replaceUsesWith(globalParam);
512520
param->removeAndDeallocate();

source/slang/slang-ir-explicit-global-context.cpp

+52-17
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,19 @@ struct IntroduceExplicitGlobalContextPass
140140
IRStructType* m_contextStructType = nullptr;
141141
IRPtrType* m_contextStructPtrType = nullptr;
142142

143-
List<IRGlobalParam*> m_globalParams;
143+
struct GlobalParamInfo
144+
{
145+
// Original global param inst.
146+
IRGlobalParam* globalParam = nullptr;
147+
148+
// New entry point param that is created by this pass.
149+
IRParam* entryPointParam = nullptr;
150+
151+
// Orignating entry point obtained from entry point param decoration, if it exists.
152+
IRFunc* originatingEntryPoint = nullptr;
153+
};
154+
155+
List<GlobalParamInfo> m_globalParams;
144156
List<IRGlobalVar*> m_globalVars;
145157
List<IRFunc*> m_entryPoints;
146158

@@ -237,7 +249,22 @@ struct IntroduceExplicitGlobalContextPass
237249
if (m_target == CodeGenTarget::CUDASource)
238250
continue;
239251

240-
m_globalParams.add(globalParam);
252+
GlobalParamInfo globalParamInfo;
253+
globalParamInfo.globalParam = globalParam;
254+
255+
// Entry point param decorations are not required anymore after this pass and
256+
// must be removed for entry point param emit. Remoeving it here prevents the
257+
// decoration from being cloned when creating struct keys and entry point
258+
// parameters.
259+
if (const auto entryPointParamDecoration =
260+
globalParam->findDecoration<IREntryPointParamDecoration>())
261+
{
262+
globalParamInfo.originatingEntryPoint =
263+
entryPointParamDecoration->getEntryPoint();
264+
entryPointParamDecoration->removeAndDeallocate();
265+
}
266+
267+
m_globalParams.add(globalParamInfo);
241268
}
242269
break;
243270

@@ -305,11 +332,10 @@ struct IntroduceExplicitGlobalContextPass
305332
// For the parameter representing all the global uniform shader
306333
// parameters, we create a field that exactly matches its type.
307334
//
308-
309335
createContextStructField(
310-
globalParam,
336+
globalParam.globalParam,
311337
GlobalObjectKind::GlobalParam,
312-
globalParam->getFullType());
338+
globalParam.globalParam->getFullType());
313339
}
314340
for (auto globalVar : m_globalVars)
315341
{
@@ -347,7 +373,7 @@ struct IntroduceExplicitGlobalContextPass
347373
//
348374
for (auto globalParam : m_globalParams)
349375
{
350-
replaceUsesOfGlobalParam(globalParam);
376+
replaceUsesOfGlobalParam(globalParam.globalParam);
351377
}
352378
for (auto globalVar : m_globalVars)
353379
{
@@ -444,23 +470,32 @@ struct IntroduceExplicitGlobalContextPass
444470
// then we need to introduce an explicit parameter onto
445471
// each entry-point function to represent it.
446472
//
447-
struct GlobalParamInfo
448-
{
449-
IRGlobalParam* globalParam;
450-
IRParam* entryPointParam;
451-
};
452-
List<GlobalParamInfo> entryPointParams;
473+
474+
List<GlobalParamInfo> entryPointParamsToAdd;
453475
for (auto globalParam : m_globalParams)
454476
{
455-
auto entryPointParam = builder.createParam(globalParam->getFullType());
477+
// Do not add global param to current entry point if global param
478+
// explicitly originates from a different entry point.
479+
if (globalParam.originatingEntryPoint &&
480+
globalParam.originatingEntryPoint != entryPointFunc)
481+
{
482+
continue;
483+
}
484+
485+
globalParam.entryPointParam =
486+
builder.createParam(globalParam.globalParam->getFullType());
456487
IRCloneEnv cloneEnv;
457-
cloneInstDecorationsAndChildren(&cloneEnv, m_module, globalParam, entryPointParam);
458-
entryPointParams.add({globalParam, entryPointParam});
488+
cloneInstDecorationsAndChildren(
489+
&cloneEnv,
490+
m_module,
491+
globalParam.globalParam,
492+
globalParam.entryPointParam);
493+
entryPointParamsToAdd.add(globalParam);
459494

460495
// The new parameter will be the last one in the
461496
// parameter list of the entry point.
462497
//
463-
entryPointParam->insertBefore(firstOrdinary);
498+
globalParam.entryPointParam->insertBefore(firstOrdinary);
464499
}
465500

466501
if (m_target == CodeGenTarget::CPPSource && m_globalParams.getCount() == 0)
@@ -485,7 +520,7 @@ struct IntroduceExplicitGlobalContextPass
485520
// to inialize the corresponding field of the `KernelContext`
486521
// before moving on with execution of the kernel body.
487522
//
488-
for (auto entryPointParam : entryPointParams)
523+
for (auto entryPointParam : entryPointParamsToAdd)
489524
{
490525
auto fieldInfo = m_mapInstToContextFieldInfo[entryPointParam.globalParam];
491526
auto fieldType = entryPointParam.globalParam->getFullType();

source/slang/slang-ir-insts.h

+13
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,14 @@ struct IRKnownBuiltinDecoration : IRDecoration
812812
UnownedStringSlice getName() { return getNameOperand()->getStringSlice(); }
813813
};
814814

815+
struct IREntryPointParamDecoration : IRDecoration
816+
{
817+
IR_LEAF_ISA(EntryPointParamDecoration)
818+
819+
/// Get the entry point that this parameter orignates from.
820+
IRFunc* getEntryPoint() { return cast<IRFunc>(getOperand(0)); }
821+
};
822+
815823
struct IRFormatDecoration : IRDecoration
816824
{
817825
enum
@@ -5226,6 +5234,11 @@ struct IRBuilder
52265234
{
52275235
addDecoration(inst, kIROp_CheckpointIntermediateDecoration, func);
52285236
}
5237+
5238+
void addEntryPointParamDecoration(IRInst* inst, IRFunc* entryPointFunc)
5239+
{
5240+
addDecoration(inst, kIROp_EntryPointParamDecoration, entryPointFunc);
5241+
}
52295242
};
52305243

52315244
// Helper to establish the source location that will be used

source/slang/slang-ir-legalize-types.cpp

+34-23
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,11 @@ static void registerLegalizedValue(
110110
context->mapValToLegalVal[irValue] = legalVal;
111111
}
112112

113-
struct IRGlobalNameInfo
113+
/// Structure to pass information from the original/old global param to
114+
/// composite members during tuple flavored global param legalization.
115+
struct IRGlobalParamInfo
114116
{
115-
IRInst* globalVar;
116-
UInt counter;
117+
IRFunc* originatingEntryPoint = nullptr;
117118
};
118119

119120
static LegalVal declareVars(
@@ -124,7 +125,7 @@ static LegalVal declareVars(
124125
LegalVarChain const& varChain,
125126
UnownedStringSlice nameHint,
126127
IRInst* leafVar,
127-
IRGlobalNameInfo* globalNameInfo,
128+
IRGlobalParamInfo* globalParamInfo,
128129
bool isSpecial);
129130

130131
/// Unwrap a value with flavor `wrappedBuffer`
@@ -2727,10 +2728,8 @@ static LegalVal declareSimpleVar(
27272728
LegalVarChain const& varChain,
27282729
UnownedStringSlice nameHint,
27292730
IRInst* leafVar,
2730-
IRGlobalNameInfo* globalNameInfo)
2731+
IRGlobalParamInfo* globalParamInfo)
27312732
{
2732-
SLANG_UNUSED(globalNameInfo);
2733-
27342733
IRVarLayout* varLayout = createVarLayout(context->builder, varChain, typeLayout);
27352734

27362735
IRBuilder* builder = context->builder;
@@ -2757,6 +2756,19 @@ static LegalVal declareSimpleVar(
27572756
globalParam->removeFromParent();
27582757
globalParam->insertBefore(context->insertBeforeGlobal);
27592758

2759+
// Add originating entry point decoration if original global param
2760+
// comes from an entry point parameter. This is required in cases where the global
2761+
// param has to be linked back to the originating entry point, such as when
2762+
// emitting Metal where there global params have to be moved back to the
2763+
// entry point parameter.
2764+
SLANG_ASSERT(globalParamInfo);
2765+
if (globalParamInfo->originatingEntryPoint)
2766+
{
2767+
builder->addEntryPointParamDecoration(
2768+
globalParam,
2769+
globalParamInfo->originatingEntryPoint);
2770+
}
2771+
27602772
irVar = globalParam;
27612773
legalVarVal = LegalVal::simple(globalParam);
27622774
}
@@ -3416,7 +3428,7 @@ static LegalVal declareVars(
34163428
LegalVarChain const& inVarChain,
34173429
UnownedStringSlice nameHint,
34183430
IRInst* leafVar,
3419-
IRGlobalNameInfo* globalNameInfo,
3431+
IRGlobalParamInfo* globalParamInfo,
34203432
bool isSpecial)
34213433
{
34223434
LegalVarChain varChain = inVarChain;
@@ -3451,7 +3463,7 @@ static LegalVal declareVars(
34513463
varChain,
34523464
nameHint,
34533465
leafVar,
3454-
globalNameInfo);
3466+
globalParamInfo);
34553467
break;
34563468

34573469
case LegalType::Flavor::implicitDeref:
@@ -3466,7 +3478,7 @@ static LegalVal declareVars(
34663478
varChain,
34673479
nameHint,
34683480
leafVar,
3469-
globalNameInfo,
3481+
globalParamInfo,
34703482
isSpecial);
34713483
return LegalVal::implicitDeref(val);
34723484
}
@@ -3483,7 +3495,7 @@ static LegalVal declareVars(
34833495
varChain,
34843496
nameHint,
34853497
leafVar,
3486-
globalNameInfo,
3498+
globalParamInfo,
34873499
false);
34883500
auto specialVal = declareVars(
34893501
context,
@@ -3493,7 +3505,7 @@ static LegalVal declareVars(
34933505
varChain,
34943506
nameHint,
34953507
leafVar,
3496-
globalNameInfo,
3508+
globalParamInfo,
34973509
true);
34983510
return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo);
34993511
}
@@ -3545,7 +3557,7 @@ static LegalVal declareVars(
35453557
newVarChain,
35463558
fieldNameHint,
35473559
ee.key,
3548-
globalNameInfo,
3560+
globalParamInfo,
35493561
true);
35503562

35513563
TuplePseudoVal::Element element;
@@ -3600,7 +3612,7 @@ static LegalVal declareVars(
36003612
varChain,
36013613
nameHint,
36023614
leafVar,
3603-
globalNameInfo);
3615+
globalParamInfo);
36043616

36053617
return LegalVal::wrappedBuffer(innerVal, wrappedBuffer->elementInfo);
36063618
}
@@ -3634,10 +3646,6 @@ static LegalVal legalizeGlobalVar(IRTypeLegalizationContext* context, IRGlobalVa
36343646
{
36353647
context->insertBeforeGlobal = irGlobalVar;
36363648

3637-
IRGlobalNameInfo globalNameInfo;
3638-
globalNameInfo.globalVar = irGlobalVar;
3639-
globalNameInfo.counter = 0;
3640-
36413649
UnownedStringSlice nameHint = findNameHint(irGlobalVar);
36423650
context->builder->setInsertBefore(irGlobalVar);
36433651
LegalVal newVal = declareVars(
@@ -3648,7 +3656,7 @@ static LegalVal legalizeGlobalVar(IRTypeLegalizationContext* context, IRGlobalVa
36483656
LegalVarChain(),
36493657
nameHint,
36503658
irGlobalVar,
3651-
&globalNameInfo,
3659+
nullptr,
36523660
context->isSpecialType(originalValueType));
36533661

36543662
// Register the new value as the replacement for the old
@@ -3689,9 +3697,12 @@ static LegalVal legalizeGlobalParam(
36893697

36903698
LegalVarChainLink varChain(LegalVarChain(), varLayout);
36913699

3692-
IRGlobalNameInfo globalNameInfo;
3693-
globalNameInfo.globalVar = irGlobalParam;
3694-
globalNameInfo.counter = 0;
3700+
IRGlobalParamInfo globalParamInfo;
3701+
if (auto entryPointParamDecoration =
3702+
irGlobalParam->findDecoration<IREntryPointParamDecoration>())
3703+
{
3704+
globalParamInfo.originatingEntryPoint = entryPointParamDecoration->getEntryPoint();
3705+
}
36953706

36963707
// TODO: need to handle initializer here!
36973708

@@ -3705,7 +3716,7 @@ static LegalVal legalizeGlobalParam(
37053716
varChain,
37063717
nameHint,
37073718
irGlobalParam,
3708-
&globalNameInfo,
3719+
&globalParamInfo,
37093720
context->isSpecialType(irGlobalParam->getDataType()));
37103721

37113722
// Register the new value as the replacement for the old

0 commit comments

Comments
 (0)