Skip to content

Commit 9ebc84c

Browse files
committed
Allow duplication of OpTypeNodePayloadArrayAMDX
This commit is mainly to allow OpTypeNodePayloadArrayAMDX to be duplicated in SPIRV when using WorkGraphs. AMD spec requires that there should be dedicated type declaration with OpTypeNodePayloadArrayAMDX for each "node". And each of them should be decorated with PayloadNodeNameAMDX. https://docs.vulkan.org/features/latest/features/proposals/VK_AMDX_shader_enqueue.html This behavior is very different from how "types" are declared. Normally types are declared only once and they are shared across the whole program. But for the work-graphs, each "RecordType" belongs to each node and they need to be declared for each node. Consider the following HLSL example where a producer triggers two consumer nodes. Each node will be differentiated by "NodeID" attribute. ``` [Shader("node")] [NodeLaunch("broadcasting")] [NumThreads(4,1,1)] void myFancyNode(..., [NodeID("myNode1")] NodeOutput<MY_RECORD> myRecords1, [NodeID("myNode2")] NodeOutput<MY_RECORD> myRecords2, ...) { ThreadNodeOutputRecords<MY_RECORD> myRecord1 = myRecords1.GetThreadNodeOutputRecords(1); myRecord1.myData = 1; myRecord1.OutputCompelete(); ThreadNodeOutputRecords<MY_RECORD> myRecord2 = myRecords2.GetThreadNodeOutputRecords(1); myRecord2.myData = 1; myRecord2.OutputCompelete(); } ``` Even though both are using the same type, "MY_RECORD", each of them needs to be declared separately. The generated SPIR-V code for the HLSL above is something like the following, ``` %str_node1 = OpConstantStringAMDX "myNode1"; %str_node2 = OpConstantStringAMDX "myNode1"; %type_node1 = OpTypeNodePayloadArrayAMDX %type_myRecord; %type_node2 = OpTypeNodePayloadArrayAMDX %type_myRecord; OpDecorateId %type_node1 PayloadNodeNameAMDX %str_node1; OpDecorateId %type_node2 PayloadNodeNameAMDX %str_node2; ``` Note that `%type_node1` and `%type_node2` are equivalent, but they must be declared separately because each of them will be decorated with different nodeID string.
1 parent 519e866 commit 9ebc84c

8 files changed

+215
-23
lines changed

source/slang/core.meta.slang

+21-2
Original file line numberDiff line numberDiff line change
@@ -4211,6 +4211,25 @@ attribute_syntax [QuadDerivatives] : QuadDerivativesAttribute;
42114211
__attributeTarget(FuncDecl)
42124212
attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute;
42134213

4214-
__generic<T>
4215-
typealias NodePayloadPtr = Ptr<T, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
4214+
// Work-graphs
4215+
4216+
/// @internal
4217+
/// __SPIRVNodePayloadArray is to emit OpTypeNodePayloadArrayAMDX and
4218+
/// SpvDecorationPayloadNodeNameAMDX for it.
4219+
///
4220+
__generic<RecordType, let nodeID : int>
4221+
__intrinsic_type($(kIROp_SPIRVNodePayloadArrayType))
4222+
struct __SPIRVNodePayloadArray;
4223+
4224+
/// @internal
4225+
/// NodePayloadArrayPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to an array of work-graph nodes
4226+
///
4227+
__generic<RecordType, let nodeID : int>
4228+
typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray<RecordType, nodeID>, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
4229+
4230+
/// @internal
4231+
/// NodePayloadPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to a `RecordType`.
4232+
///
4233+
__generic<RecordType>
4234+
typealias NodePayloadPtr = Ptr<RecordType, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
42164235

source/slang/hlsl.meta.slang

+68-11
Original file line numberDiff line numberDiff line change
@@ -24371,28 +24371,85 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue)
2437124371

2437224372
// Work-graphs
2437324373

