Skip to content

Commit 9b45d1b

Browse files
committed
Fix work-graph consumer issue
1 parent 6df923f commit 9b45d1b

File tree

5 files changed

+184
-51
lines changed

5 files changed

+184
-51
lines changed

source/slang/core.meta.slang

+7-7
Original file line numberDiff line numberDiff line change
@@ -4217,19 +4217,19 @@ attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute;
42174217
/// __SPIRVNodePayloadArray is to emit OpTypeNodePayloadArrayAMDX and
42184218
/// SpvDecorationPayloadNodeNameAMDX for it.
42194219
///
4220-
__generic<RecordType, let nodeID : int>
4220+
__generic<BaseType, let nodeID : int>
42214221
__intrinsic_type($(kIROp_SPIRVNodePayloadArrayType))
42224222
struct __SPIRVNodePayloadArray;
42234223

42244224
/// @internal
42254225
/// NodePayloadArrayPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to an array of work-graph nodes
42264226
///
4227-
__generic<RecordType, let nodeID : int>
4228-
typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray<RecordType, nodeID>, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
4227+
__generic<BaseType, let nodeID : int>
4228+
typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray<BaseType, nodeID>, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
42294229

4230-
/// @internal
4231-
/// NodePayloadPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to a `RecordType`.
4230+
/// @public
4231+
/// FunctionPtr is a pointer type for "Function" storage-class and it points to a `BaseType`.
42324232
///
4233-
__generic<RecordType>
4234-
typealias NodePayloadPtr = Ptr<RecordType, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
4233+
__generic<BaseType>
4234+
typealias FunctionPtr = Ptr<BaseType, $( (uint64_t)AddressSpace::Function)>;
42354235

source/slang/hlsl.meta.slang

+77
Original file line numberDiff line numberDiff line change
@@ -24692,3 +24692,80 @@ uint packInt4x8Clamp(int16_t4 unpackedValue)
2469224692
return packInt4x8(clamp(unpackedValue, -128, 127));
2469324693
}
2469424694
}
24695+
// TODO: It seems that unexpected SPIRV code gets emitted here when
24696+
// assign from spriv_asm to `alloc` such as `get_field_addr`
24697+
alloc = spirv_asm
24698+
{
24699+
result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr<RecordType, nodeID> Workgroup $recordCount $nodeIndex;
24700+
};
24701+
}
24702+
24703+
[ForceInline]
24704+
[require(hlsl_spirv)] // [require(hlsl_spirv, thread)]
24705+
void OutputComplete()
24706+
{
24707+
__target_switch
24708+
{
24709+
case hlsl: __intrinsic_asm ".OutputComplete";
24710+
case spirv: spirv_asm { OpEnqueueNodePayloadsAMDX $alloc; };
24711+
}
24712+
}
24713+
24714+
__subscript(uint index) -> FunctionPtr<RecordType>
24715+
{
24716+
[ForceInline]
24717+
get { return Get(index); }
24718+
}
24719+
24720+
FunctionPtr<RecordType> Get(int index = 0)
24721+
{
24722+
__target_switch
24723+
{
24724+
case hlsl: __intrinsic_asm ".Get";
24725+
case spirv:
24726+
let tmp = reinterpret<NodePayloadArrayPtr<RecordType, nodeID>>(alloc);
24727+
let ptr = spirv_asm
24728+
{
24729+
%tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray<RecordType, nodeID>> Function;
24730+
24731+
%loaded = OpLoad $$__SPIRVNodePayloadArray<RecordType, nodeID> $tmp;
24732+
OpStore %tempPtrVar %loaded;
24733+
24734+
result : $$FunctionPtr<RecordType> = OpAccessChain %tempPtrVar $index
24735+
};
24736+
return ptr;
24737+
}
24738+
}
24739+
};
24740+
24741+
//@public:
24742+
/// Read-only input to Broadcasting launch node.
24743+
///
24744+
__generic<RecordType, let nodeID : int>
24745+
struct DispatchNodeInputRecord
24746+
{
24747+
/// Provide an access to a record object that only holds a single record.
24748+
FunctionPtr<RecordType> Get()
24749+
{
24750+
int index = 0;
24751+
24752+
__target_switch
24753+
{
24754+
case hlsl: __intrinsic_asm ".Get";
24755+
case spirv:
24756+
let ptr = spirv_asm
24757+
{
24758+
%nodePtrType = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray<RecordType, nodeID>;
24759+
%nodePtrVar = OpVariable %nodePtrType NodePayloadAMDX;
24760+
24761+
%tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray<RecordType, nodeID>> Function;
24762+
24763+
%loaded = OpLoad $$__SPIRVNodePayloadArray<RecordType, nodeID> %nodePtrVar;
24764+
OpStore %tempPtrVar %loaded;
24765+
24766+
result : $$FunctionPtr<RecordType> = OpAccessChain %tempPtrVar $index;
24767+
};
24768+
return ptr;
24769+
}
24770+
}
24771+
};

