From 31fdf06f8384b1ae858218a7894528b0b5100432 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Tue, 25 Feb 2025 07:22:20 -0800 Subject: [PATCH 01/13] Allow duplication of OpTypeNodePayloadArrayAMDX This commit is mainly to allow OpTypeNodePayloadArrayAMDX to be duplicated in SPIRV when using WorkGraphs. AMD spec requires that there should be dedicated type declaration with OpTypeNodePayloadArrayAMDX for each "node". And each of them should be decorated with PayloadNodeNameAMDX. https://docs.vulkan.org/features/latest/features/proposals/VK_AMDX_shader_enqueue.html This behavior is very different from how "types" are declared. Normally types are declared only once and they are shared across the whole program. But for the work-graphs, each "RecordType" belongs to each node and they need to be declared for each node. Consider the following HLSL example where a producer triggers two consumer nodes. Each node will be differentiated by "NodeID" attribute. ``` [Shader("node")] [NodeLaunch("broadcasting")] [NumThreads(4,1,1)] void myFancyNode(..., [NodeID("myNode1")] NodeOutput myRecords1, [NodeID("myNode2")] NodeOutput myRecords2, ...) { ThreadNodeOutputRecords myRecord1 = myRecords1.GetThreadNodeOutputRecords(1); myRecord1.myData = 1; myRecord1.OutputCompelete(); ThreadNodeOutputRecords myRecord2 = myRecords2.GetThreadNodeOutputRecords(1); myRecord2.myData = 1; myRecord2.OutputCompelete(); } ``` Even though both are using the same type, "MY_RECORD", each of them needs to be declared separately. The generated SPIR-V code for the HLSL above is something like the following, ``` %str_node1 = OpConstantStringAMDX "myNode1"; %str_node2 = OpConstantStringAMDX "myNode1"; %type_node1 = OpTypeNodePayloadArrayAMDX %type_myRecord; %type_node2 = OpTypeNodePayloadArrayAMDX %type_myRecord; OpDecorateId %type_node1 PayloadNodeNameAMDX %str_node1; OpDecorateId %type_node2 PayloadNodeNameAMDX %str_node2; ``` Note that `%type_node1` and `%type_node2` are equivalent, but they must be declared separately because each of them will be decorated with different nodeID string. --- source/slang/core.meta.slang | 23 ++++++++- source/slang/hlsl.meta.slang | 79 +++++++++++++++++++++++++---- source/slang/slang-emit-spirv-ops.h | 26 +++++++++- source/slang/slang-emit-spirv.cpp | 22 +++++++- source/slang/slang-ir-inst-defs.h | 3 ++ source/slang/slang-ir.h | 8 +++ tests/workgraphs/consumer.slang | 24 ++++++--- tests/workgraphs/producer.slang | 53 +++++++++++++++++++ 8 files changed, 215 insertions(+), 23 deletions(-) create mode 100644 tests/workgraphs/producer.slang diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index e2fb8bbf27..9c75c55876 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -4211,6 +4211,25 @@ attribute_syntax [QuadDerivatives] : QuadDerivativesAttribute; __attributeTarget(FuncDecl) attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute; -__generic -typealias NodePayloadPtr = Ptr; +// Work-graphs + +/// @internal +/// __SPIRVNodePayloadArray is to emit OpTypeNodePayloadArrayAMDX and +/// SpvDecorationPayloadNodeNameAMDX for it. +/// +__generic +__intrinsic_type($(kIROp_SPIRVNodePayloadArrayType)) +struct __SPIRVNodePayloadArray; + +/// @internal +/// NodePayloadArrayPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to an array of work-graph nodes +/// +__generic +typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray, $( (uint64_t)AddressSpace::NodePayloadAMDX)>; + +/// @internal +/// NodePayloadPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to a `RecordType`. +/// +__generic +typealias NodePayloadPtr = Ptr; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5f0edb3f90..54fe0f487a 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24347,28 +24347,85 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue) // Work-graphs +///@public +/// Set of zero or more records for each thread. +/// +__generic +struct ThreadNodeOutputRecords +{ + NodePayloadArrayPtr alloc; + + [ForceInline] + [require(spirv)] // [require(spirv, thread)] + __init(int recordCount, int nodeIndex = 0) + { + alloc = spirv_asm + { + result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; + }; + } + + [ForceInline] + [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] + void OutputComplete() + { + __target_switch + { + case hlsl: __intrinsic_asm ".OutputComplete"; + case spirv: spirv_asm { OpEnqueueNodePayloadsAMDX $alloc; }; + } + } + + __subscript(uint index) -> NodePayloadPtr + { + [ForceInline] + get { return Get(index); } + } + + [ForceInline] + [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] + NodePayloadPtr Get(int index = 0) + { + __target_switch + { + case hlsl: __intrinsic_asm ".Get"; + case spirv: + NodePayloadArrayPtr ptr = spirv_asm + { + %ptr = OpTypePointer Function $$__SPIRVNodePayloadArray; + %var = OpVariable %ptr Function; + OpStore %var $alloc; + %Load = OpLoad $$NodePayloadArrayPtr %var; + result : $$NodePayloadArrayPtr = OpAccessChain %Load $index + }; + return reinterpret>(ptr); + } + } +}; + //@public: -/// read-only input to Broadcasting launch node. -__generic -//TODO: DispatchNodeInputRecord should be available only for broadcasting node shader. -//[require(broadcasting_node)] -[require(spirv)] +/// Read-only input to Broadcasting launch node. +/// +__generic struct DispatchNodeInputRecord { /// Provide an access to a record object that only holds a single record. - NodePayloadPtr Get() + [ForceInline] + [require(hlsl_spirv)] // [require(hlsl_spirv, broadcasting)] + NodePayloadPtr Get() { int index = 0; __target_switch { + case hlsl: __intrinsic_asm ".Get"; case spirv: - return spirv_asm + NodePayloadArrayPtr ptr = spirv_asm { - %in_payload_t = OpTypeNodePayloadArrayAMDX $$T; - %in_payload_ptr_t = OpTypePointer NodePayloadAMDX %in_payload_t; - %var = OpVariable %in_payload_ptr_t NodePayloadAMDX; - result : $$NodePayloadPtr = OpAccessChain %var $index; + %ptr = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray; + %var = OpVariable %ptr NodePayloadAMDX; + result : $$NodePayloadArrayPtr = OpAccessChain %var $index; }; + return reinterpret>(ptr); } } }; diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 880f6b083c..68d545702a 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -2585,7 +2585,7 @@ template SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type) { static_assert(isSingular); - return emitInstMemoized( + return emitInst( getSection(SpvLogicalSectionID::ConstantsAndTypes), inst, SpvOpTypeNodePayloadArrayAMDX, @@ -2593,4 +2593,28 @@ SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type) type); } +// https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/AMD/SPV_AMDX_shader_enqueue.html#OpTypeNodePayloadArrayAMDX +SpvInst* emitOpConstantString(IRInst* inst, const UnownedStringSlice& str) +{ + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpConstantStringAMDX, + kResultID, + str); +} + +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorateId +template +SpvInst* emitOpDecoratePayloadNodeName(IRInst* inst, const T1& target, const T2& id) +{ + return emitInst( + getSection(SpvLogicalSectionID::Annotations), + inst, + SpvOpDecorateId, + target, + SpvDecorationPayloadNodeNameAMDX, + id); +} + #endif // SLANG_IN_SPIRV_EMIT_CONTEXT diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 18016d5fe7..40bb6e2cc9 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1585,8 +1585,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex else if (storageClass == SpvStorageClassNodePayloadAMDX) { auto spvValueType = ensureInst(valueType); - auto spvNodePayloadType = emitOpTypeNodePayloadArray(inst, spvValueType); - valueTypeId = getID(spvNodePayloadType); + valueTypeId = getID(spvValueType); } else { @@ -1929,6 +1928,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_IndicesType: case kIROp_PrimitivesType: return nullptr; + case kIROp_SPIRVNodePayloadArrayType: + if (auto nodePayloadArrayType = as(inst)) + { + auto newType = emitOpTypeNodePayloadArray( + inst, nodePayloadArrayType->getRecordType()); + + Slang::StringBuilder str; + str << "NodeID_" << uint32_t(nodePayloadArrayType->getNodeID()->getValue()); + SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice()); + (void)spvStr; + + auto r = emitOpDecoratePayloadNodeName( + nullptr, + newType, + spvStr); + (void)r; + return newType; + } default: { if (as(inst)) @@ -7878,6 +7895,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return getSection(SpvLogicalSectionID::Annotations); case SpvOpTypeNodePayloadArrayAMDX: return getSection(SpvLogicalSectionID::ConstantsAndTypes); + default: return defaultParent; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bd590f08f1..6e35830f63 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -265,6 +265,9 @@ INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE) +// WorkGraphs +INST(SPIRVNodePayloadArrayType, spirvNodePayloadArrayType, 2, HOISTABLE) + // A TypeType-typed IRValue represents a IRType. // It is used to represent a type parameter/argument in a generics. INST(TypeType, type_t, 0, HOISTABLE) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index dbc66c6a37..822bdd29b1 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1757,6 +1757,14 @@ struct IRSPIRVLiteralType : IRType IRType* getValueType() { return static_cast(getOperand(0)); } }; +struct IRSPIRVNodePayloadArrayType : IRType +{ + IR_LEAF_ISA(SPIRVNodePayloadArrayType) + + IRType* getRecordType() { return static_cast(getOperand(0)); } + IRIntLit* getNodeID() { return static_cast(getOperand(1)); } +}; + struct IRPtrTypeBase : IRType { IRType* getValueType() { return (IRType*)getOperand(0); } diff --git a/tests/workgraphs/consumer.slang b/tests/workgraphs/consumer.slang index 5e211a2a13..0de90fc176 100644 --- a/tests/workgraphs/consumer.slang +++ b/tests/workgraphs/consumer.slang @@ -1,4 +1,7 @@ -//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry main -skip-spirv-validation +//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation + +RWStructuredBuffer outputBuffer; + struct RecordData { int myData; @@ -6,20 +9,27 @@ struct RecordData [shader("compute")] [numthreads(1, 1, 1)] -void main(uint3 dispatchThreadId : SV_GroupThreadID) +void computeMain(uint3 dispatchThreadId : SV_GroupThreadID) { spirv_asm { - OpExecutionMode $main ShaderIndexAMDX $(0); - OpExecutionMode $main StaticNumWorkgroupsAMDX $(1) $(1) $(1); + OpExecutionMode $computeMain ShaderIndexAMDX $(0); + OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1); }; - DispatchNodeInputRecord inputData; + //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute of the entry point. + // Until it is implemented properly, we will use an int-type generic argument. + // + #define myNodeID 0 + DispatchNodeInputRecord inputData; - let recordData = inputData.Get(); - int myData = recordData.myData; + int myData = inputData.Get().myData; + outputBuffer[dispatchThreadId.x] = myData; } +//CHK: OpCapability ShaderEnqueueAMDX +//CHK: OpExtension "SPV_AMDX_shader_enqueue" + //CHK: ; Types, variables and constants //CHK: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1 //CHK: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] diff --git a/tests/workgraphs/producer.slang b/tests/workgraphs/producer.slang new file mode 100644 index 0000000000..d3e71320ba --- /dev/null +++ b/tests/workgraphs/producer.slang @@ -0,0 +1,53 @@ +//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation + +RWStructuredBuffer inputBuffer; + +struct RecordData +{ + int myData; +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + spirv_asm + { + OpExecutionMode $computeMain ShaderIndexAMDX $(0); + OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1); + }; + + //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute on "node" variable. + // Until it is implemented properly, we will use an int-type generic argument. + // + #define myNodeID 0 + ThreadNodeOutputRecords node = { 1 }; + + node.Get().myData = inputBuffer[0]; + node.OutputComplete(); +} + +//CHK: OpCapability ShaderEnqueueAMDX +//CHK: OpExtension "SPV_AMDX_shader_enqueue" + +//CHK-DAG: [[RecordType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData +//CHK-DAG: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType]] +//CHK-DAG: [[NodeID:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_0" +//CHK-DAG: OpDecorateId [[RecordType]] PayloadNodeNameAMDX [[NodeID]] + +//CHK-NOT: = OpTypeNodePayloadArrayAMDX +//CHK-NOT: = OpConstantStringAMDX +//CHK-NOT: OpDecorateID {{.*}} PayloadNodeNameAMDX + +// ThreadNodeOutputRecords::__init() +//CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType]] % +//CHK-NOT: = OpAllocateNodePayloadsAMDX + +// ThreadNodeOutputRecords::Get() +//CHK: [[Load:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType]] %var +//CHK: = OpAccessChain [[PtrType]] [[Load]] % + +// ThreadNodeOutputRecords::OutputComplete() +//CHK: OpEnqueueNodePayloadsAMDX [[Alloc]] +//CHK-NOT: OpEnqueueNodePayloadsAMDX + From b3604a4c069a4960fc6a31f734291887ea232e7c Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:33:55 -0800 Subject: [PATCH 02/13] Adding comments --- source/slang/slang-emit-spirv-ops.h | 2 +- source/slang/slang-emit-spirv.cpp | 3 ++- source/slang/slang-ir.h | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 68d545702a..c44229b512 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -2604,7 +2604,7 @@ SpvInst* emitOpConstantString(IRInst* inst, const UnownedStringSlice& str) str); } -// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorateId +// https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMDX_shader_enqueue.html#_decorations template SpvInst* emitOpDecoratePayloadNodeName(IRInst* inst, const T1& target, const T2& id) { diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 40bb6e2cc9..e2232a41fb 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1934,6 +1934,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex auto newType = emitOpTypeNodePayloadArray( inst, nodePayloadArrayType->getRecordType()); + // TODO: This is a temporary hack. + // The NodeID must come from an attribute [NodeID("name")]. Slang::StringBuilder str; str << "NodeID_" << uint32_t(nodePayloadArrayType->getNodeID()->getValue()); SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice()); @@ -7895,7 +7897,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return getSection(SpvLogicalSectionID::Annotations); case SpvOpTypeNodePayloadArrayAMDX: return getSection(SpvLogicalSectionID::ConstantsAndTypes); - default: return defaultParent; } diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 822bdd29b1..07120af681 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1762,6 +1762,8 @@ struct IRSPIRVNodePayloadArrayType : IRType IR_LEAF_ISA(SPIRVNodePayloadArrayType) IRType* getRecordType() { return static_cast(getOperand(0)); } + + // TODO: getNodeID needs to return `IRStringLit*`. IRIntLit* getNodeID() { return static_cast(getOperand(1)); } }; From 532d4c03398f56f4f18bd84d16ce41fa369f78ba Mon Sep 17 00:00:00 2001 From: slangbot Date: Thu, 27 Feb 2025 05:41:35 +0800 Subject: [PATCH 03/13] format code (#43) --- source/slang/slang-emit-spirv.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index e2232a41fb..b7d05f15c5 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1931,8 +1931,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_SPIRVNodePayloadArrayType: if (auto nodePayloadArrayType = as(inst)) { - auto newType = emitOpTypeNodePayloadArray( - inst, nodePayloadArrayType->getRecordType()); + auto newType = + emitOpTypeNodePayloadArray(inst, nodePayloadArrayType->getRecordType()); // TODO: This is a temporary hack. // The NodeID must come from an attribute [NodeID("name")]. @@ -1941,10 +1941,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice()); (void)spvStr; - auto r = emitOpDecoratePayloadNodeName( - nullptr, - newType, - spvStr); + auto r = emitOpDecoratePayloadNodeName(nullptr, newType, spvStr); (void)r; return newType; } From a9828d9b7b8d49188ca4d52d439968408877f3b1 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:53:51 -0800 Subject: [PATCH 04/13] Add one more test for multi-nodes --- tests/workgraphs/producer-multi-nodes.slang | 71 +++++++++++++++++++++ tests/workgraphs/producer.slang | 4 ++ 2 files changed, 75 insertions(+) create mode 100644 tests/workgraphs/producer-multi-nodes.slang diff --git a/tests/workgraphs/producer-multi-nodes.slang b/tests/workgraphs/producer-multi-nodes.slang new file mode 100644 index 0000000000..733bc68786 --- /dev/null +++ b/tests/workgraphs/producer-multi-nodes.slang @@ -0,0 +1,71 @@ +//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation + +RWStructuredBuffer inputBuffer; + +struct RecordData +{ + int myData; +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + spirv_asm + { + OpExecutionMode $computeMain ShaderIndexAMDX $(0); + OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1); + }; + + //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute on "node" variable. + // Until it is implemented properly, we will use an int-type generic argument. + // + #define myNodeID_1 0 + #define myNodeID_2 1 + + ThreadNodeOutputRecords node1 = { 1 }; + ThreadNodeOutputRecords node2 = { 1 }; + + node1.Get().myData = inputBuffer[0]; + node2.Get().myData = inputBuffer[0]; + + node1.OutputComplete(); + node2.OutputComplete(); +} + +//CHK: OpCapability ShaderEnqueueAMDX +//CHK: OpExtension "SPV_AMDX_shader_enqueue" + +//CHK: ; Annotations + +//CHK-DAG: [[RecordType1:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData +//CHK-DAG: [[RecordType2:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData +//CHK-DAG: [[PtrType1:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType1]] +//CHK-DAG: [[PtrType2:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType2]] +//CHK-DAG: [[NodeID_1:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_0" +//CHK-DAG: [[NodeID_2:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_1" +//CHK-DAG: OpDecorateId [[RecordType1]] PayloadNodeNameAMDX [[NodeID_1]] +//CHK-DAG: OpDecorateId [[RecordType2]] PayloadNodeNameAMDX [[NodeID_2]] + +//CHK-NOT: = OpTypeNodePayloadArrayAMDX +//CHK-NOT: = OpConstantStringAMDX +//CHK-NOT: OpDecorateID {{.*}} PayloadNodeNameAMDX + +//CHK: ; Function + +// ThreadNodeOutputRecords::__init() +//CHK-DAG: [[Alloc1:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType1]] % +//CHK-DAG: [[Alloc2:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType2]] % +//CHK-NOT: = OpAllocateNodePayloadsAMDX + +// ThreadNodeOutputRecords::Get() +//CHK-DAG: [[Load1:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType1]] %var +//CHK-DAG: [[Load2:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType2]] %var +//CHK-DAG: = OpAccessChain [[PtrType1]] [[Load1]] % +//CHK-DAG: = OpAccessChain [[PtrType2]] [[Load2]] % + +// ThreadNodeOutputRecords::OutputComplete() +//CHK-DAG: OpEnqueueNodePayloadsAMDX [[Alloc1]] +//CHK-DAG: OpEnqueueNodePayloadsAMDX [[Alloc2]] +//CHK-NOT: OpEnqueueNodePayloadsAMDX + diff --git a/tests/workgraphs/producer.slang b/tests/workgraphs/producer.slang index d3e71320ba..ed00b3af28 100644 --- a/tests/workgraphs/producer.slang +++ b/tests/workgraphs/producer.slang @@ -30,6 +30,8 @@ void computeMain() //CHK: OpCapability ShaderEnqueueAMDX //CHK: OpExtension "SPV_AMDX_shader_enqueue" +//CHK: ; Annotations + //CHK-DAG: [[RecordType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData //CHK-DAG: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType]] //CHK-DAG: [[NodeID:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_0" @@ -39,6 +41,8 @@ void computeMain() //CHK-NOT: = OpConstantStringAMDX //CHK-NOT: OpDecorateID {{.*}} PayloadNodeNameAMDX +//CHK: ; Function + // ThreadNodeOutputRecords::__init() //CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType]] % //CHK-NOT: = OpAllocateNodePayloadsAMDX From e63074f97923149e7629508996e7616f02a24abf Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Wed, 26 Feb 2025 16:37:28 -0800 Subject: [PATCH 05/13] Do not emit the name decorate, it needs to be emitted manually --- source/slang/slang-emit-spirv.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b7d05f15c5..a1ced18163 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1934,6 +1934,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex auto newType = emitOpTypeNodePayloadArray(inst, nodePayloadArrayType->getRecordType()); + #if 0 // TODO: This is a temporary hack. // The NodeID must come from an attribute [NodeID("name")]. Slang::StringBuilder str; @@ -1941,8 +1942,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice()); (void)spvStr; - auto r = emitOpDecoratePayloadNodeName(nullptr, newType, spvStr); - (void)r; + emitOpDecoratePayloadNodeName(nullptr, newType, spvStr); +#endif return newType; } default: From 93507e8ee103969af1fb9fa5e7d175ed8f3603dc Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Wed, 12 Mar 2025 20:46:47 -0700 Subject: [PATCH 06/13] Fix work-graph consumer issue --- source/slang/core.meta.slang | 14 ++-- source/slang/hlsl.meta.slang | 45 ++++++----- source/slang/slang-emit-spirv.cpp | 122 ++++++++++++++++++++++-------- tests/workgraphs/consumer.slang | 17 ++--- tests/workgraphs/producer.slang | 5 -- 5 files changed, 133 insertions(+), 70 deletions(-) diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 9c75c55876..60ed85fe28 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -4217,19 +4217,19 @@ attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute; /// __SPIRVNodePayloadArray is to emit OpTypeNodePayloadArrayAMDX and /// SpvDecorationPayloadNodeNameAMDX for it. /// -__generic +__generic __intrinsic_type($(kIROp_SPIRVNodePayloadArrayType)) struct __SPIRVNodePayloadArray; /// @internal /// NodePayloadArrayPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to an array of work-graph nodes /// -__generic -typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray, $( (uint64_t)AddressSpace::NodePayloadAMDX)>; +__generic +typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray, $( (uint64_t)AddressSpace::NodePayloadAMDX)>; -/// @internal -/// NodePayloadPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to a `RecordType`. +/// @public +/// FunctionPtr is a pointer type for "Function" storage-class and it points to a `BaseType`. /// -__generic -typealias NodePayloadPtr = Ptr; +__generic +typealias FunctionPtr = Ptr; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 54fe0f487a..01d2c13f4b 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24359,6 +24359,8 @@ struct ThreadNodeOutputRecords [require(spirv)] // [require(spirv, thread)] __init(int recordCount, int nodeIndex = 0) { + // TODO: It seems that unexpected SPIRV code gets emitted here when + // assign from spriv_asm to `alloc` such as `get_field_addr` alloc = spirv_asm { result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; @@ -24376,29 +24378,29 @@ struct ThreadNodeOutputRecords } } - __subscript(uint index) -> NodePayloadPtr + __subscript(uint index) -> FunctionPtr { [ForceInline] get { return Get(index); } } - [ForceInline] - [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] - NodePayloadPtr Get(int index = 0) + FunctionPtr Get(int index = 0) { __target_switch { case hlsl: __intrinsic_asm ".Get"; case spirv: - NodePayloadArrayPtr ptr = spirv_asm + let tmp = reinterpret>(alloc); + let ptr = spirv_asm { - %ptr = OpTypePointer Function $$__SPIRVNodePayloadArray; - %var = OpVariable %ptr Function; - OpStore %var $alloc; - %Load = OpLoad $$NodePayloadArrayPtr %var; - result : $$NodePayloadArrayPtr = OpAccessChain %Load $index + %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; + + %loaded = OpLoad $$__SPIRVNodePayloadArray $tmp; + OpStore %tempPtrVar %loaded; + + result : $$FunctionPtr = OpAccessChain %tempPtrVar $index }; - return reinterpret>(ptr); + return ptr; } } }; @@ -24410,22 +24412,27 @@ __generic struct DispatchNodeInputRecord { /// Provide an access to a record object that only holds a single record. - [ForceInline] - [require(hlsl_spirv)] // [require(hlsl_spirv, broadcasting)] - NodePayloadPtr Get() + FunctionPtr Get() { int index = 0; + __target_switch { case hlsl: __intrinsic_asm ".Get"; case spirv: - NodePayloadArrayPtr ptr = spirv_asm + let ptr = spirv_asm { - %ptr = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray; - %var = OpVariable %ptr NodePayloadAMDX; - result : $$NodePayloadArrayPtr = OpAccessChain %var $index; + %nodePtrType = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray; + %nodePtrVar = OpVariable %nodePtrType NodePayloadAMDX; + + %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; + + %loaded = OpLoad $$__SPIRVNodePayloadArray %nodePtrVar; + OpStore %tempPtrVar %loaded; + + result : $$FunctionPtr = OpAccessChain %tempPtrVar $index; }; - return reinterpret>(ptr); + return ptr; } } }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index a1ced18163..b025ad84e4 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -292,6 +292,8 @@ void SpvInstParent::addInst(SpvInst* inst) SLANG_ASSERT(inst); SLANG_ASSERT(!inst->nextSibling); + inst->parent = this; + if (m_firstChild == nullptr) { m_firstChild = m_lastChild = inst; @@ -304,7 +306,6 @@ void SpvInstParent::addInst(SpvInst* inst) // m_lastChild->nextSibling = inst; inst->prevSibling = m_lastChild; - inst->parent = this; m_lastChild = inst; } @@ -492,6 +493,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex /// The next destination `` to allocate. SpvWord m_nextID = 1; + // This keeps track of the named IDs used in the asm block + Dictionary> m_idMaps; + OrderedDictionary m_forwardDeclaredPointers; SpvInst* m_nullDwarfExpr = nullptr; @@ -1582,11 +1586,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { valueTypeId = getIRInstSpvID(valueType); } - else if (storageClass == SpvStorageClassNodePayloadAMDX) - { - auto spvValueType = ensureInst(valueType); - valueTypeId = getID(spvValueType); - } else { auto spvValueType = ensureInst(valueType); @@ -1933,17 +1932,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto newType = emitOpTypeNodePayloadArray(inst, nodePayloadArrayType->getRecordType()); - - #if 0 - // TODO: This is a temporary hack. - // The NodeID must come from an attribute [NodeID("name")]. - Slang::StringBuilder str; - str << "NodeID_" << uint32_t(nodePayloadArrayType->getNodeID()->getValue()); - SpvInst* spvStr = emitOpConstantString(nullptr, str.getUnownedSlice()); - (void)spvStr; - - emitOpDecoratePayloadNodeName(nullptr, newType, spvStr); -#endif return newType; } default: @@ -3007,7 +2995,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex switch (inst->getOp()) { case kIROp_Var: - emitLocalInst(spvBlock, inst); + emitLocalInst(spvBlock, spvBlock, inst); break; case kIROp_DebugVar: // Declare an ordinary local variable for debugDeclare association @@ -3015,6 +3003,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // values to upon a `kIROp_DebugValue` inst. emitDebugVarBackingLocalVarDeclaration(spvBlock, as(inst)); break; + case kIROp_SPIRVAsm: + emitLocalInst(spvBlock, spvBlock, inst); + break; } } } @@ -3095,7 +3086,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Skip vars because they are already emitted. if (as(irInst)) continue; - emitLocalInst(spvBlock, irInst); + emitLocalInst(spvBlock, nullptr, irInst); if (irInst->getOp() == kIROp_loop) pendingLoopInsts.add(as(irInst)); if (irInst->getOp() == kIROp_discard && !shouldEmitDiscardAsDemote()) @@ -3472,7 +3463,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // a known parent (the basic block that contains them). /// Emit an instruction that is local to the body of the given `parent`. - SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) + SpvInst* emitLocalInst(SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst) { SpvInst* result = nullptr; switch (inst->getOp()) @@ -3604,7 +3595,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_Geq: case kIROp_Rsh: case kIROp_Lsh: - result = emitArithmetic(parent, inst); + result = emitArithmetic(parent, firstLabel, inst); break; case kIROp_CastDescriptorHandleToUInt2: case kIROp_CastUInt2ToDescriptorHandle: @@ -3855,7 +3846,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex result = emitOpUndef(parent, inst, inst->getDataType()); break; case kIROp_SPIRVAsm: - result = emitSPIRVAsm(parent, as(inst)); + result = emitSPIRVAsm(parent, firstLabel, as(inst)); break; case kIROp_ImageLoad: result = emitImageLoad(parent, as(inst)); @@ -4466,6 +4457,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex break; } } + + // Pass in global OpVariable as interface to the entry point + // TODO: Pass in only when they are used by the entry point + for (auto entryInterface : m_entryPointInterfaces) + { + SpvInst* spvInst; + if (m_mapIRInstToSpvInst.tryGetValue(entryInterface, spvInst)) + { + params.add(spvInst); + } + } + emitOpEntryPoint(section, decoration, spvStage, dstID, name, params); // Stage specific execution mode and capability declarations. @@ -7265,7 +7268,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } - SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + SpvInst* emitArithmetic(SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst) { if (const auto matrixType = as(inst->getDataType())) { @@ -7284,7 +7287,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (as(originalOperand->getDataType())) { auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); + emitLocalInst(parent, firstLabel, operand); operands.add(operand); } else @@ -7850,12 +7853,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return debugFunc; } - SpvInst* emitSPIRVAsm(SpvInstParent* parent, IRSPIRVAsm* inst) + SpvInst* emitSPIRVAsm(SpvInstParent* parent, SpvInstParent* firstLabel, IRSPIRVAsm* inst) { SpvInst* last = nullptr; - // This keeps track of the named IDs used in the asm block - Dictionary idMap; + auto &idMap = m_idMaps.getOrAddValue(inst, Dictionary()); for (const auto spvInst : inst->getInsts()) { @@ -8033,6 +8035,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (spvInst->getOpcodeOperand()->getOp() == kIROp_SPIRVAsmOperandTruncate) { + // Nothing to emit to the first OpLabel + if (firstLabel) + continue; + const auto getSlangType = [&](IRSPIRVAsmOperand* operand) -> IRType* { switch (operand->getOp()) @@ -8196,7 +8202,46 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex default: break; } - const auto opParent = parentForOpCode(opcode, parent); + + IRStringLit* resultID = nullptr; + SpvInstParent* opParent = nullptr; + if (opcode == SpvOpVariable) + { + // SPIRV validator says, + // "All OpVariable instructions in a function must be the first instructions in the first block." + opParent = firstLabel; + + auto opStorageClass = spvInst->getOperand(3); + if (opStorageClass && opStorageClass->getOp() == kIROp_SPIRVAsmOperandEnum) + { + if (auto intLit = cast(opStorageClass->getOperand(0))) + { + switch (SpvStorageClass(intLit->getValue())) + { + case SpvStorageClassNodePayloadAMDX: + requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX); + ensureExtensionDeclaration( + UnownedStringSlice("SPV_AMDX_shader_enqueue")); + + opParent = getSection(SpvLogicalSectionID::ConstantsAndTypes); + + if (auto resultOperand = + cast(spvInst->getOperand(2))) + { + resultID = cast(resultOperand->getValue()); + } + + m_entryPointInterfaces.add(spvInst); + break; + } + } + } + } + if (opParent == nullptr) + { + opParent = parentForOpCode(opcode, parent); + } + const auto opInfo = m_grammarInfo->opInfos.lookup(opcode); // TODO: handle resultIdIndex == 1, for constants @@ -8204,6 +8249,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex opParent == getSection(SpvLogicalSectionID::ConstantsAndTypes) && opInfo && opInfo->resultIdIndex == 0; + // SpvOpVariable must appear at the first block of the function + // And it may depends on other memoized instructions. + if ((opcode == SpvOpVariable || memoize) == (firstLabel == nullptr)) + continue; + // We want the "result instruction" to refer to the top level // block which assumes its value, the others are free to refer // to whatever, so just use the internal spv inst rep @@ -8254,15 +8304,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex for (const auto operand : spvInst->getSPIRVOperands()) emitSpvAsmOperand(operand); }); + + // TODO: We may be able to simplify without checking the string. + if (resultID) + { + SpvWord id; + if (last->id == 0 && idMap.tryGetValue(resultID->getStringSlice(), id)) + last->id = id; + } } } } - for (const auto& [name, id] : idMap) - emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name); + if (firstLabel == nullptr) + { + for (const auto& [name, id] : idMap) + emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name); + } return last; } + HashSet m_entryPointInterfaces; OrderedHashSet m_capabilities; void requireSPIRVCapability(SpvCapability capability) diff --git a/tests/workgraphs/consumer.slang b/tests/workgraphs/consumer.slang index 0de90fc176..74488a7e0c 100644 --- a/tests/workgraphs/consumer.slang +++ b/tests/workgraphs/consumer.slang @@ -11,12 +11,6 @@ struct RecordData [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadId : SV_GroupThreadID) { - spirv_asm - { - OpExecutionMode $computeMain ShaderIndexAMDX $(0); - OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1); - }; - //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute of the entry point. // Until it is implemented properly, we will use an int-type generic argument. // @@ -34,9 +28,14 @@ void computeMain(uint3 dispatchThreadId : SV_GroupThreadID) //CHK: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1 //CHK: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] //CHK: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] -//CHK: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] +//CHK: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] +//CHK: [[NodePtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[NodePtrType]] NodePayloadAMDX +//CHK: [[TempPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType]] +//CHK: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] //CHK: ; Function -//CHK: [[VarName:%[a-zA-Z_0-9]+]] = OpVariable [[PtrType]] NodePayloadAMDX -//CHK: = OpAccessChain [[PtrType]] [[VarName]] +//CHK: [[TempPtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType]] Function +//CHK: [[Loaded:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType]] [[NodePtrVar]] +//CHK: OpStore [[TempPtrVar]] [[Loaded]] +//CHK: = OpAccessChain [[FuncPtrType]] [[TempPtrVar]] diff --git a/tests/workgraphs/producer.slang b/tests/workgraphs/producer.slang index ed00b3af28..a3b3f85f80 100644 --- a/tests/workgraphs/producer.slang +++ b/tests/workgraphs/producer.slang @@ -11,11 +11,6 @@ struct RecordData [numthreads(1, 1, 1)] void computeMain() { - spirv_asm - { - OpExecutionMode $computeMain ShaderIndexAMDX $(0); - OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1); - }; //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute on "node" variable. // Until it is implemented properly, we will use an int-type generic argument. From 95e5980c9b39e9a40fbf1341f330a9cfa1e03ba8 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 13 Mar 2025 11:28:52 -0700 Subject: [PATCH 07/13] Fix producer code --- source/slang/hlsl.meta.slang | 23 +++++++++++++---------- tests/workgraphs/producer.slang | 15 +++++++-------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 01d2c13f4b..8ac8478d91 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24356,15 +24356,19 @@ struct ThreadNodeOutputRecords NodePayloadArrayPtr alloc; [ForceInline] - [require(spirv)] // [require(spirv, thread)] + [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] __init(int recordCount, int nodeIndex = 0) { - // TODO: It seems that unexpected SPIRV code gets emitted here when - // assign from spriv_asm to `alloc` such as `get_field_addr` - alloc = spirv_asm + __target_switch { - result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; - }; + case hlsl: return; + case spirv: + alloc = spirv_asm + { + result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; + }; + return; + } } [ForceInline] @@ -24384,23 +24388,22 @@ struct ThreadNodeOutputRecords get { return Get(index); } } + [ForceInline] FunctionPtr Get(int index = 0) { __target_switch { case hlsl: __intrinsic_asm ".Get"; case spirv: - let tmp = reinterpret>(alloc); - let ptr = spirv_asm + return spirv_asm { %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; - %loaded = OpLoad $$__SPIRVNodePayloadArray $tmp; + %loaded = OpLoad $$__SPIRVNodePayloadArray $alloc; OpStore %tempPtrVar %loaded; result : $$FunctionPtr = OpAccessChain %tempPtrVar $index }; - return ptr; } } }; diff --git a/tests/workgraphs/producer.slang b/tests/workgraphs/producer.slang index a3b3f85f80..a56e624716 100644 --- a/tests/workgraphs/producer.slang +++ b/tests/workgraphs/producer.slang @@ -1,5 +1,4 @@ //TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation - RWStructuredBuffer inputBuffer; struct RecordData @@ -27,10 +26,9 @@ void computeMain() //CHK: ; Annotations -//CHK-DAG: [[RecordType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData -//CHK-DAG: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType]] -//CHK-DAG: [[NodeID:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_0" -//CHK-DAG: OpDecorateId [[RecordType]] PayloadNodeNameAMDX [[NodeID]] +//CHK-DAG: [[PayloadRecordType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData +//CHK-DAG: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadRecordType]] +//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function %RecordData //CHK-NOT: = OpTypeNodePayloadArrayAMDX //CHK-NOT: = OpConstantStringAMDX @@ -39,12 +37,13 @@ void computeMain() //CHK: ; Function // ThreadNodeOutputRecords::__init() -//CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType]] % +//CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType]] % //CHK-NOT: = OpAllocateNodePayloadsAMDX // ThreadNodeOutputRecords::Get() -//CHK: [[Load:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType]] %var -//CHK: = OpAccessChain [[PtrType]] [[Load]] % +//CHK: [[Load:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadRecordType]] +//CHK: OpStore [[Var:%[a-zA-Z_0-9]+]] [[Load]] +//CHK: = OpAccessChain [[FuncPtrType]] [[Var]] % // ThreadNodeOutputRecords::OutputComplete() //CHK: OpEnqueueNodePayloadsAMDX [[Alloc]] From c0f827ec796298a0b566d9d8d2de8a57bd71f1e4 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 14 Mar 2025 07:52:17 -0700 Subject: [PATCH 08/13] Fix the producer side problems --- source/slang/hlsl.meta.slang | 18 +++++---- tests/workgraphs/producer-multi-nodes.slang | 45 ++++++++++----------- tests/workgraphs/producer.slang | 26 ++++++------ 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 8ac8478d91..62ed280131 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24353,7 +24353,7 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue) __generic struct ThreadNodeOutputRecords { - NodePayloadArrayPtr alloc; + FunctionPtr<__SPIRVNodePayloadArray> alloc; [ForceInline] [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] @@ -24365,7 +24365,14 @@ struct ThreadNodeOutputRecords case spirv: alloc = spirv_asm { - result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; + %alloc = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; + + %loaded = OpLoad $$__SPIRVNodePayloadArray %alloc; + + %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; + OpStore %tempPtrVar %loaded; + + result : $$FunctionPtr<__SPIRVNodePayloadArray> = OpCopyObject %tempPtrVar; }; return; } @@ -24397,12 +24404,7 @@ struct ThreadNodeOutputRecords case spirv: return spirv_asm { - %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; - - %loaded = OpLoad $$__SPIRVNodePayloadArray $alloc; - OpStore %tempPtrVar %loaded; - - result : $$FunctionPtr = OpAccessChain %tempPtrVar $index + result : $$FunctionPtr = OpAccessChain $alloc $index }; } } diff --git a/tests/workgraphs/producer-multi-nodes.slang b/tests/workgraphs/producer-multi-nodes.slang index 733bc68786..bf2a4f8a04 100644 --- a/tests/workgraphs/producer-multi-nodes.slang +++ b/tests/workgraphs/producer-multi-nodes.slang @@ -36,36 +36,33 @@ void computeMain() //CHK: OpCapability ShaderEnqueueAMDX //CHK: OpExtension "SPV_AMDX_shader_enqueue" -//CHK: ; Annotations - -//CHK-DAG: [[RecordType1:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData -//CHK-DAG: [[RecordType2:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData -//CHK-DAG: [[PtrType1:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType1]] -//CHK-DAG: [[PtrType2:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[RecordType2]] -//CHK-DAG: [[NodeID_1:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_0" -//CHK-DAG: [[NodeID_2:%[a-zA-Z_0-9]+]] = OpConstantStringAMDX "NodeID_1" -//CHK-DAG: OpDecorateId [[RecordType1]] PayloadNodeNameAMDX [[NodeID_1]] -//CHK-DAG: OpDecorateId [[RecordType2]] PayloadNodeNameAMDX [[NodeID_2]] - -//CHK-NOT: = OpTypeNodePayloadArrayAMDX -//CHK-NOT: = OpConstantStringAMDX -//CHK-NOT: OpDecorateID {{.*}} PayloadNodeNameAMDX +//CHK: ; Types, variables and constants + +//CHK-DAG: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1 +//CHK-DAG: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] +//CHK-DAG: [[PayloadType1:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] +//CHK-DAG: [[PayloadType2:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] +//CHK-DAG: [[NodePtrType1:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType1]] +//CHK-DAG: [[NodePtrType2:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType2]] +//CHK-DAG: [[TempPtrType1:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType1]] +//CHK-DAG: [[TempPtrType2:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType2]] +//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] //CHK: ; Function // ThreadNodeOutputRecords::__init() -//CHK-DAG: [[Alloc1:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType1]] % -//CHK-DAG: [[Alloc2:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[PtrType2]] % -//CHK-NOT: = OpAllocateNodePayloadsAMDX +//CHK-DAG: [[TempPtrVar1:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType1]] Function +//CHK-DAG: [[TempPtrVar2:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType2]] Function +//CHK-DAG: [[Alloc1:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType1]] % +//CHK-DAG: [[Alloc2:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType2]] % +//CHK-DAG: [[Loaded1:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType1]] [[Alloc1]] +//CHK-DAG: [[Loaded2:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType2]] [[Alloc2]] +//CHK-DAG: OpStore [[TempPtrVar1]] [[Loaded1]] +//CHK-DAG: OpStore [[TempPtrVar2]] [[Loaded2]] // ThreadNodeOutputRecords::Get() -//CHK-DAG: [[Load1:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType1]] %var -//CHK-DAG: [[Load2:%[a-zA-Z_0-9]+]] = OpLoad [[PtrType2]] %var -//CHK-DAG: = OpAccessChain [[PtrType1]] [[Load1]] % -//CHK-DAG: = OpAccessChain [[PtrType2]] [[Load2]] % +//CHK-COUNT-2: = OpAccessChain [[FuncPtrType]] % // ThreadNodeOutputRecords::OutputComplete() -//CHK-DAG: OpEnqueueNodePayloadsAMDX [[Alloc1]] -//CHK-DAG: OpEnqueueNodePayloadsAMDX [[Alloc2]] -//CHK-NOT: OpEnqueueNodePayloadsAMDX +//CHK-COUNT-2: OpEnqueueNodePayloadsAMDX % diff --git a/tests/workgraphs/producer.slang b/tests/workgraphs/producer.slang index a56e624716..b513034479 100644 --- a/tests/workgraphs/producer.slang +++ b/tests/workgraphs/producer.slang @@ -24,28 +24,26 @@ void computeMain() //CHK: OpCapability ShaderEnqueueAMDX //CHK: OpExtension "SPV_AMDX_shader_enqueue" -//CHK: ; Annotations +//CHK: ; Types, variables and constants -//CHK-DAG: [[PayloadRecordType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX %RecordData -//CHK-DAG: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadRecordType]] -//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function %RecordData - -//CHK-NOT: = OpTypeNodePayloadArrayAMDX -//CHK-NOT: = OpConstantStringAMDX -//CHK-NOT: OpDecorateID {{.*}} PayloadNodeNameAMDX +//CHK-DAG: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1 +//CHK-DAG: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] +//CHK-DAG: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] +//CHK-DAG: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] +//CHK-DAG: [[TempPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType]] +//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] //CHK: ; Function // ThreadNodeOutputRecords::__init() +//CHK: [[TempPtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType]] Function //CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType]] % -//CHK-NOT: = OpAllocateNodePayloadsAMDX +//CHK: [[Loaded:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType]] [[Alloc]] +//CHK: OpStore [[TempPtrVar]] [[Loaded]] // ThreadNodeOutputRecords::Get() -//CHK: [[Load:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadRecordType]] -//CHK: OpStore [[Var:%[a-zA-Z_0-9]+]] [[Load]] -//CHK: = OpAccessChain [[FuncPtrType]] [[Var]] % +//CHK: = OpAccessChain [[FuncPtrType]] % // ThreadNodeOutputRecords::OutputComplete() -//CHK: OpEnqueueNodePayloadsAMDX [[Alloc]] -//CHK-NOT: OpEnqueueNodePayloadsAMDX +//CHK: OpEnqueueNodePayloadsAMDX % From 9ddc3de91851ba2b771928f3f3a913c37d32a35a Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Wed, 19 Mar 2025 10:48:47 -0700 Subject: [PATCH 09/13] Fix compile error on Linux --- source/slang/slang-emit-spirv.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b025ad84e4..5bd90a6002 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1928,8 +1928,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_PrimitivesType: return nullptr; case kIROp_SPIRVNodePayloadArrayType: - if (auto nodePayloadArrayType = as(inst)) { + auto nodePayloadArrayType = cast(inst); auto newType = emitOpTypeNodePayloadArray(inst, nodePayloadArrayType->getRecordType()); return newType; From dd31ed712b36b10f52b69f15da7735d71ba07602 Mon Sep 17 00:00:00 2001 From: slangbot Date: Thu, 20 Mar 2025 01:51:41 +0800 Subject: [PATCH 10/13] format code (#48) Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang/slang-emit-spirv.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 5bd90a6002..2f9de6cfdc 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -7857,7 +7857,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { SpvInst* last = nullptr; - auto &idMap = m_idMaps.getOrAddValue(inst, Dictionary()); + auto& idMap = m_idMaps.getOrAddValue(inst, Dictionary()); for (const auto spvInst : inst->getInsts()) { @@ -8208,7 +8208,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (opcode == SpvOpVariable) { // SPIRV validator says, - // "All OpVariable instructions in a function must be the first instructions in the first block." + // "All OpVariable instructions in a function must be the first instructions in + // the first block." opParent = firstLabel; auto opStorageClass = spvInst->getOperand(3); From d6e2c2de8567c5e7cbd0a5b9b2f19854f8cd2750 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Wed, 2 Apr 2025 19:39:24 -0700 Subject: [PATCH 11/13] Emit SpvCapabilityVariablePointers --- source/slang/slang-emit-spirv.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2f9de6cfdc..1e40b3c5bc 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1565,6 +1565,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex break; case SpvStorageClassNodePayloadAMDX: requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX); + requireSPIRVCapability(SpvCapabilityVariablePointers); // TODO: need to relocate to a more proper place. ensureExtensionDeclaration(UnownedStringSlice("SPV_AMDX_shader_enqueue")); break; } From 59f37d10019dce318db98cbe2fcc037ef0c19b20 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 4 Apr 2025 12:19:24 -0700 Subject: [PATCH 12/13] Avoid using Function storage-class --- source/slang/core.meta.slang | 8 +-- source/slang/hlsl.meta.slang | 62 ++++++++++++--------- source/slang/slang-emit-spirv.cpp | 50 +++++------------ tests/a.slang | 52 +++++++++++++++++ tests/workgraphs/consumer.slang | 8 +-- tests/workgraphs/producer-multi-nodes.slang | 22 ++------ tests/workgraphs/producer.slang | 16 +++--- 7 files changed, 120 insertions(+), 98 deletions(-) create mode 100644 tests/a.slang diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 60ed85fe28..018509ac29 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -4227,9 +4227,9 @@ struct __SPIRVNodePayloadArray; __generic typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray, $( (uint64_t)AddressSpace::NodePayloadAMDX)>; -/// @public -/// FunctionPtr is a pointer type for "Function" storage-class and it points to a `BaseType`. +/// @internal +/// NodePayloadPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to a `RecordType`. /// -__generic -typealias FunctionPtr = Ptr; +__generic +typealias NodePayloadPtr = Ptr; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 62ed280131..49f8cf4413 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24353,11 +24353,11 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue) __generic struct ThreadNodeOutputRecords { - FunctionPtr<__SPIRVNodePayloadArray> alloc; + NodePayloadArrayPtr alloc; [ForceInline] [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] - __init(int recordCount, int nodeIndex = 0) + __init(int recordCount, int nodeIndex = 0) // , constexpr String nodeStr) { __target_switch { @@ -24365,16 +24365,11 @@ struct ThreadNodeOutputRecords case spirv: alloc = spirv_asm { - %alloc = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; - - %loaded = OpLoad $$__SPIRVNodePayloadArray %alloc; - - %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; - OpStore %tempPtrVar %loaded; - - result : $$FunctionPtr<__SPIRVNodePayloadArray> = OpCopyObject %tempPtrVar; + //TODO: `String` doesn't seem to work inside of spirv_asm. + //%nodeStr = OpConstantStringAMDX $nodeStr; + //OpDecorateId $$__SPIRVNodePayloadArray PayloadNodeNameAMDX %nodeStr; + result = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr Workgroup $recordCount $nodeIndex; }; - return; } } @@ -24389,14 +24384,15 @@ struct ThreadNodeOutputRecords } } - __subscript(uint index) -> FunctionPtr + __subscript(uint index) -> NodePayloadPtr { [ForceInline] get { return Get(index); } } [ForceInline] - FunctionPtr Get(int index = 0) + [require(hlsl_spirv)] // [require(hlsl_spirv, thread)] + NodePayloadPtr Get(int index = 0) { __target_switch { @@ -24404,7 +24400,7 @@ struct ThreadNodeOutputRecords case spirv: return spirv_asm { - result : $$FunctionPtr = OpAccessChain $alloc $index + result : $$NodePayloadPtr = OpAccessChain $alloc $index }; } } @@ -24416,8 +24412,31 @@ struct ThreadNodeOutputRecords __generic struct DispatchNodeInputRecord { + NodePayloadArrayPtr payload; + + [ForceInline] + [require(hlsl_spirv)] // [require(hlsl_spirv, broadcasting)] + __init() + { + __target_switch + { + case hlsl: return; + case spirv: + //TODO: It may be simpler if we add a new __intrinsic_op for OpVariable with NodePayloadAMDX. + payload = spirv_asm + { + // The variable name, "%payload", is required here, + // because it needs to be send to OpEntryPoint as an interface. + %payload = OpVariable $$NodePayloadArrayPtr NodePayloadAMDX; + result : $$NodePayloadArrayPtr = OpCopyObject %payload; + }; + } + } + /// Provide an access to a record object that only holds a single record. - FunctionPtr Get() + [ForceInline] + [require(hlsl_spirv)] // [require(hlsl_spirv, broadcasting)] + NodePayloadPtr Get() { int index = 0; @@ -24425,19 +24444,10 @@ struct DispatchNodeInputRecord { case hlsl: __intrinsic_asm ".Get"; case spirv: - let ptr = spirv_asm + return spirv_asm { - %nodePtrType = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray; - %nodePtrVar = OpVariable %nodePtrType NodePayloadAMDX; - - %tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray> Function; - - %loaded = OpLoad $$__SPIRVNodePayloadArray %nodePtrVar; - OpStore %tempPtrVar %loaded; - - result : $$FunctionPtr = OpAccessChain %tempPtrVar $index; + result : $$NodePayloadPtr = OpAccessChain $payload $index; }; - return ptr; } } }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 1e40b3c5bc..6fb8af5135 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -493,9 +493,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex /// The next destination `` to allocate. SpvWord m_nextID = 1; - // This keeps track of the named IDs used in the asm block - Dictionary> m_idMaps; - OrderedDictionary m_forwardDeclaredPointers; SpvInst* m_nullDwarfExpr = nullptr; @@ -2996,7 +2993,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex switch (inst->getOp()) { case kIROp_Var: - emitLocalInst(spvBlock, spvBlock, inst); + emitLocalInst(spvBlock, inst); break; case kIROp_DebugVar: // Declare an ordinary local variable for debugDeclare association @@ -3004,9 +3001,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // values to upon a `kIROp_DebugValue` inst. emitDebugVarBackingLocalVarDeclaration(spvBlock, as(inst)); break; - case kIROp_SPIRVAsm: - emitLocalInst(spvBlock, spvBlock, inst); - break; } } } @@ -3087,7 +3081,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Skip vars because they are already emitted. if (as(irInst)) continue; - emitLocalInst(spvBlock, nullptr, irInst); + emitLocalInst(spvBlock, irInst); if (irInst->getOp() == kIROp_loop) pendingLoopInsts.add(as(irInst)); if (irInst->getOp() == kIROp_discard && !shouldEmitDiscardAsDemote()) @@ -3464,7 +3458,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // a known parent (the basic block that contains them). /// Emit an instruction that is local to the body of the given `parent`. - SpvInst* emitLocalInst(SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst) + SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) { SpvInst* result = nullptr; switch (inst->getOp()) @@ -3596,7 +3590,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_Geq: case kIROp_Rsh: case kIROp_Lsh: - result = emitArithmetic(parent, firstLabel, inst); + result = emitArithmetic(parent, inst); break; case kIROp_CastDescriptorHandleToUInt2: case kIROp_CastUInt2ToDescriptorHandle: @@ -3847,7 +3841,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex result = emitOpUndef(parent, inst, inst->getDataType()); break; case kIROp_SPIRVAsm: - result = emitSPIRVAsm(parent, firstLabel, as(inst)); + result = emitSPIRVAsm(parent, as(inst)); break; case kIROp_ImageLoad: result = emitImageLoad(parent, as(inst)); @@ -7269,7 +7263,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } - SpvInst* emitArithmetic(SpvInstParent* parent, SpvInstParent* firstLabel, IRInst* inst) + SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { if (const auto matrixType = as(inst->getDataType())) { @@ -7288,7 +7282,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (as(originalOperand->getDataType())) { auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, firstLabel, operand); + emitLocalInst(parent, operand); operands.add(operand); } else @@ -7854,11 +7848,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return debugFunc; } - SpvInst* emitSPIRVAsm(SpvInstParent* parent, SpvInstParent* firstLabel, IRSPIRVAsm* inst) + SpvInst* emitSPIRVAsm(SpvInstParent* parent, IRSPIRVAsm* inst) { SpvInst* last = nullptr; - auto& idMap = m_idMaps.getOrAddValue(inst, Dictionary()); + // This keeps track of the named IDs used in the asm block + Dictionary idMap; for (const auto spvInst : inst->getInsts()) { @@ -8036,10 +8031,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (spvInst->getOpcodeOperand()->getOp() == kIROp_SPIRVAsmOperandTruncate) { - // Nothing to emit to the first OpLabel - if (firstLabel) - continue; - const auto getSlangType = [&](IRSPIRVAsmOperand* operand) -> IRType* { switch (operand->getOp()) @@ -8208,17 +8199,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvInstParent* opParent = nullptr; if (opcode == SpvOpVariable) { - // SPIRV validator says, - // "All OpVariable instructions in a function must be the first instructions in - // the first block." - opParent = firstLabel; - auto opStorageClass = spvInst->getOperand(3); if (opStorageClass && opStorageClass->getOp() == kIROp_SPIRVAsmOperandEnum) { - if (auto intLit = cast(opStorageClass->getOperand(0))) + if (auto enumStorageClass = cast(opStorageClass->getOperand(0))) { - switch (SpvStorageClass(intLit->getValue())) + switch (SpvStorageClass(enumStorageClass->getValue())) { case SpvStorageClassNodePayloadAMDX: requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX); @@ -8251,11 +8237,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex opParent == getSection(SpvLogicalSectionID::ConstantsAndTypes) && opInfo && opInfo->resultIdIndex == 0; - // SpvOpVariable must appear at the first block of the function - // And it may depends on other memoized instructions. - if ((opcode == SpvOpVariable || memoize) == (firstLabel == nullptr)) - continue; - // We want the "result instruction" to refer to the top level // block which assumes its value, the others are free to refer // to whatever, so just use the internal spv inst rep @@ -8318,11 +8299,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } - if (firstLabel == nullptr) - { - for (const auto& [name, id] : idMap) - emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name); - } + for (const auto& [name, id] : idMap) + emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name); return last; } diff --git a/tests/a.slang b/tests/a.slang new file mode 100644 index 0000000000..32c06ef427 --- /dev/null +++ b/tests/a.slang @@ -0,0 +1,52 @@ +//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -profile sm_6_0 -capability spirv_1_4 + +RWStructuredBuffer inputBuffer; + +struct RecordData +{ + int myData; +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + + //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute on "node" variable. + // Until it is implemented properly, we will use an int-type generic argument. + // + #define myNodeID 0 + ThreadNodeOutputRecords node + = ThreadNodeOutputRecords(1); + + node.Get().myData = inputBuffer[0]; + + node.OutputComplete(); +} + +//CHK: OpCapability ShaderEnqueueAMDX +//CHK: OpExtension "SPV_AMDX_shader_enqueue" + +//CHK: ; Types, variables and constants + +//CHK-DAG: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1 +//CHK-DAG: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] +//CHK-DAG: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] +//CHK-DAG: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] +//CHK-DAG: [[TempPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType]] +//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] + +//CHK: ; Function + +// ThreadNodeOutputRecords::__init() +//CHK: [[TempPtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType]] Function +//CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType]] % +//CHK: [[Loaded:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType]] [[Alloc]] +//CHK: OpStore [[TempPtrVar]] [[Loaded]] + +// ThreadNodeOutputRecords::Get() +//CHK: = OpAccessChain [[FuncPtrType]] % + +// ThreadNodeOutputRecords::OutputComplete() +//CHK: OpEnqueueNodePayloadsAMDX % + diff --git a/tests/workgraphs/consumer.slang b/tests/workgraphs/consumer.slang index 74488a7e0c..2f8178844e 100644 --- a/tests/workgraphs/consumer.slang +++ b/tests/workgraphs/consumer.slang @@ -30,12 +30,8 @@ void computeMain(uint3 dispatchThreadId : SV_GroupThreadID) //CHK: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] //CHK: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] //CHK: [[NodePtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[NodePtrType]] NodePayloadAMDX -//CHK: [[TempPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType]] -//CHK: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] +//CHK: [[StructPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[StructType]] //CHK: ; Function -//CHK: [[TempPtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType]] Function -//CHK: [[Loaded:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType]] [[NodePtrVar]] -//CHK: OpStore [[TempPtrVar]] [[Loaded]] -//CHK: = OpAccessChain [[FuncPtrType]] [[TempPtrVar]] +//CHK: = OpAccessChain [[StructPtrType]] [[NodePtrVar]] diff --git a/tests/workgraphs/producer-multi-nodes.slang b/tests/workgraphs/producer-multi-nodes.slang index bf2a4f8a04..cd78137872 100644 --- a/tests/workgraphs/producer-multi-nodes.slang +++ b/tests/workgraphs/producer-multi-nodes.slang @@ -11,12 +11,6 @@ struct RecordData [numthreads(1, 1, 1)] void computeMain() { - spirv_asm - { - OpExecutionMode $computeMain ShaderIndexAMDX $(0); - OpExecutionMode $computeMain StaticNumWorkgroupsAMDX $(1) $(1) $(1); - }; - //TODO: "myNodeID" is supposed to come from [NodeID("name")] attribute on "node" variable. // Until it is implemented properly, we will use an int-type generic argument. // @@ -44,25 +38,19 @@ void computeMain() //CHK-DAG: [[PayloadType2:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] //CHK-DAG: [[NodePtrType1:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType1]] //CHK-DAG: [[NodePtrType2:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType2]] -//CHK-DAG: [[TempPtrType1:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType1]] -//CHK-DAG: [[TempPtrType2:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType2]] -//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] +//CHK-DAG: [[StructPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[StructType]] //CHK: ; Function // ThreadNodeOutputRecords::__init() -//CHK-DAG: [[TempPtrVar1:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType1]] Function -//CHK-DAG: [[TempPtrVar2:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType2]] Function //CHK-DAG: [[Alloc1:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType1]] % //CHK-DAG: [[Alloc2:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType2]] % -//CHK-DAG: [[Loaded1:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType1]] [[Alloc1]] -//CHK-DAG: [[Loaded2:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType2]] [[Alloc2]] -//CHK-DAG: OpStore [[TempPtrVar1]] [[Loaded1]] -//CHK-DAG: OpStore [[TempPtrVar2]] [[Loaded2]] // ThreadNodeOutputRecords::Get() -//CHK-COUNT-2: = OpAccessChain [[FuncPtrType]] % +//CHK: = OpAccessChain [[StructPtrType]] [[Alloc1]] +//CHK: = OpAccessChain [[StructPtrType]] [[Alloc2]] // ThreadNodeOutputRecords::OutputComplete() -//CHK-COUNT-2: OpEnqueueNodePayloadsAMDX % +//CHK: OpEnqueueNodePayloadsAMDX [[Alloc1]] +//CHK: OpEnqueueNodePayloadsAMDX [[Alloc2]] diff --git a/tests/workgraphs/producer.slang b/tests/workgraphs/producer.slang index b513034479..f87b8a891d 100644 --- a/tests/workgraphs/producer.slang +++ b/tests/workgraphs/producer.slang @@ -1,4 +1,4 @@ -//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -skip-spirv-validation +//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry computeMain -profile sm_6_0 -capability spirv_1_4 RWStructuredBuffer inputBuffer; struct RecordData @@ -15,9 +15,11 @@ void computeMain() // Until it is implemented properly, we will use an int-type generic argument. // #define myNodeID 0 - ThreadNodeOutputRecords node = { 1 }; + ThreadNodeOutputRecords node + = ThreadNodeOutputRecords(1); node.Get().myData = inputBuffer[0]; + node.OutputComplete(); } @@ -30,20 +32,16 @@ void computeMain() //CHK-DAG: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] //CHK-DAG: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] //CHK-DAG: [[NodePtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] -//CHK-DAG: [[TempPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[PayloadType]] -//CHK-DAG: [[FuncPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer Function [[StructType]] +//CHK-DAG: [[StructPtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[StructType]] //CHK: ; Function // ThreadNodeOutputRecords::__init() -//CHK: [[TempPtrVar:%[a-zA-Z_0-9]+]] = OpVariable [[TempPtrType]] Function //CHK: [[Alloc:%[a-zA-Z_0-9]+]] = OpAllocateNodePayloadsAMDX [[NodePtrType]] % -//CHK: [[Loaded:%[a-zA-Z_0-9]+]] = OpLoad [[PayloadType]] [[Alloc]] -//CHK: OpStore [[TempPtrVar]] [[Loaded]] // ThreadNodeOutputRecords::Get() -//CHK: = OpAccessChain [[FuncPtrType]] % +//CHK: = OpAccessChain [[StructPtrType]] [[Alloc]] // ThreadNodeOutputRecords::OutputComplete() -//CHK: OpEnqueueNodePayloadsAMDX % +//CHK: OpEnqueueNodePayloadsAMDX [[Alloc]] From ffe6d675d165bb045ae668d30d43355a378cea0a Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 11 Apr 2025 12:49:42 -0700 Subject: [PATCH 13/13] Add Count() to DispatspatchNodeInputRecord --- source/slang/hlsl.meta.slang | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 49f8cf4413..5a2c36d1ca 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24450,6 +24450,15 @@ struct DispatchNodeInputRecord }; } } + + [ForceInline] + uint Count() + { + return spirv_asm + { + result: $$uint = OpNodePayloadArrayLengthAMDX $payload; + }; + } }; //