Skip to content

Commit a10d9cd

Browse files
authored
WIP Prefix Sum for CUDA (shader-slang#1268)
* Fix some typos. * Add wave-prefix-sum.slang test * First pass at implementing prefixSum. * Small improvments to prefixSum CUDA. * Small improvement to prefix sum. * Enable prefix sum in stdlib.
1 parent 721d2e8 commit a10d9cd

File tree

4 files changed

+76
-10
lines changed

4 files changed

+76
-10
lines changed

prelude/slang-cuda-prelude.h

+41
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,47 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane)
823823
return outVal;
824824
}
825825

826+
__device__ int _wavePrefixSum(int val)
827+
{
828+
const int mask = __activemask();
829+
const int offsetSize = _waveCalcPow2Offset(mask);
830+
831+
const int laneId = _getLaneId();
832+
if (offsetSize > 0)
833+
{
834+
int sum = val;
835+
for (int i = 1; i < offsetSize; i += i)
836+
{
837+
const int readVal = __shfl_up_sync(mask, sum, i, offsetSize);
838+
if (laneId >= i)
839+
{
840+
sum += readVal;
841+
}
842+
}
843+
return sum - val;
844+
}
845+
else
846+
{
847+
int result = 0;
848+
int remaining = mask;
849+
while (remaining)
850+
{
851+
const int laneBit = remaining & -remaining;
852+
// Get the sourceLane
853+
const int srcLane = __ffs(laneBit) - 1;
854+
// Broadcast (can also broadcast to self)
855+
int readValue = __shfl_sync(mask, val, srcLane);
856+
// Only accumulate if srcLane is less than this lane
857+
if (srcLane < laneId)
858+
{
859+
result += readValue;
860+
}
861+
remaining &= ~laneBit;
862+
}
863+
return result;
864+
}
865+
}
866+
826867
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
827868

828869

source/slang/hlsl.meta.slang

+11-10
Original file line numberDiff line numberDiff line change
@@ -2497,37 +2497,38 @@ bool WaveIsFirstLane();
24972497
__generic<T : __BuiltinArithmeticType>
24982498
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
24992499
__spirv_version(1.3)
2500-
__target_intrinsic(glsl, "subgroupExcusiveMul($0)")
2500+
__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
25012501
T WavePrefixProduct(T expr);
25022502
__generic<T : __BuiltinArithmeticType, let N : int>
25032503
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25042504
__spirv_version(1.3)
2505-
__target_intrinsic(glsl, "subgroupExcusiveMul($0)")
2505+
__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
25062506
vector<T,N> WavePrefixProduct(vector<T,N> expr);
25072507
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
25082508
matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr);
25092509

25102510
__generic<T : __BuiltinArithmeticType>
25112511
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25122512
__spirv_version(1.3)
2513-
__target_intrinsic(glsl, "subgroupExcusiveAdd($0)")
2513+
__target_intrinsic(glsl, "subgroupExclusiveAdd($0)")
2514+
__target_intrinsic(cuda, "_wavePrefixSum($0)")
25142515
T WavePrefixSum(T expr);
25152516
__generic<T : __BuiltinArithmeticType, let N : int>
25162517
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25172518
__spirv_version(1.3)
2518-
__target_intrinsic(glsl, "subgroupExcusiveAdd($0)")
2519+
__target_intrinsic(glsl, "subgroupExclusiveAdd($0)")
25192520
vector<T,N> WavePrefixSum(vector<T,N> expr);
25202521
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
25212522
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
25222523

25232524
__generic<T : __BuiltinArithmeticType>
25242525
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25252526
__spirv_version(1.3)
2526-
__target_intrinsic(glsl, "subgroupExcusiveAnd($0)")
2527+
__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
25272528
T WaveMultiPrefixBitAnd(T expr);
25282529
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25292530
__spirv_version(1.3)
2530-
__target_intrinsic(glsl, "subgroupExcusiveAnd($0)")
2531+
__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
25312532
__generic<T : __BuiltinArithmeticType, let N : int>
25322533
vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr);
25332534
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
@@ -2536,25 +2537,25 @@ matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr);
25362537
__generic<T : __BuiltinArithmeticType>
25372538
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25382539
__spirv_version(1.3)
2539-
__target_intrinsic(glsl, "subgroupExcusiveOr($0)")
2540+
__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
25402541
T WaveMultiPrefixBitOr(T expr);
25412542
__generic<T : __BuiltinArithmeticType, let N : int>
25422543
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25432544
__spirv_version(1.3)
2544-
__target_intrinsic(glsl, "subgroupExcusiveOr($0)")
2545+
__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
25452546
vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr);
25462547
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
25472548
matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr);
25482549

25492550
__generic<T : __BuiltinArithmeticType>
25502551
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25512552
__spirv_version(1.3)
2552-
__target_intrinsic(glsl, "subgroupExcusiveXor($0)")
2553+
__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
25532554
T WaveMultiPrefixBitXor(T expr);
25542555
__generic<T : __BuiltinArithmeticType, let N : int>
25552556
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
25562557
__spirv_version(1.3)
2557-
__target_intrinsic(glsl, "subgroupExcusiveXor($0)")
2558+
__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
25582559
vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr);
25592560
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
25602561
matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr);
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute
2+
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
3+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0
4+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
5+
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute
6+
7+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
8+
RWStructuredBuffer<int> outputBuffer;
9+
10+
[numthreads(8, 1, 1)]
11+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
12+
{
13+
int idx = int(dispatchThreadID.x);
14+
int val = WavePrefixSum(1 << idx);
15+
outputBuffer[idx] = val;
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
0
2+
1
3+
3
4+
7
5+
F
6+
1F
7+
3F
8+
7F

0 commit comments

Comments
 (0)