@@ -2997,6 +2997,95 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
2997
2997
int ld_block2, int ldb_loop_length, bool is_reg_tail, bool is_ld_tail,
2998
2998
bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail,
2999
2999
bool skip_accumulation) {
3000
+ auto ic_group_shift_generic = [&]() {
3001
+ if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0 ))
3002
+ || brg.with_src_dyn_quant ) {
3003
+ auto reg_local_ic = reg_aux_D;
3004
+ auto reg_local_wei_params = reg_bdb_loop;
3005
+ auto reg_local_ic_group = reg_ldb_loop;
3006
+
3007
+ auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) {
3008
+ mov (reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3009
+ mov (reg_local_ic_group, group_size);
3010
+ xor_ (rdx, rdx);
3011
+ idiv (reg_local_ic_group);
3012
+ imul (reg_local_ic, reg_local_ic, stride);
3013
+
3014
+ mov (reg_local_wei_params, ptr[rsp + src_offs]);
3015
+ add (reg_local_wei_params, reg_local_ic);
3016
+ mov (ptr[rsp + dst_offs], reg_local_wei_params);
3017
+ };
3018
+
3019
+ mov (ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
3020
+ mov (ptr[rsp + reg_aux2_D_offs_], reg_aux_D);
3021
+ mov (ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
3022
+ mov (ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv
3023
+
3024
+ if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0 ) {
3025
+ ic_group_shift (reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_,
3026
+ brg.wei_decomp_scales_group_size , brg.wei_decomp_scales_stride * types::data_type_size (brg.wei_decomp_scales_dt ));
3027
+ }
3028
+
3029
+ if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) {
3030
+ ic_group_shift (reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_,
3031
+ brg.wei_decomp_zero_points_group_size , brg.wei_decomp_zero_points_stride * types::data_type_size (brg.wei_decomp_zero_points_dt ));
3032
+ }
3033
+
3034
+ if (brg.with_src_dyn_quant ) {
3035
+ ic_group_shift (reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_,
3036
+ brg.src_scales_group_size , sizeof (float ));
3037
+
3038
+ if (brg.with_wei_decomp_zero_points ) {
3039
+ ic_group_shift (reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3040
+ brg.src_sum_group_size , sizeof (int32_t ));
3041
+ }
3042
+ }
3043
+
3044
+ mov (reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3045
+ add (reg_local_ic, brg.rd_block );
3046
+ mov (ptr[rsp + reg_aux_ic_offs_], reg_local_ic);
3047
+
3048
+ mov (reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3049
+ mov (reg_aux_D, ptr[rsp + reg_aux2_D_offs_]);
3050
+ mov (reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
3051
+ mov (reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]);
3052
+ }
3053
+ };
3054
+
3055
+ auto ic_group_shift_opt = [&](int rb) {
3056
+ mov (ptr[rsp + reg_bdb_loop_offs_], reg_rdb_loop);
3057
+ auto reg_ptr = reg_rdb_loop;
3058
+
3059
+ auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) {
3060
+ if ((rb + 1 ) * brg.rd_block % group_size == 0 ) {
3061
+ mov (reg_ptr, ptr[rsp + src_offs]);
3062
+ add (reg_ptr, stride);
3063
+ mov (ptr[rsp + dst_offs], reg_ptr);
3064
+ }
3065
+ };
3066
+
3067
+ if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0 ) {
3068
+ ic_group_shift (reg_aux2_wei_scales_offs_, reg_aux2_wei_scales_offs_,
3069
+ brg.wei_decomp_scales_group_size , brg.wei_decomp_scales_stride * types::data_type_size (brg.wei_decomp_scales_dt ));
3070
+ }
3071
+
3072
+ if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) {
3073
+ ic_group_shift (reg_aux2_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_,
3074
+ brg.wei_decomp_zero_points_group_size , brg.wei_decomp_zero_points_stride * types::data_type_size (brg.wei_decomp_zero_points_dt ));
3075
+ }
3076
+
3077
+ if (brg.with_src_dyn_quant ) {
3078
+ ic_group_shift (reg_aux2_src_scales_offs_, reg_aux2_src_scales_offs_,
3079
+ brg.src_scales_group_size , sizeof (float ));
3080
+
3081
+ if (brg.with_wei_decomp_zero_points ) {
3082
+ ic_group_shift (reg_aux2_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3083
+ brg.src_sum_group_size , sizeof (int32_t ));
3084
+ }
3085
+ }
3086
+
3087
+ mov (reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3088
+ };
3000
3089
3001
3090
Label ldb_loop_label;
3002
3091
Label BS_loop_label;
@@ -3023,75 +3112,76 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
3023
3112
gemm_microkernel_amx (
3024
3113
bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail);
3025
3114
} else {
3026
- if (brg.rdb > 0 ) {
3027
- Label rdb_loop_label;
3028
- mov (reg_rdb_loop, brg.rdb );
3029
- L_aligned (rdb_loop_label, 64 );
3030
- {
3031
- if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0 ))
3032
- || brg.with_src_dyn_quant ) {
3033
- auto reg_local_ic = reg_aux_D;
3034
- auto reg_local_wei_params = reg_bdb_loop;
3035
- auto reg_local_ic_group = reg_ldb_loop;
3036
-
3037
- auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) {
3038
- mov (reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3039
- mov (reg_local_ic_group, group_size);
3040
- xor_ (rdx, rdx);
3041
- idiv (reg_local_ic_group);
3042
- imul (reg_local_ic, reg_local_ic, stride);
3043
-
3044
- mov (reg_local_wei_params, ptr[rsp + src_offs]);
3045
- add (reg_local_wei_params, reg_local_ic);
3046
- mov (ptr[rsp + dst_offs], reg_local_wei_params);
3047
- };
3048
-
3049
- mov (ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
3050
- mov (ptr[rsp + reg_aux2_D_offs_], reg_aux_D);
3051
- mov (ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
3052
- mov (ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv
3053
-
3054
- if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0 ) {
3055
- ic_group_shift (reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_,
3056
- brg.wei_decomp_scales_group_size , brg.wei_decomp_scales_stride * types::data_type_size (brg.wei_decomp_scales_dt ));
3057
- }
3115
+ ic_group_shift_generic ();
3116
+
3117
+ if (brg.with_src_dyn_quant ) {
3118
+ auto rdb_group = brg.rd_block ;
3119
+ auto rd_size = brg.rdb * brg.rd_block + brg.rdb_tail ;
3120
+ if (brg.wei_decomp_scales_group_size < rd_size)
3121
+ rdb_group = nstl::max (rdb_group, brg.wei_decomp_scales_group_size );
3122
+ if (brg.wei_decomp_zero_points_group_size < rd_size)
3123
+ rdb_group = nstl::max (rdb_group, brg.wei_decomp_zero_points_group_size );
3124
+ if (brg.with_src_dyn_quant ) {
3125
+ rdb_group = nstl::max (rdb_group, brg.src_scales_group_size );
3126
+ if (brg.with_wei_decomp_zero_points ) {
3127
+ rdb_group = nstl::max (rdb_group, brg.src_sum_group_size );
3128
+ }
3129
+ }
3130
+ rdb_group = rdb_group / brg.rd_block ;
3131
+ auto rbd_blocks = brg.rdb / rdb_group;
3058
3132
3059
- if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) {
3060
- ic_group_shift (reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_,
3061
- brg.wei_decomp_zero_points_group_size , brg.wei_decomp_zero_points_stride * types::data_type_size (brg.wei_decomp_zero_points_dt ));
3062
- }
3133
+ if (rbd_blocks > 0 ) {
3134
+ Label rdb_loop_label;
3135
+ mov (reg_rdb_loop, rbd_blocks);
3136
+ L_aligned (rdb_loop_label, 64 );
3137
+ {
3138
+ for (int rb = 0 ; rb < rdb_group; rb++) {
3139
+ gemm_microkernel (bd_block2, is_bdb_tail, ld_block2, false ,
3140
+ is_ld_tail, vpad, rows_for_rd_tail);
3063
3141
3064
- if (brg.with_src_dyn_quant ) {
3065
- ic_group_shift (reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_,
3066
- brg.src_scales_group_size , sizeof (float ));
3142
+ add (reg_aux_A, rdb_A_offset ());
3143
+ add (reg_aux_B, rdb_B_offset ());
3067
3144
3068
- if (brg.with_wei_decomp_zero_points ) {
3069
- ic_group_shift (reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3070
- brg.src_sum_group_size , sizeof (int32_t ));
3071
- }
3145
+ ic_group_shift_opt (rb);
3072
3146
}
3073
3147
3074
- mov (reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
3075
- add (reg_local_ic, brg.rd_block );
3076
- mov (ptr[rsp + reg_aux_ic_offs_], reg_local_ic);
3077
-
3078
- mov (reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3079
- mov (reg_aux_D, ptr[rsp + reg_aux2_D_offs_]);
3080
- mov (reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
3081
- mov (reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]);
3148
+ dec (reg_rdb_loop);
3149
+ cmp (reg_rdb_loop, 0 );
3082
3150
}
3151
+ jg (rdb_loop_label, T_NEAR);
3152
+ }
3083
3153
3084
- const bool is_rd_tail = false ;
3085
- gemm_microkernel (bd_block2, is_bdb_tail, ld_block2,
3086
- is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);
3154
+ for ( int rb = rbd_blocks * rdb_group; rb < brg. rdb ; rb++) {
3155
+ gemm_microkernel (bd_block2, is_bdb_tail, ld_block2, false ,
3156
+ is_ld_tail, vpad, rows_for_rd_tail);
3087
3157
3088
3158
add (reg_aux_A, rdb_A_offset ());
3089
3159
add (reg_aux_B, rdb_B_offset ());
3090
3160
3091
- dec (reg_rdb_loop);
3092
- cmp (reg_rdb_loop, 0 );
3161
+ ic_group_shift_opt (rb);
3162
+
3163
+ mov (reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
3164
+ }
3165
+ } else {
3166
+ if (brg.rdb > 0 ) {
3167
+ Label rdb_loop_label;
3168
+ mov (reg_rdb_loop, brg.rdb );
3169
+ L_aligned (rdb_loop_label, 64 );
3170
+ {
3171
+ const bool is_rd_tail = false ;
3172
+ gemm_microkernel (bd_block2, is_bdb_tail, ld_block2,
3173
+ is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);
3174
+
3175
+ add (reg_aux_A, rdb_A_offset ());
3176
+ add (reg_aux_B, rdb_B_offset ());
3177
+
3178
+ ic_group_shift_generic ();
3179
+
3180
+ dec (reg_rdb_loop);
3181
+ cmp (reg_rdb_loop, 0 );
3182
+ }
3183
+ jg (rdb_loop_label, T_NEAR);
3093
3184
}
3094
- jg (rdb_loop_label, T_NEAR);
3095
3185
}
3096
3186
}
3097
3187
if (brg.rdb_tail != 0 ) {
0 commit comments