Skip to content

Commit b380b1a

Browse files
authored
Wave Prefix Product (shader-slang#1270)
* Fix some typos. * Add wave-prefix-sum.slang test * First pass at implementing prefixSum. * Small improvments to prefixSum CUDA. * Small improvement to prefix sum. * Enable prefix sum in stdlib. * Wave prefix product without using a divide. * Split out SM6.5 Wave intrinsics. Template mechanism for do prefix calculations.
1 parent a10d9cd commit b380b1a

File tree

4 files changed

+169
-57
lines changed

4 files changed

+169
-57
lines changed

prelude/slang-cuda-prelude.h

+98-18
Original file line numberDiff line numberDiff line change
@@ -534,20 +534,25 @@ struct WaveOpXor
534534
{
535535
__inline__ __device__ static T getInitial(T a) { return 0; }
536536
__inline__ __device__ static T doOp(T a, T b) { return a ^ b; }
537+
__inline__ __device__ static T doInverse(T a, T b) { return a ^ b; }
537538
};
538539

539540
template <typename T>
540541
struct WaveOpAdd
541542
{
542543
__inline__ __device__ static T getInitial(T a) { return 0; }
543544
__inline__ __device__ static T doOp(T a, T b) { return a + b; }
545+
__inline__ __device__ static T doInverse(T a, T b) { return a - b; }
544546
};
545547

546548
template <typename T>
547549
struct WaveOpMul
548550
{
549551
__inline__ __device__ static T getInitial(T a) { return T(1); }
550552
__inline__ __device__ static T doOp(T a, T b) { return a * b; }
553+
// Using this inverse for int is probably undesirable - because in general it requires T to have more precision
554+
// There is also a performance aspect to it, where divides are generally significantly slower
555+
__inline__ __device__ static T doInverse(T a, T b) { return a / b; }
551556
};
552557

553558
template <typename T>
@@ -823,46 +828,121 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane)
823828
return outVal;
824829
}
825830

