Skip to content

Commit 55ff468

Browse files
jkwak-workcsyonghe
andauthored
Support a storage class, NodePayloadAMDX, for SPIRV work-graphs (shader-slang#6052)
In order to unblock experiments with SPIRV work-graphs, Slang needs to support the storage class, `NodePayloadAMDX`. Note that this commit is only to support a storage class, `NodePayloadAMDX`. There are many parts required for work-graphs hasn't been implemented yet. The implementation of `DispatchNodeInputRecord` is not required, but it is implemented mostly for a testing purpose. Closes shader-slang#6049 Co-authored-by: Yong He <yonghe@outlook.com>
1 parent fce63c2 commit 55ff468

6 files changed

+108
-12
lines changed

source/slang/core.meta.slang

+4
Original file line numberDiff line numberDiff line change
@@ -3965,3 +3965,7 @@ attribute_syntax [DerivativeGroupQuad] : DerivativeGroupQuadAttribute;
39653965
/// effect on other targets.
39663966
__attributeTarget(FuncDecl)
39673967
attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute;
3968+
3969+
__generic<T>
3970+
typealias NodePayloadPtr = Ptr<T, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
3971+

source/slang/hlsl.meta.slang

+28
Original file line numberDiff line numberDiff line change
@@ -21365,3 +21365,31 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue)
2136521365
}
2136621366
}
2136721367

21368+
// Work-graphs
21369+
21370+
//@public:
21371+
/// read-only input to Broadcasting launch node.
21372+
__generic<T>
21373+
//TODO: DispatchNodeInputRecord should be available only for broadcasting node shader.
21374+
//[require(broadcasting_node)]
21375+
[require(spirv)]
21376+
struct DispatchNodeInputRecord
21377+
{
21378+
/// Provide an access to a record object that only holds a single record.
21379+
NodePayloadPtr<T> Get()
21380+
{
21381+
int index = 0;
21382+
__target_switch
21383+
{
21384+
case spirv:
21385+
return spirv_asm
21386+
{
21387+
%in_payload_t = OpTypeNodePayloadArrayAMDX $$T;
21388+
%in_payload_ptr_t = OpTypePointer NodePayloadAMDX %in_payload_t;
21389+
%var = OpVariable %in_payload_ptr_t NodePayloadAMDX;
21390+
result : $$NodePayloadPtr<T> = OpAccessChain %var $index;
21391+
};
21392+
}
21393+
}
21394+
};
21395+

source/slang/slang-emit-spirv-ops.h

+14
Original file line numberDiff line numberDiff line change
@@ -2552,4 +2552,18 @@ SpvInst* emitOpAtomicIDecrement(
25522552
memory,
25532553
semantics);
25542554
}
2555+
2556+
// https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/AMD/SPV_AMDX_shader_enqueue.html#OpTypeNodePayloadArrayAMDX
2557+
template<typename T>
2558+
SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type)
2559+
{
2560+
static_assert(isSingular<T>);
2561+
return emitInstMemoized(
2562+
getSection(SpvLogicalSectionID::ConstantsAndTypes),
2563+
inst,
2564+
SpvOpTypeNodePayloadArrayAMDX,
2565+
kResultID,
2566+
type);
2567+
}
2568+
25552569
#endif // SLANG_IN_SPIRV_EMIT_CONTEXT

source/slang/slang-emit-spirv.cpp