24374+
///@public
24375+
/// Set of zero or more records for each thread.
24376+
///
24377+
__generic<RecordType, let nodeID : int>
24378+
struct ThreadNodeOutputRecords
24379+
{
24380+
NodePayloadArrayPtr<RecordType, nodeID> alloc;
24381+
24382+
[ForceInline]
24383+
[require(spirv)] // [require(spirv, thread)]
24384+
__init(int recordCount, int nodeIndex = 0)
24385+
{
24386+
alloc = spirv_asm
24387+
{
24388+
result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr<RecordType, nodeID> Workgroup $recordCount $nodeIndex;
24389+
};
24390+
}
24391+
24392+
[ForceInline]
24393+
[require(hlsl_spirv)] // [require(hlsl_spirv, thread)]
24394+
void OutputComplete()
24395+
{
24396+
__target_switch
24397+
{
24398+
case hlsl: __intrinsic_asm ".OutputComplete";
24399+
case spirv: spirv_asm { OpEnqueueNodePayloadsAMDX $alloc; };
24400+
}
24401+
}
24402+
24403+
__subscript(uint index) -> NodePayloadPtr<RecordType>
24404+
{
24405+
[ForceInline]
24406+
get { return Get(index); }
24407+
}
24408+
24409+
[ForceInline]
24410+
[require(hlsl_spirv)] // [require(hlsl_spirv, thread)]
24411+
NodePayloadPtr<RecordType> Get(int index = 0)
24412+
{
24413+
__target_switch
24414+
{
24415+
case hlsl: __intrinsic_asm ".Get";
24416+
case spirv:
24417+
NodePayloadArrayPtr<RecordType, nodeID> ptr = spirv_asm
24418+
{
24419+
%ptr = OpTypePointer Function $$__SPIRVNodePayloadArray<RecordType, nodeID>;
24420+
%var = OpVariable %ptr Function;
24421+
OpStore %var $alloc;
24422+
%Load = OpLoad $$NodePayloadArrayPtr<RecordType, nodeID> %var;
24423+
result : $$NodePayloadArrayPtr<RecordType, nodeID> = OpAccessChain %Load $index
24424+
};
24425+
return reinterpret<NodePayloadPtr<RecordType>>(ptr);
24426+
}
24427+
}
24428+
};
24429+
2437424430
//@public:
24375-
/// read-only input to Broadcasting launch node.
24376-
__generic<T>
24377-
//TODO: DispatchNodeInputRecord should be available only for broadcasting node shader.
24378-
//[require(broadcasting_node)]
24379-
[require(spirv)]
24431+
/// Read-only input to Broadcasting launch node.
24432+
///
24433+
__generic<RecordType, let nodeID : int>
2438024434
struct DispatchNodeInputRecord
2438124435
{
2438224436
/// Provide an access to a record object that only holds a single record.
24383-
NodePayloadPtr<T> Get()
24437+
[ForceInline]
24438+
[require(hlsl_spirv)] // [require(hlsl_spirv, broadcasting)]
24439+
NodePayloadPtr<RecordType> Get()
2438424440
{
2438524441
int index = 0;
2438624442
__target_switch
2438724443
{
24444+
case hlsl: __intrinsic_asm ".Get";
2438824445
case spirv:
24389-
return spirv_asm
24446+
NodePayloadArrayPtr<RecordType, nodeID> ptr = spirv_asm
2439024447
{
24391-
%in_payload_t = OpTypeNodePayloadArrayAMDX $$T;
24392-
%in_payload_ptr_t = OpTypePointer NodePayloadAMDX %in_payload_t;
24393-
%var = OpVariable %in_payload_ptr_t NodePayloadAMDX;
24394-
result : $$NodePayloadPtr<T> = OpAccessChain %var $index;
24448+
%ptr = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray<RecordType, nodeID>;
24449+
%var = OpVariable %ptr NodePayloadAMDX;
24450+
result : $$NodePayloadArrayPtr<RecordType, nodeID> = OpAccessChain %var $index;
2439524451
};
24452+
return reinterpret<NodePayloadPtr<RecordType>>(ptr);
2439624453
}
2439724454
}
2439824455
};

source/slang/slang-emit-spirv-ops.h

