@@ -291,6 +291,8 @@ void SpvInstParent::addInst(SpvInst* inst)
291
291
SLANG_ASSERT (inst);
292
292
SLANG_ASSERT (!inst->nextSibling );
293
293
294
+ inst->parent = this ;
295
+
294
296
if (m_firstChild == nullptr )
295
297
{
296
298
m_firstChild = m_lastChild = inst;
@@ -303,7 +305,6 @@ void SpvInstParent::addInst(SpvInst* inst)
303
305
//
304
306
m_lastChild->nextSibling = inst;
305
307
inst->prevSibling = m_lastChild;
306
- inst->parent = this ;
307
308
m_lastChild = inst;
308
309
}
309
310
@@ -491,6 +492,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
491
492
// / The next destination `<id>` to allocate.
492
493
SpvWord m_nextID = 1 ;
493
494
495
+ // This keeps track of the named IDs used in the asm block
496
+ Dictionary<IRSPIRVAsm*, Dictionary<UnownedStringSlice, SpvWord>> m_idMaps;
497
+
494
498
OrderedHashSet<IRPtrTypeBase*> m_forwardDeclaredPointers;
495
499
496
500
SpvInst* m_nullDwarfExpr = nullptr ;
@@ -1537,11 +1541,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
1537
1541
{
1538
1542
valueTypeId = getIRInstSpvID (valueType);
1539
1543
}
1540
- else if (storageClass == SpvStorageClassNodePayloadAMDX)
1541
- {
1542
- auto spvValueType = ensureInst (valueType);
1543
- valueTypeId = getID (spvValueType);
1544
- }
1545
1544
else
1546
1545
{
1547
1546
auto spvValueType = ensureInst (valueType);
@@ -1911,17 +1910,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
1911
1910
{
1912
1911
auto newType =
1913
1912
emitOpTypeNodePayloadArray (inst, nodePayloadArrayType->getRecordType ());
1914
-
1915
- #if 0
1916
- // TODO: This is a temporary hack.
1917
- // The NodeID must come from an attribute [NodeID("name")].
1918
- Slang::StringBuilder str;
1919
- str << "NodeID_" << uint32_t(nodePayloadArrayType->getNodeID()->getValue());
1920
- SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice());
1921
- (void)spvStr;
1922
-
1923
- emitOpDecoratePayloadNodeName(nullptr, newType, spvStr);
1924
- #endif
1925
1913
return newType;
1926
1914
}
1927
1915
default :
@@ -2985,14 +2973,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
2985
2973
switch (inst->getOp ())
2986
2974
{
2987
2975
case kIROp_Var :
2988
- emitLocalInst (spvBlock, inst);
2976
+ emitLocalInst (spvBlock, spvBlock, inst);
2989
2977
break ;
2990
2978
case kIROp_DebugVar :
2991
2979
// Declare an ordinary local variable for debugDeclare association
2992
2980
// of a debug variable. This variable is what we will actually write
2993
2981
// values to upon a `kIROp_DebugValue` inst.
2994
2982
emitDebugVarBackingLocalVarDeclaration (spvBlock, as<IRDebugVar>(inst));
2995
2983
break ;
2984
+ case kIROp_SPIRVAsm :
2985
+ emitLocalInst (spvBlock, spvBlock, inst);
2986
+ break ;
2996
2987
}
2997
2988
}
2998
2989
}
@@ -3073,7 +3064,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
3073
3064
// Skip vars because they are already emitted.
3074
3065
if (as<IRVar>(irInst))
3075
3066
continue ;
3076
- emitLocalInst (spvBlock, irInst);
3067
+ emitLocalInst (spvBlock, nullptr , irInst);
3077
3068
if (irInst->getOp () == kIROp_loop )
3078
3069
pendingLoopInsts.add (as<IRLoop>(irInst));
3079
3070
if (irInst->getOp () == kIROp_discard && !shouldEmitDiscardAsDemote ())
@@ -3450,7 +3441,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
3450
3441
// a known parent (the basic block that contains them).
3451
3442
3452
3443
// / Emit an instruction that is local to the body of the given `parent`.
3453
- SpvInst* emitLocalInst (SpvInstParent* parent, IRInst* inst)
3444
+ SpvInst* emitLocalInst (SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst)
3454
3445
{
3455
3446
SpvInst* result = nullptr ;
3456
3447
switch (inst->getOp ())
@@ -3582,7 +3573,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
3582
3573
case kIROp_Geq :
3583
3574
case kIROp_Rsh :
3584
3575
case kIROp_Lsh :
3585
- result = emitArithmetic (parent, inst);
3576
+ result = emitArithmetic (parent, firstLabel, inst);
3586
3577
break ;
3587
3578
case kIROp_CastDescriptorHandleToUInt2 :
3588
3579
case kIROp_CastUInt2ToDescriptorHandle :
@@ -3848,7 +3839,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
3848
3839
result = emitOpUndef (parent, inst, inst->getDataType ());
3849
3840
break ;
3850
3841
case kIROp_SPIRVAsm :
3851
- result = emitSPIRVAsm (parent, as<IRSPIRVAsm>(inst));
3842
+ result = emitSPIRVAsm (parent, firstLabel, as<IRSPIRVAsm>(inst));
3852
3843
break ;
3853
3844
case kIROp_ImageLoad :
3854
3845
result = emitImageLoad (parent, as<IRImageLoad>(inst));
@@ -4459,6 +4450,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
4459
4450
break ;
4460
4451
}
4461
4452
}
4453
+
4454
+ // Pass in global OpVariable as interface to the entry point
4455
+ // TODO: Pass in only when they are used by the entry point
4456
+ for (auto entryInterface : m_entryPointInterfaces)
4457
+ {
4458
+ SpvInst* spvInst;
4459
+ if (m_mapIRInstToSpvInst.tryGetValue (entryInterface, spvInst))
4460
+ {
4461
+ params.add (spvInst);
4462
+ }
4463
+ }
4464
+
4462
4465
emitOpEntryPoint (section, decoration, spvStage, dstID, name, params);
4463
4466
4464
4467
// Stage specific execution mode and capability declarations.
@@ -7238,7 +7241,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
7238
7241
}
7239
7242
7240
7243
7241
- SpvInst* emitArithmetic (SpvInstParent* parent, IRInst* inst)
7244
+ SpvInst* emitArithmetic (SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst)
7242
7245
{
7243
7246
if (const auto matrixType = as<IRMatrixType>(inst->getDataType ()))
7244
7247
{
@@ -7257,7 +7260,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
7257
7260
if (as<IRMatrixType>(originalOperand->getDataType ()))
7258
7261
{
7259
7262
auto operand = builder.emitElementExtract (originalOperand, i);
7260
- emitLocalInst (parent, operand);
7263
+ emitLocalInst (parent, firstLabel, operand);
7261
7264
operands.add (operand);
7262
7265
}
7263
7266
else
@@ -7809,12 +7812,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
7809
7812
return debugFunc;
7810
7813
}
7811
7814
7812
- SpvInst* emitSPIRVAsm (SpvInstParent* parent, IRSPIRVAsm* inst)
7815
+ SpvInst* emitSPIRVAsm (SpvInstParent* parent, SpvInstParent* firstLabel, IRSPIRVAsm* inst)
7813
7816
{
7814
7817
SpvInst* last = nullptr ;
7815
7818
7816
- // This keeps track of the named IDs used in the asm block
7817
- Dictionary<UnownedStringSlice, SpvWord> idMap;
7819
+ auto &idMap = m_idMaps.getOrAddValue (inst, Dictionary<UnownedStringSlice, SpvWord>());
7818
7820
7819
7821
for (const auto spvInst : inst->getInsts ())
7820
7822
{
@@ -7992,6 +7994,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
7992
7994
7993
7995
if (spvInst->getOpcodeOperand ()->getOp () == kIROp_SPIRVAsmOperandTruncate )
7994
7996
{
7997
+ // Nothing to emit to the first OpLabel
7998
+ if (firstLabel)
7999
+ continue ;
8000
+
7995
8001
const auto getSlangType = [&](IRSPIRVAsmOperand* operand) -> IRType*
7996
8002
{
7997
8003
switch (operand->getOp ())
@@ -8155,14 +8161,58 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
8155
8161
default :
8156
8162
break ;
8157
8163
}
8158
- const auto opParent = parentForOpCode (opcode, parent);
8164
+
8165
+ IRStringLit* resultID = nullptr ;
8166
+ SpvInstParent* opParent = nullptr ;
8167
+ if (opcode == SpvOpVariable)
8168
+ {
8169
+ // SPIRV validator says,
8170
+ // "All OpVariable instructions in a function must be the first instructions in the first block."
8171
+ opParent = firstLabel;
8172
+
8173
+ auto opStorageClass = spvInst->getOperand (3 );
8174
+ if (opStorageClass && opStorageClass->getOp () == kIROp_SPIRVAsmOperandEnum )
8175
+ {
8176
+ if (auto intLit = cast<IRIntLit>(opStorageClass->getOperand (0 )))
8177
+ {
8178
+ switch (SpvStorageClass (intLit->getValue ()))
8179
+ {
8180
+ case SpvStorageClassNodePayloadAMDX:
8181
+ requireSPIRVCapability (SpvCapabilityShaderEnqueueAMDX);
8182
+ ensureExtensionDeclaration (
8183
+ UnownedStringSlice (" SPV_AMDX_shader_enqueue" ));
8184
+
8185
+ opParent = getSection (SpvLogicalSectionID::ConstantsAndTypes);
8186
+
8187
+ if (auto resultOperand =
8188
+ cast<IRSPIRVAsmOperand>(spvInst->getOperand (2 )))
8189
+ {
8190
+ resultID = cast<IRStringLit>(resultOperand->getValue ());
8191
+ }
8192
+
8193
+ m_entryPointInterfaces.add (spvInst);
8194
+ break ;
8195
+ }
8196
+ }
8197
+ }
8198
+ }
8199
+ if (opParent == nullptr )
8200
+ {
8201
+ opParent = parentForOpCode (opcode, parent);
8202
+ }
8203
+
8159
8204
const auto opInfo = m_grammarInfo->opInfos .lookup (opcode);
8160
8205
8161
8206
// TODO: handle resultIdIndex == 1, for constants
8162
8207
const bool memoize =
8163
8208
opParent == getSection (SpvLogicalSectionID::ConstantsAndTypes) && opInfo &&
8164
8209
opInfo->resultIdIndex == 0 ;
8165
8210
8211
+ // SpvOpVariable must appear at the first block of the function
8212
+ // And it may depends on other memoized instructions.
8213
+ if ((opcode == SpvOpVariable || memoize) == (firstLabel == nullptr ))
8214
+ continue ;
8215
+
8166
8216
// We want the "result instruction" to refer to the top level
8167
8217
// block which assumes its value, the others are free to refer
8168
8218
// to whatever, so just use the internal spv inst rep
@@ -8213,15 +8263,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
8213
8263
for (const auto operand : spvInst->getSPIRVOperands ())
8214
8264
emitSpvAsmOperand (operand);
8215
8265
});
8266
+
8267
+ // TODO: We may be able to simplify without checking the string.
8268
+ if (resultID)
8269
+ {
8270
+ SpvWord id;
8271
+ if (last->id == 0 && idMap.tryGetValue (resultID->getStringSlice (), id))
8272
+ last->id = id;
8273
+ }
8216
8274
}
8217
8275
}
8218
8276
}
8219
8277
8220
- for (const auto & [name, id] : idMap)
8221
- emitOpName (getSection (SpvLogicalSectionID::DebugNames), nullptr , id, name);
8278
+ if (firstLabel == nullptr )
8279
+ {
8280
+ for (const auto & [name, id] : idMap)
8281
+ emitOpName (getSection (SpvLogicalSectionID::DebugNames), nullptr , id, name);
8282
+ }
8222
8283
8223
8284
return last;
8224
8285
}
8286
+ HashSet<IRInst*> m_entryPointInterfaces;
8225
8287
8226
8288
OrderedHashSet<SpvCapability> m_capabilities;
8227
8289
void requireSPIRVCapability (SpvCapability capability)
0 commit comments