+28-12
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
13111311
return SpvStorageClassImage;
13121312
case AddressSpace::UserPointer:
13131313
return SpvStorageClassPhysicalStorageBuffer;
1314+
case AddressSpace::NodePayloadAMDX:
1315+
return SpvStorageClassNodePayloadAMDX;
13141316
case AddressSpace::Global:
13151317
case AddressSpace::MetalObjectData:
13161318
case AddressSpace::SpecializationConstant:
@@ -1504,13 +1506,22 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
15041506
SLANG_ASSERT(ptrType);
15051507
if (ptrType->hasAddressSpace())
15061508
storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
1507-
if (storageClass == SpvStorageClassStorageBuffer)
1509+
1510+
switch (storageClass)
1511+
{
1512+
case SpvStorageClassStorageBuffer:
15081513
ensureExtensionDeclaration(
15091514
UnownedStringSlice("SPV_KHR_storage_buffer_storage_class"));
1510-
if (storageClass == SpvStorageClassPhysicalStorageBuffer)
1511-
{
1515+
break;
1516+
case SpvStorageClassPhysicalStorageBuffer:
15121517
requirePhysicalStorageAddressing();
1518+
break;
1519+
case SpvStorageClassNodePayloadAMDX:
1520+
requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX);
1521+
ensureExtensionDeclaration(UnownedStringSlice("SPV_AMDX_shader_enqueue"));
1522+
break;
15131523
}
1524+
15141525
auto valueType = ptrType->getValueType();
15151526
// If we haven't emitted the inner type yet, we need to emit a forward declaration.
15161527
bool useForwardDeclaration =
@@ -1524,17 +1535,20 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
15241535
builder.setInsertBefore(valueType);
15251536
valueTypeId = getID(ensureInst(builder.getUIntType()));
15261537
}
1538+
else if (useForwardDeclaration)
1539+
{
1540+
valueTypeId = getIRInstSpvID(valueType);
1541+
}
1542+
else if (storageClass == SpvStorageClassNodePayloadAMDX)
1543+
{
1544+
auto spvValueType = ensureInst(valueType);
1545+
auto spvNodePayloadType = emitOpTypeNodePayloadArray(inst, spvValueType);
1546+
valueTypeId = getID(spvNodePayloadType);
1547+
}
15271548
else
15281549
{
1529-
if (useForwardDeclaration)
1530-
{
1531-
valueTypeId = getIRInstSpvID(valueType);
1532-
}
1533-
else
1534-
{
1535-
auto spvValueType = ensureInst(valueType);
1536-
valueTypeId = getID(spvValueType);
1537-
}
1550+
auto spvValueType = ensureInst(valueType);
1551+
valueTypeId = getID(spvValueType);
15381552
}
15391553

15401554
auto resultSpvType = emitOpTypePointer(inst, storageClass, valueTypeId);
@@ -7564,6 +7578,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
75647578
case SpvOpMemberDecorate:
75657579
case SpvOpMemberDecorateString:
75667580
return getSection(SpvLogicalSectionID::Annotations);
7581+
case SpvOpTypeNodePayloadArrayAMDX:
7582+
return getSection(SpvLogicalSectionID::ConstantsAndTypes);
75677583
default:
75687584
return defaultParent;
75697585
}

source/slang/slang-type-system-shared.h

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ enum class AddressSpace : uint64_t
110110
Image,
111111
// Represents a SPIR-V specialization constant
112112
SpecializationConstant,
113+
// Corresponds to SPIR-V's SpvStorageClassNodePayloadAMDX,
114+
NodePayloadAMDX,
113115

114116
// Default address space for a user-defined pointer
115117
UserPointer = 0x100000001ULL,

tests/workgraphs/consumer.slang

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry main -skip-spirv-validation
2+
struct RecordData
3+
{
4+
int myData;
5+
};
6+
7+
[shader("compute")]
8+
[numthreads(1, 1, 1)]
9+
void main(uint3 dispatchThreadId : SV_GroupThreadID)
10+
{
11+
spirv_asm
12+
{
13+
OpExecutionMode $main ShaderIndexAMDX $(0);
14+
OpExecutionMode $main StaticNumWorkgroupsAMDX $(1) $(1) $(1);
15+
};
16+
17+
DispatchNodeInputRecord<RecordData> inputData;
18+
19+
let recordData = inputData.Get();
20+
int myData = recordData.myData;
21+
}
22+
23+
//CHK: ; Types, variables and constants
24+
//CHK: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1
25+
//CHK: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]]
26+
//CHK: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]]
27+
//CHK: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]]
28+
29+
//CHK: ; Function
30+
//CHK: [[VarName:%[a-zA-Z_0-9]+]] = OpVariable [[PtrType]] NodePayloadAMDX
31+
//CHK: = OpAccessChain [[PtrType]] [[VarName]]
32+

0 commit comments

Comments
 (0)