source/slang/slang-emit-spirv.cpp

+92-30
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ void SpvInstParent::addInst(SpvInst* inst)
291291
SLANG_ASSERT(inst);
292292
SLANG_ASSERT(!inst->nextSibling);
293293

294+
inst->parent = this;
295+
294296
if (m_firstChild == nullptr)
295297
{
296298
m_firstChild = m_lastChild = inst;
@@ -303,7 +305,6 @@ void SpvInstParent::addInst(SpvInst* inst)
303305
//
304306
m_lastChild->nextSibling = inst;
305307
inst->prevSibling = m_lastChild;
306-
inst->parent = this;
307308
m_lastChild = inst;
308309
}
309310

@@ -491,6 +492,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
491492
/// The next destination `<id>` to allocate.
492493
SpvWord m_nextID = 1;
493494

495+
// This keeps track of the named IDs used in the asm block
496+
Dictionary<IRSPIRVAsm*, Dictionary<UnownedStringSlice, SpvWord>> m_idMaps;
497+
494498
OrderedHashSet<IRPtrTypeBase*> m_forwardDeclaredPointers;
495499

496500
SpvInst* m_nullDwarfExpr = nullptr;
@@ -1537,11 +1541,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
15371541
{
15381542
valueTypeId = getIRInstSpvID(valueType);
15391543
}
1540-
else if (storageClass == SpvStorageClassNodePayloadAMDX)
1541-
{
1542-
auto spvValueType = ensureInst(valueType);
1543-
valueTypeId = getID(spvValueType);
1544-
}
15451544
else
15461545
{
15471546
auto spvValueType = ensureInst(valueType);
@@ -1911,17 +1910,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
19111910
{
19121911
auto newType =
19131912
emitOpTypeNodePayloadArray(inst, nodePayloadArrayType->getRecordType());
1914-
1915-
#if 0
1916-
// TODO: This is a temporary hack.
1917-
// The NodeID must come from an attribute [NodeID("name")].
1918-
Slang::StringBuilder str;
1919-
str << "NodeID_" << uint32_t(nodePayloadArrayType->getNodeID()->getValue());
1920-
SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice());
1921-
(void)spvStr;
1922-
1923-
emitOpDecoratePayloadNodeName(nullptr, newType, spvStr);
1924-
#endif
19251913
return newType;
19261914
}
19271915
default:
@@ -2985,14 +2973,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
29852973
switch (inst->getOp())
29862974
{
29872975
case kIROp_Var:
2988-
emitLocalInst(spvBlock, inst);
2976+
emitLocalInst(spvBlock, spvBlock, inst);
29892977
break;
29902978
case kIROp_DebugVar:
29912979
// Declare an ordinary local variable for debugDeclare association
29922980
// of a debug variable. This variable is what we will actually write
29932981
// values to upon a `kIROp_DebugValue` inst.
29942982
emitDebugVarBackingLocalVarDeclaration(spvBlock, as<IRDebugVar>(inst));
29952983
break;
2984+
case kIROp_SPIRVAsm:
2985+
emitLocalInst(spvBlock, spvBlock, inst);
2986+
break;
29962987
}
29972988
}
29982989
}
@@ -3073,7 +3064,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
30733064
// Skip vars because they are already emitted.
30743065
if (as<IRVar>(irInst))
30753066
continue;
3076-
emitLocalInst(spvBlock, irInst);
3067+
emitLocalInst(spvBlock, nullptr, irInst);
30773068
if (irInst->getOp() == kIROp_loop)
30783069
pendingLoopInsts.add(as<IRLoop>(irInst));
30793070
if (irInst->getOp() == kIROp_discard && !shouldEmitDiscardAsDemote())
@@ -3450,7 +3441,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
34503441
// a known parent (the basic block that contains them).
34513442

