Skip to content

Commit c24a2a6

Browse files
authored
cuda : replace remaining shfl_xor with calls to warp_reduce functions (#5744)
1 parent 1f30b7a commit c24a2a6

File tree

1 file changed

+24
-49
lines changed

1 file changed

+24
-49
lines changed

ggml-cuda.cu

+24-49
Original file line numberDiff line numberDiff line change
@@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
696696
return a;
697697
}
698698

699-
//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
700-
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
701-
//#pragma unroll
702-
// for (int mask = 16; mask > 0; mask >>= 1) {
703-
// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
704-
// }
705-
// return a;
706-
//#else
707-
// (void) a;
708-
// NO_DEVICE_CODE;
709-
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
710-
//}
699+
#ifdef GGML_CUDA_F16
700+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
701+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
702+
#pragma unroll
703+
for (int mask = 16; mask > 0; mask >>= 1) {
704+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
705+
}
706+
return a;
707+
#else
708+
(void) a;
709+
NO_DEVICE_CODE;
710+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
711+
}
712+
#endif // GGML_CUDA_F16
711713

712714
static __device__ __forceinline__ float warp_reduce_max(float x) {
713715
#pragma unroll
@@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
25212523
#endif
25222524

25232525
// sum up partial sums and write back result
2524-
#pragma unroll
2525-
for (int mask = 16; mask > 0; mask >>= 1) {
2526-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2527-
}
2526+
tmp = warp_reduce_sum(tmp);
25282527

25292528
if (threadIdx.x == 0) {
25302529
dst[row] = tmp;
@@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
26252624
#endif
26262625

26272626
// sum up partial sums and write back result
2628-
#pragma unroll
2629-
for (int mask = 16; mask > 0; mask >>= 1) {
2630-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2631-
}
2627+
tmp = warp_reduce_sum(tmp);
26322628

26332629
if (threadIdx.x == 0) {
26342630
dst[row] = tmp;
@@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
27612757
#endif
27622758

27632759
// sum up partial sums and write back result
2764-
#pragma unroll
2765-
for (int mask = 16; mask > 0; mask >>= 1) {
2766-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2767-
}
2760+
tmp = warp_reduce_sum(tmp);
27682761

27692762
if (tid == 0) {
27702763
dst[row] = tmp;
@@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
28772870
#endif
28782871

28792872
// sum up partial sums and write back result
2880-
#pragma unroll
2881-
for (int mask = 16; mask > 0; mask >>= 1) {
2882-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2883-
}
2873+
tmp = warp_reduce_sum(tmp);
28842874

28852875
if (threadIdx.x == 0) {
28862876
dst[row] = tmp;
@@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
29872977
#endif
29882978

29892979
// sum up partial sums and write back result
2990-
#pragma unroll
2991-
for (int mask = 16; mask > 0; mask >>= 1) {
2992-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2993-
}
2980+
tmp = warp_reduce_sum(tmp);
29942981

29952982
if (tid == 0) {
29962983
dst[row] = tmp;
@@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
30253012
float amax = fabsf(xi);
30263013
float sum = xi;
30273014

3028-
#pragma unroll
3029-
for (int mask = 16; mask > 0; mask >>= 1) {
3030-
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
3031-
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
3032-
}
3015+
amax = warp_reduce_max(amax);
3016+
sum = warp_reduce_sum(sum);
30333017

30343018
const float d = amax / 127;
30353019
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
@@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
62226206
}
62236207

62246208
// sum up partial sums and write back result
6225-
#pragma unroll
6226-
for (int mask = 16; mask > 0; mask >>= 1) {
6227-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6228-
}
6209+
tmp = warp_reduce_sum(tmp);
62296210

62306211
if (tid == 0) {
62316212
#ifdef GGML_CUDA_F16
@@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
62756256
const int idst = channel*nrows_dst + row_dst;
62766257

62776258
// sum up partial sums and write back result
6278-
#pragma unroll
6279-
for (int mask = 16; mask > 0; mask >>= 1) {
6280-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6281-
}
6259+
tmp = warp_reduce_sum(tmp);
62826260

62836261
if (threadIdx.x == 0) {
62846262
dst[idst] = tmp;
@@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
63216299
}
63226300

63236301
// sum up partial sums and write back result
6324-
#pragma unroll
6325-
for (int mask = 16; mask > 0; mask >>= 1) {
6326-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6327-
}
6302+
tmp = warp_reduce_sum(tmp);
63286303

63296304
if (threadIdx.x == 0) {
63306305
dst[idst] = tmp;

0 commit comments

Comments
 (0)