826-
__device__ int _wavePrefixSum(int val)
831+
// Scalar
832+
833+
// Invertable means that when we get to the end of the reduce, we can remove val (to make exclusive), using
834+
// the inverse of the op.
835+
template <typename INTF, typename T>
836+
__device__ T _wavePrefixInvertableScalar(T val)
827837
{
828838
const int mask = __activemask();
829839
const int offsetSize = _waveCalcPow2Offset(mask);
830840

831841
const int laneId = _getLaneId();
842+
T result;
832843
if (offsetSize > 0)
833844
{
834-
int sum = val;
845+
// Sum is calculated inclusive of this lanes value
846+
result = val;
835847
for (int i = 1; i < offsetSize; i += i)
836848
{
837-
const int readVal = __shfl_up_sync(mask, sum, i, offsetSize);
849+
const T readVal = __shfl_up_sync(mask, result, i, offsetSize);
838850
if (laneId >= i)
839851
{
840-
sum += readVal;
852+
result = INTF::doOp(result, readVal);
841853
}
842854
}
843-
return sum - val;
855+
// Remove val from the result, by applyin inverse
856+
result = INTF::doInverse(result, val);
844857
}
845858
else
846859
{
847-
int result = 0;
848-
int remaining = mask;
849-
while (remaining)
860+
result = INTF::getInitial(val);
861+
if (!_waveIsSingleLane(mask))
850862
{
851-
const int laneBit = remaining & -remaining;
852-
// Get the sourceLane
853-
const int srcLane = __ffs(laneBit) - 1;
854-
// Broadcast (can also broadcast to self)
855-
int readValue = __shfl_sync(mask, val, srcLane);
856-
// Only accumulate if srcLane is less than this lane
857-
if (srcLane < laneId)
863+
int remaining = mask;
864+
while (remaining)
858865
{
859-
result += readValue;
866+
const int laneBit = remaining & -remaining;
867+
// Get the sourceLane
868+
const int srcLane = __ffs(laneBit) - 1;
869+
// Broadcast (can also broadcast to self)
870+
const T readValue = __shfl_sync(mask, val, srcLane);
871+
// Only accumulate if srcLane is less than this lane
872+
if (srcLane < laneId)
873+
{
874+
result = INTF::doOp(result, readValue);
875+
}
876+
remaining &= ~laneBit;
877+
}
878+
}
879+
}
880+
return result;
881+
}
882+
883+
// This implementation separately tracks the value to be propogated, and the value
884+
// that is the final result
885+
template <typename INTF, typename T>
886+
__device__ T _wavePrefixScalar(T val)
887+
{
888+
const int mask = __activemask();
889+
const int offsetSize = _waveCalcPow2Offset(mask);
890+
891+
const int laneId = _getLaneId();
892+
T result = INTF::getInitial(val);
893+
if (offsetSize > 0)
894+
{
895+
// For transmitted value we will do it inclusively with this lanes value
896+
// For the result we do not include the lanes value. This means an extra multiply for each iteration
897+
// but means we don't need to have a divide at the end and also removes overflow issues in that scenario.
898+
for (int i = 1; i < offsetSize; i += i)
899+
{
900+
const T readVal = __shfl_up_sync(mask, val, i, offsetSize);
901+
if (laneId >= i)
902+
{
903+
result = INTF::doOp(result, readVal);
904+
val = INTF::doOp(val, readVal);
905+
}
906+
}
907+
}
908+
else
909+
{
910+
if (!_waveIsSingleLane(mask))
911+
{
912+
int remaining = mask;
913+
while (remaining)
914+
{
915+
const int laneBit = remaining & -remaining;
916+
// Get the sourceLane
917+
const int srcLane = __ffs(laneBit) - 1;
918+
// Broadcast (can also broadcast to self)
919+
const T readValue = __shfl_sync(mask, val, srcLane);
920+
// Only accumulate if srcLane is less than this lane
921+
if (srcLane < laneId)
922+
{
923+
result = INTF::doOp(result, readValue);
924+
}
925+
remaining &= ~laneBit;
860926
}
861-
remaining &= ~laneBit;
862927
}
863-
return result;
864928
}
929+
return result;
865930
}
931+
932+
template <typename T>
933+
__inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar<WaveOpMul<T>, T>(val); }
934+
935+
template <typename T>
936+
__inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val); }
937+
938+
template <typename T>
939+
__inline__ __device__ T _wavePrefixAnd(T val) { return _wavePrefixScalar<WaveOpAnd<T>, T>(val); }
940+
941+
template <typename T>
942+
__inline__ __device__ T _wavePrefixOr(T val) { return _wavePrefixScalar<WaveOpOr<T>, T>(val); }
943+
944+
template <typename T>
945+
__inline__ __device__ T _wavePrefixXor(T val) { return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(val); }
866946

867947
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
868948

source/slang/hlsl.meta.slang

+47-39
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,7 @@ __generic<T : __BuiltinArithmeticType>
24982498
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
24992499
__spirv_version(1.3)
25002500
__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
2501+
__target_intrinsic(cuda, "_wavePrefixProduct($0)")
25012502
T WavePrefixProduct(T expr);
25022503
__generic<T : __BuiltinArithmeticType, let N : int>
25032504
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
@@ -2521,10 +2522,54 @@ vector<T,N> WavePrefixSum(vector<T,N> expr);
25212522
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
25222523
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
25232524