+25-1
Original file line numberDiff line numberDiff line change
@@ -2571,12 +2571,36 @@ template<typename T>
25712571
SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type)
25722572
{
25732573
static_assert(isSingular<T>);
2574-
return emitInstMemoized(
2574+
return emitInst(
25752575
getSection(SpvLogicalSectionID::ConstantsAndTypes),
25762576
inst,
25772577
SpvOpTypeNodePayloadArrayAMDX,
25782578
kResultID,
25792579
type);
25802580
}
25812581

2582+
// https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/AMD/SPV_AMDX_shader_enqueue.html#OpTypeNodePayloadArrayAMDX
2583+
SpvInst* emitOpConstantString(IRInst* inst, const UnownedStringSlice& str)
2584+
{
2585+
return emitInstMemoized(
2586+
getSection(SpvLogicalSectionID::ConstantsAndTypes),
2587+
inst,
2588+
SpvOpConstantStringAMDX,
2589+
kResultID,
2590+
str);
2591+
}
2592+
2593+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorateId
2594+
template<typename T1, typename T2>
2595+
SpvInst* emitOpDecoratePayloadNodeName(IRInst* inst, const T1& target, const T2& id)
2596+
{
2597+
return emitInst(
2598+
getSection(SpvLogicalSectionID::Annotations),
2599+
inst,
2600+
SpvOpDecorateId,
2601+
target,
2602+
SpvDecorationPayloadNodeNameAMDX,
2603+
id);
2604+
}
2605+
25822606
#endif // SLANG_IN_SPIRV_EMIT_CONTEXT

source/slang/slang-emit-spirv.cpp

+20-2
Original file line numberDiff line numberDiff line change
@@ -1542,8 +1542,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
15421542
else if (storageClass == SpvStorageClassNodePayloadAMDX)
15431543
{
15441544
auto spvValueType = ensureInst(valueType);
1545-
auto spvNodePayloadType = emitOpTypeNodePayloadArray(inst, spvValueType);
1546-
valueTypeId = getID(spvNodePayloadType);
1545+
valueTypeId = getID(spvValueType);
15471546
}
15481547
else
15491548
{
@@ -1909,6 +1908,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
19091908
case kIROp_IndicesType:
19101909
case kIROp_PrimitivesType:
19111910
return nullptr;
1911+
case kIROp_SPIRVNodePayloadArrayType:
1912+
if (auto nodePayloadArrayType = as<IRSPIRVNodePayloadArrayType>(inst))
1913+
{
1914+
auto newType = emitOpTypeNodePayloadArray(
1915+
inst, nodePayloadArrayType->getRecordType());
1916+
1917+
Slang::StringBuilder str;
1918+
str << "NodeID_" << uint32_t(nodePayloadArrayType->getNodeID()->getValue());
1919+
SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice());
1920+
(void)spvStr;
1921+
1922+
auto r = emitOpDecoratePayloadNodeName(
1923+
nullptr,
1924+
newType,
1925+
spvStr);
1926+
(void)r;
1927+
return newType;
1928+
}
19121929
default:
19131930
{
19141931
if (as<IRSPIRVAsmOperand>(inst))
@@ -7750,6 +7767,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
77507767
return getSection(SpvLogicalSectionID::Annotations);
77517768
case SpvOpTypeNodePayloadArrayAMDX:
77527769
return getSection(SpvLogicalSectionID::ConstantsAndTypes);
7770+
77537771
default:
77547772
return defaultParent;
77557773
}

source/slang/slang-ir-inst-defs.h

+3
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE)
265265
// A type that identifies it's contained type as being emittable as `spirv_literal.
266266
INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE)
267267

268+
// WorkGraphs
269+
INST(SPIRVNodePayloadArrayType, spirvNodePayloadArrayType, 2, HOISTABLE)
270+
268271
// A TypeType-typed IRValue represents a IRType.
269272
// It is used to represent a type parameter/argument in a generics.
270273
INST(TypeType, type_t, 0, HOISTABLE)

source/slang/slang-ir.h

+8
Original file line numberDiff line numberDiff line change
@@ -1752,6 +1752,14 @@ struct IRSPIRVLiteralType : IRType
17521752
IRType* getValueType() { return static_cast<IRType*>(getOperand(0)); }
17531753
};
17541754