34523443
/// Emit an instruction that is local to the body of the given `parent`.
3453-
SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst)
3444+
SpvInst* emitLocalInst(SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst)
34543445
{
34553446
SpvInst* result = nullptr;
34563447
switch (inst->getOp())
@@ -3582,7 +3573,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
35823573
case kIROp_Geq:
35833574
case kIROp_Rsh:
35843575
case kIROp_Lsh:
3585-
result = emitArithmetic(parent, inst);
3576+
result = emitArithmetic(parent, firstLabel, inst);
35863577
break;
35873578
case kIROp_CastDescriptorHandleToUInt2:
35883579
case kIROp_CastUInt2ToDescriptorHandle:
@@ -3848,7 +3839,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
38483839
result = emitOpUndef(parent, inst, inst->getDataType());
38493840
break;
38503841
case kIROp_SPIRVAsm:
3851-
result = emitSPIRVAsm(parent, as<IRSPIRVAsm>(inst));
3842+
result = emitSPIRVAsm(parent, firstLabel, as<IRSPIRVAsm>(inst));
38523843
break;
38533844
case kIROp_ImageLoad:
38543845
result = emitImageLoad(parent, as<IRImageLoad>(inst));
@@ -4459,6 +4450,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
44594450
break;
44604451
}
44614452
}
4453+
4454+
// Pass in global OpVariable as interface to the entry point
4455+
// TODO: Pass in only when they are used by the entry point
4456+
for (auto entryInterface : m_entryPointInterfaces)
4457+
{
4458+
SpvInst* spvInst;
4459+
if (m_mapIRInstToSpvInst.tryGetValue(entryInterface, spvInst))
4460+
{
4461+
params.add(spvInst);
4462+
}
4463+
}
4464+
44624465
emitOpEntryPoint(section, decoration, spvStage, dstID, name, params);
44634466

44644467
// Stage specific execution mode and capability declarations.
@@ -7238,7 +7241,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
72387241
}
72397242

72407243