2525+
__generic<T : __BuiltinType>
2526+
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2527+
__spirv_version(1.3)
2528+
__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
2529+
__target_intrinsic(cuda, "_waveReadFirst($0)")
2530+
T WaveReadLaneFirst(T expr);
2531+
__generic<T : __BuiltinType, let N : int>
2532+
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2533+
__spirv_version(1.3)
2534+
__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
2535+
__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
2536+
vector<T,N> WaveReadLaneFirst(vector<T,N> expr);
2537+
__generic<T : __BuiltinType, let N : int, let M : int>
2538+
__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
2539+
matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr);
2540+
2541+
// NOTE! On GLSL based targets the lane index *must* be a compile time expression!
2542+
// See https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt
2543+
__generic<T : __BuiltinType>
2544+
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2545+
__spirv_version(1.3)
2546+
__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
2547+
__target_intrinsic(cuda, "__shfl_sync(__activemask(), $0, $1)")
2548+
T WaveReadLaneAt(T value, int lane);
2549+
__generic<T : __BuiltinType, let N : int>
2550+
__spirv_version(1.3)
2551+
__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
2552+
__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
2553+
vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane);
2554+
__generic<T : __BuiltinType, let N : int, let M : int>
2555+
__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
2556+
matrix<T,N,M> WaveReadLaneAt(matrix<T,N,M> value, int lane);
2557+
2558+
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2559+
__spirv_version(1.3)
2560+
__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($0))")
2561+
__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & _getLaneLtMask())")
2562+
uint WavePrefixCountBits(bool value);
2563+
2564+
// Shader model 6.5 stuff
2565+
// https://github.com/microsoft/DirectX-Specs/blob/master/d3d/HLSL_ShaderModel6_5.md
2566+
// TODO(JS): Looks like they need a mask parameter
2567+
25242568
__generic<T : __BuiltinArithmeticType>
25252569
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25262570
__spirv_version(1.3)
25272571
__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
2572+
__target_intrinsic(cuda, "_wavePrefixAnd($0)")
25282573
T WaveMultiPrefixBitAnd(T expr);
25292574
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25302575
__spirv_version(1.3)
@@ -2538,6 +2583,7 @@ __generic<T : __BuiltinArithmeticType>
25382583
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25392584
__spirv_version(1.3)
25402585
__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
2586+
__target_intrinsic(cuda, "_wavePrefixOr($0)")
25412587
T WaveMultiPrefixBitOr(T expr);
25422588
__generic<T : __BuiltinArithmeticType, let N : int>
25432589
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
@@ -2551,6 +2597,7 @@ __generic<T : __BuiltinArithmeticType>
25512597
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25522598
__spirv_version(1.3)
25532599
__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
2600+
__target_intrinsic(cuda, "_wavePrefixXor($0)")
25542601
T WaveMultiPrefixBitXor(T expr);
25552602
__generic<T : __BuiltinArithmeticType, let N : int>
25562603
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
@@ -2560,11 +2607,6 @@ vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr);
25602607
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
25612608
matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr);
25622609

2563-
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2564-
__spirv_version(1.3)
2565-
__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($0))")
2566-
__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & _getLaneLtMask())")
2567-
uint WavePrefixCountBits(bool value);
25682610

25692611
uint WaveMultiPrefixCountBits(bool value, uint4 mask);
25702612

@@ -2576,40 +2618,6 @@ __generic<T : __BuiltinArithmeticType> T WaveMultiPrefixSum(T value, uint4 mask)
25762618
__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask);
25772619
__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask);
25782620

2579-
__generic<T : __BuiltinType>
2580-
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2581-
__spirv_version(1.3)
2582-
__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
2583-
__target_intrinsic(cuda, "_waveReadFirst($0)")
2584-
T WaveReadLaneFirst(T expr);
2585-
__generic<T : __BuiltinType, let N : int>
2586-
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2587-
__spirv_version(1.3)
2588-
__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
2589-
__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
2590-
vector<T,N> WaveReadLaneFirst(vector<T,N> expr);
2591-
__generic<T : __BuiltinType, let N : int, let M : int>
2592-
__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
2593-
matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr);
2594-
2595-
// NOTE! On GLSL based targets the lane index *must* be a compile time expression!
2596-
// See https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt
2597-
__generic<T : __BuiltinType>
2598-
__glsl_extension(GL_KHR_shader_subgroup_ballot)
2599-
__spirv_version(1.3)
2600-
__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
2601-
__target_intrinsic(cuda, "__shfl_sync(__activemask(), $0, $1)")
2602-
T WaveReadLaneAt(T value, int lane);
2603-
__generic<T : __BuiltinType, let N : int>
2604-
__spirv_version(1.3)
2605-
__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
2606-
__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
2607-
vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane);
2608-
__generic<T : __BuiltinType, let N : int, let M : int>
2609-
__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
2610-
matrix<T,N,M> WaveReadLaneAt(matrix<T,N,M> value, int lane);
2611-
2612-
26132621
// `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points
26142622
typedef Texture2D texture2D;
26152623

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute
2+
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
5+
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute
6+
7+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
8+
RWStructuredBuffer<int> outputBuffer;
9+
10+
[numthreads(8, 1, 1)]
11+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
12+
{
13+
int idx = int(dispatchThreadID.x);
14+
int val = WavePrefixProduct(idx + 1);
15+
outputBuffer[idx] = val;
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
1
2+
1
3+
2
4+
6
5+
18
6+
78
7+
2D0
8+
13B0

0 commit comments

Comments
 (0)