@@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
696
696
return a;
697
697
}
698
698
699
- //static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
700
- //#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
701
- //#pragma unroll
702
- // for (int mask = 16; mask > 0; mask >>= 1) {
703
- // a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
704
- // }
705
- // return a;
706
- //#else
707
- // (void) a;
708
- // NO_DEVICE_CODE;
709
- //#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
710
- //}
699
+ #ifdef GGML_CUDA_F16
700
+ static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
701
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
702
+ #pragma unroll
703
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
704
+ a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
705
+ }
706
+ return a;
707
+ #else
708
+ (void ) a;
709
+ NO_DEVICE_CODE;
710
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
711
+ }
712
+ #endif // GGML_CUDA_F16
711
713
712
714
static __device__ __forceinline__ float warp_reduce_max (float x) {
713
715
#pragma unroll
@@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
2521
2523
#endif
2522
2524
2523
2525
// sum up partial sums and write back result
2524
- #pragma unroll
2525
- for (int mask = 16; mask > 0; mask >>= 1) {
2526
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2527
- }
2526
+ tmp = warp_reduce_sum (tmp);
2528
2527
2529
2528
if (threadIdx .x == 0 ) {
2530
2529
dst[row] = tmp;
@@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
2625
2624
#endif
2626
2625
2627
2626
// sum up partial sums and write back result
2628
- #pragma unroll
2629
- for (int mask = 16; mask > 0; mask >>= 1) {
2630
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2631
- }
2627
+ tmp = warp_reduce_sum (tmp);
2632
2628
2633
2629
if (threadIdx .x == 0 ) {
2634
2630
dst[row] = tmp;
@@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
2761
2757
#endif
2762
2758
2763
2759
// sum up partial sums and write back result
2764
- #pragma unroll
2765
- for (int mask = 16; mask > 0; mask >>= 1) {
2766
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2767
- }
2760
+ tmp = warp_reduce_sum (tmp);
2768
2761
2769
2762
if (tid == 0 ) {
2770
2763
dst[row] = tmp;
@@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
2877
2870
#endif
2878
2871
2879
2872
// sum up partial sums and write back result
2880
- #pragma unroll
2881
- for (int mask = 16; mask > 0; mask >>= 1) {
2882
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2883
- }
2873
+ tmp = warp_reduce_sum (tmp);
2884
2874
2885
2875
if (threadIdx .x == 0 ) {
2886
2876
dst[row] = tmp;
@@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
2987
2977
#endif
2988
2978
2989
2979
// sum up partial sums and write back result
2990
- #pragma unroll
2991
- for (int mask = 16; mask > 0; mask >>= 1) {
2992
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2993
- }
2980
+ tmp = warp_reduce_sum (tmp);
2994
2981
2995
2982
if (tid == 0 ) {
2996
2983
dst[row] = tmp;
@@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
3025
3012
float amax = fabsf (xi);
3026
3013
float sum = xi;
3027
3014
3028
- #pragma unroll
3029
- for (int mask = 16; mask > 0; mask >>= 1) {
3030
- amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
3031
- sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
3032
- }
3015
+ amax = warp_reduce_max (amax);
3016
+ sum = warp_reduce_sum (sum);
3033
3017
3034
3018
const float d = amax / 127 ;
3035
3019
const int8_t q = amax == 0 .0f ? 0 : roundf (xi / d);
@@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
6222
6206
}
6223
6207
6224
6208
// sum up partial sums and write back result
6225
- #pragma unroll
6226
- for (int mask = 16; mask > 0; mask >>= 1) {
6227
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6228
- }
6209
+ tmp = warp_reduce_sum (tmp);
6229
6210
6230
6211
if (tid == 0 ) {
6231
6212
#ifdef GGML_CUDA_F16
@@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
6275
6256
const int idst = channel*nrows_dst + row_dst;
6276
6257
6277
6258
// sum up partial sums and write back result
6278
- #pragma unroll
6279
- for (int mask = 16; mask > 0; mask >>= 1) {
6280
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6281
- }
6259
+ tmp = warp_reduce_sum (tmp);
6282
6260
6283
6261
if (threadIdx .x == 0 ) {
6284
6262
dst[idst] = tmp;
@@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
6321
6299
}
6322
6300
6323
6301
// sum up partial sums and write back result
6324
- #pragma unroll
6325
- for (int mask = 16; mask > 0; mask >>= 1) {
6326
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6327
- }
6302
+ tmp = warp_reduce_sum (tmp);
6328
6303
6329
6304
if (threadIdx .x == 0 ) {
6330
6305
dst[idst] = tmp;
0 commit comments