7241-
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
7244+
SpvInst* emitArithmetic(SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst)
72427245
{
72437246
if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
72447247
{
@@ -7257,7 +7260,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
72577260
if (as<IRMatrixType>(originalOperand->getDataType()))
72587261
{
72597262
auto operand = builder.emitElementExtract(originalOperand, i);
7260-
emitLocalInst(parent, operand);
7263+
emitLocalInst(parent, firstLabel, operand);
72617264
operands.add(operand);
72627265
}
72637266
else
@@ -7809,12 +7812,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
78097812
return debugFunc;
78107813
}
78117814

7812-
SpvInst* emitSPIRVAsm(SpvInstParent* parent, IRSPIRVAsm* inst)
7815+
SpvInst* emitSPIRVAsm(SpvInstParent* parent, SpvInstParent* firstLabel, IRSPIRVAsm* inst)
78137816
{
78147817
SpvInst* last = nullptr;
78157818

7816-
// This keeps track of the named IDs used in the asm block
7817-
Dictionary<UnownedStringSlice, SpvWord> idMap;
7819+
auto &idMap = m_idMaps.getOrAddValue(inst, Dictionary<UnownedStringSlice, SpvWord>());
78187820

78197821
for (const auto spvInst : inst->getInsts())
78207822
{
@@ -7992,6 +7994,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
79927994

79937995
if (spvInst->getOpcodeOperand()->getOp() == kIROp_SPIRVAsmOperandTruncate)
79947996
{
7997+
// Nothing to emit to the first OpLabel
7998+
if (firstLabel)
7999+
continue;
8000+
79958001
const auto getSlangType = [&](IRSPIRVAsmOperand* operand) -> IRType*
79968002
{
79978003
switch (operand->getOp())
@@ -8155,14 +8161,58 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
81558161
default:
81568162
break;
81578163
}
8158-
const auto opParent = parentForOpCode(opcode, parent);
8164+
8165+
IRStringLit* resultID = nullptr;
8166+
SpvInstParent* opParent = nullptr;
8167+
if (opcode == SpvOpVariable)
8168+
{
8169+
// SPIRV validator says,
8170+
// "All OpVariable instructions in a function must be the first instructions in the first block."
8171+
opParent = firstLabel;
8172+
8173+
auto opStorageClass = spvInst->getOperand(3);
8174+
if (opStorageClass && opStorageClass->getOp() == kIROp_SPIRVAsmOperandEnum)
8175+
{
8176+
if (auto intLit = cast<IRIntLit>(opStorageClass->getOperand(0)))
8177+
{
8178+
switch (SpvStorageClass(intLit->getValue()))
8179+
{
8180+
case SpvStorageClassNodePayloadAMDX:
8181+
requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX);
8182+
ensureExtensionDeclaration(
8183+
UnownedStringSlice("SPV_AMDX_shader_enqueue"));
8184+
8185+
opParent = getSection(SpvLogicalSectionID::ConstantsAndTypes);
8186+
8187+
if (auto resultOperand =
8188+
cast<IRSPIRVAsmOperand>(spvInst->getOperand(2)))
8189+
{
8190+
resultID = cast<IRStringLit>(resultOperand->getValue());
8191+
}
8192+
8193+
m_entryPointInterfaces.add(spvInst);
8194+
break;
8195+
}
8196+
}
8197+
}
8198+
}
8199+
if (opParent == nullptr)
8200+
{
8201+
opParent = parentForOpCode(opcode, parent);
8202+
}
8203+
81598204
const auto opInfo = m_grammarInfo->opInfos.lookup(opcode);
81608205

81618206
// TODO: handle resultIdIndex == 1, for constants
81628207
const bool memoize =
81638208
opParent == getSection(SpvLogicalSectionID::ConstantsAndTypes) && opInfo &&
81648209
opInfo->resultIdIndex == 0;
81658210

8211+
// SpvOpVariable must appear at the first block of the function
8212+
// And it may depends on other memoized instructions.
8213+
if ((opcode == SpvOpVariable || memoize) == (firstLabel == nullptr))
8214+
continue;
8215+
81668216
// We want the "result instruction" to refer to the top level
81678217
// block which assumes its value, the others are free to refer
81688218
// to whatever, so just use the internal spv inst rep
@@ -8213,15 +8263,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
82138263
for (const auto operand : spvInst->getSPIRVOperands())
82148264
emitSpvAsmOperand(operand);
82158265
});
8266+
8267+
// TODO: We may be able to simplify without checking the string.
8268+
if (resultID)
8269+
{
8270+
SpvWord id;
8271+
if (last->id == 0 && idMap.tryGetValue(resultID->getStringSlice(), id))
8272+
last->id = id;
8273+
}
82168274
}
82178275
}
82188276
}
82198277

8220-
for (const auto& [name, id] : idMap)
8221-
emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name);
8278+
if (firstLabel == nullptr)
8279+
{
8280+
for (const auto& [name, id] : idMap)
8281+
emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name);
8282+
}
82228283

82238284
return last;
82248285
}
8286+
HashSet<IRInst*> m_entryPointInterfaces;
82258287

82268288
OrderedHashSet<SpvCapability> m_capabilities;
82278289
void requireSPIRVCapability(SpvCapability capability)

0 commit comments

Comments
 (0)