1755+
struct IRSPIRVNodePayloadArrayType : IRType
1756+
{
1757+
IR_LEAF_ISA(SPIRVNodePayloadArrayType)
1758+
1759+
IRType* getRecordType() { return static_cast<IRType*>(getOperand(0)); }
1760+
IRIntLit* getNodeID() { return static_cast<IRIntLit*>(getOperand(1)); }
1761+
};
1762+
17551763
struct IRPtrTypeBase : IRType
17561764
{
17571765
IRType* getValueType() { return (IRType*)getOperand(0); }

tests/workgraphs/consumer.slang

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,35 @@
1-
//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry main -skip-spirv-validation
1+
//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation
2+
3+
RWStructuredBuffer<int> outputBuffer;
4+
25
struct RecordData
36
{
47
int myData;
58
};
69

710
[shader("compute")]
811
[numthreads(1, 1, 1)]
9-
void main(uint3 dispatchThreadId : SV_GroupThreadID)
12+
void computeMain(uint3 dispatchThreadId : SV_GroupThreadID)
1013
{
1114
spirv_asm
1215
{
13-
OpExecutionMode $main ShaderIndexAMDX $(0);
14-
OpExecutionMode $main StaticNumWorkgroupsAMDX $(1) $(1) $(1);
16+
OpExecutionMode $computeMain ShaderIndexAMDX $(0);
17+
OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1);
1518
};
1619

17-
DispatchNodeInputRecord<RecordData> inputData;
20+
//TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute of the entry point.
21+
// Until it is implemented properly, we will use an int-type generic argument.
22+
//
23+
#define myNodeID 0
24+
DispatchNodeInputRecord<RecordData, myNodeID> inputData;
1825

19-
let recordData = inputData.Get();
20-
int myData = recordData.myData;
26+
int myData = inputData.Get().myData;
27+
outputBuffer[dispatchThreadId.x] = myData;
2128
}
2229

30+
//CHK: OpCapability ShaderEnqueueAMDX
31+
//CHK: OpExtension "SPV_AMDX_shader_enqueue"
32+
2333
//CHK: ; Types, variables and constants
2434
//CHK: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1
2535
//CHK: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]]

tests/workgraphs/producer.slang

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation
2+
3+
RWStructuredBuffer<int> inputBuffer;
4+
5+
struct RecordData
6+
{
7+
int myData;
8+
}
9+
10+
[shader("compute")]
11+
[numthreads(1, 1, 1)]
12+
void computeMain()
13+
{
14+
spirv_asm
15+
{
16+
OpExecutionMode $computeMain ShaderIndexAMDX $(0);
17+
OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1);
18+
};
19+
20+
//TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute on "node" variable.
21+
// Until it is implemented properly, we will use an int-type generic argument.
22+
//
23+
#define myNodeID 0
24+
ThreadNodeOutputRecords<RecordData, myNodeID> node = { 1 };
25+
26+
node.Get().myData = inputBuffer[0];
27+
node.OutputComplete();
28+
}
29+
30+
//CHK: OpCapability ShaderEnqueueAMDX
31+
//CHK: OpExtension "SPV_AMDX_shader_enqueue"
32+
33+
//CHK-DAG: [[RecordType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData
34+
//CHK-DAG: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType]]
35+
//CHK-DAG: [[NodeID:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_0"
36+
//CHK-DAG: OpDecorateId [[RecordType]] PayloadNodeNameAMDX [[NodeID]]
37+
38+
//CHK-NOT: = OpTypeNodePayloadArrayAMDX
39+
//CHK-NOT: = OpConstantStringAMDX
40+
//CHK-NOT: OpDecorateID {{.*}} PayloadNodeNameAMDX
41+
42+
// ThreadNodeOutputRecords::__init()
43+
//CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType]] %
44+
//CHK-NOT: = OpAllocateNodePayloadsAMDX
45+
46+
// ThreadNodeOutputRecords::Get()
47+
//CHK: [[Load:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType]] %var
48+
//CHK: = OpAccessChain [[PtrType]] [[Load]] %
49+
50+
// ThreadNodeOutputRecords::OutputComplete()
51+
//CHK: OpEnqueueNodePayloadsAMDX [[Alloc]]
52+
//CHK-NOT: OpEnqueueNodePayloadsAMDX
53+

0 commit comments

Comments
 (0)