Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow duplication of OpTypeNodePayloadArrayAMDX #6470

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
23 changes: 21 additions & 2 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -4211,6 +4211,25 @@ attribute_syntax [QuadDerivatives] : QuadDerivativesAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute;

__generic<T>
typealias NodePayloadPtr = Ptr<T, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
// Work-graphs

/// @internal
/// __SPIRVNodePayloadArray is to emit OpTypeNodePayloadArrayAMDX and
/// SpvDecorationPayloadNodeNameAMDX for it.
///
__generic<BaseType, let nodeID : int>
__intrinsic_type($(kIROp_SPIRVNodePayloadArrayType))
struct __SPIRVNodePayloadArray;

/// @internal
/// NodePayloadArrayPtr is a pointer type for "NodePayloadAMDX" storage-class and it points to an array of work-graph nodes
///
__generic<BaseType, let nodeID : int>
typealias NodePayloadArrayPtr = Ptr<__SPIRVNodePayloadArray<BaseType, nodeID>, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;

/// @public
/// FunctionPtr is a pointer type for "Function" storage-class and it points to a `BaseType`.
///
__generic<BaseType>
typealias FunctionPtr = Ptr<BaseType, $( (uint64_t)AddressSpace::Function)>;

91 changes: 80 additions & 11 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -24280,28 +24280,97 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue)

// Work-graphs

///@public
/// Set of zero or more records for each thread.
///
__generic<RecordType, let nodeID : int>
struct ThreadNodeOutputRecords
{
FunctionPtr<__SPIRVNodePayloadArray<RecordType, nodeID>> alloc;

[ForceInline]
[require(hlsl_spirv)] // [require(hlsl_spirv, thread)]
__init(int recordCount, int nodeIndex = 0)
{
__target_switch
{
case hlsl: return;
case spirv:
alloc = spirv_asm
{
%alloc = OpAllocateNodePayloadsAMDX $$NodePayloadArrayPtr<RecordType, nodeID> Workgroup $recordCount $nodeIndex;

%loaded = OpLoad $$__SPIRVNodePayloadArray<RecordType, nodeID> %alloc;

%tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray<RecordType, nodeID>> Function;
OpStore %tempPtrVar %loaded;

result : $$FunctionPtr<__SPIRVNodePayloadArray<RecordType, nodeID>> = OpCopyObject %tempPtrVar;
};
return;
}
}

[ForceInline]
[require(hlsl_spirv)] // [require(hlsl_spirv, thread)]
void OutputComplete()
{
__target_switch
{
case hlsl: __intrinsic_asm ".OutputComplete";
case spirv: spirv_asm { OpEnqueueNodePayloadsAMDX $alloc; };
}
}

__subscript(uint index) -> FunctionPtr<RecordType>
{
[ForceInline]
get { return Get(index); }
}

[ForceInline]
FunctionPtr<RecordType> Get(int index = 0)
{
__target_switch
{
case hlsl: __intrinsic_asm ".Get";
case spirv:
return spirv_asm
{
result : $$FunctionPtr<RecordType> = OpAccessChain $alloc $index
};
}
}
};

//@public:
/// read-only input to Broadcasting launch node.
__generic<T>
//TODO: DispatchNodeInputRecord should be available only for broadcasting node shader.
//[require(broadcasting_node)]
[require(spirv)]
/// Read-only input to Broadcasting launch node.
///
__generic<RecordType, let nodeID : int>
struct DispatchNodeInputRecord
{
/// Provide an access to a record object that only holds a single record.
NodePayloadPtr<T> Get()
FunctionPtr<RecordType> Get()
{
int index = 0;

__target_switch
{
case hlsl: __intrinsic_asm ".Get";
case spirv:
return spirv_asm
let ptr = spirv_asm
{
%in_payload_t = OpTypeNodePayloadArrayAMDX $$T;
%in_payload_ptr_t = OpTypePointer NodePayloadAMDX %in_payload_t;
%var = OpVariable %in_payload_ptr_t NodePayloadAMDX;
result : $$NodePayloadPtr<T> = OpAccessChain %var $index;
%nodePtrType = OpTypePointer NodePayloadAMDX $$__SPIRVNodePayloadArray<RecordType, nodeID>;
%nodePtrVar = OpVariable %nodePtrType NodePayloadAMDX;

%tempPtrVar = OpVariable $$FunctionPtr<__SPIRVNodePayloadArray<RecordType, nodeID>> Function;

%loaded = OpLoad $$__SPIRVNodePayloadArray<RecordType, nodeID> %nodePtrVar;
OpStore %tempPtrVar %loaded;

result : $$FunctionPtr<RecordType> = OpAccessChain %tempPtrVar $index;
};
return ptr;
}
}
};
Expand Down
26 changes: 25 additions & 1 deletion source/slang/slang-emit-spirv-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2585,12 +2585,36 @@ template<typename T>
SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type)
{
static_assert(isSingular<T>);
return emitInstMemoized(
return emitInst(
getSection(SpvLogicalSectionID::ConstantsAndTypes),
inst,
SpvOpTypeNodePayloadArrayAMDX,
kResultID,
type);
}

// https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/AMD/SPV_AMDX_shader_enqueue.html#OpTypeNodePayloadArrayAMDX
SpvInst* emitOpConstantString(IRInst* inst, const UnownedStringSlice& str)
{
return emitInstMemoized(
getSection(SpvLogicalSectionID::ConstantsAndTypes),
inst,
SpvOpConstantStringAMDX,
kResultID,
str);
}

// https://github.khronos.org/SPIRV-Registry/extensions/AMD/SPV_AMDX_shader_enqueue.html#_decorations
template<typename T1, typename T2>
SpvInst* emitOpDecoratePayloadNodeName(IRInst* inst, const T1& target, const T2& id)
{
return emitInst(
getSection(SpvLogicalSectionID::Annotations),
inst,
SpvOpDecorateId,
target,
SpvDecorationPayloadNodeNameAMDX,
id);
}

#endif // SLANG_IN_SPIRV_EMIT_CONTEXT
Loading
Loading