@@ -203,6 +203,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
203
203
const reg64_t reg_aux_wei_zp = reg_rdb_loop;
204
204
const reg64_t reg_ic = reg_rdb_loop;
205
205
const reg64_t reg_src_scales = reg_rdb_loop;
206
+ const reg64_t reg_src_grouped_sum = reg_rdb_loop;
206
207
const reg64_t reg_tmp_read_values = reg_rdb_loop;
207
208
208
209
const reg64_t reg_aux_scales = reg_aux_B;
@@ -280,12 +281,13 @@ struct jit_brgemm_kernel_t : public jit_generator {
280
281
constexpr static int reg_src_scales_offs_ = 336 ;
281
282
constexpr static int reg_aux_src_scales_offs_ = 344 ;
282
283
constexpr static int reg_aux2_src_scales_offs_ = 352 ;
283
- // constexpr static int stack_space_needed_ = 360;
284
+ constexpr static int reg_src_grouped_sum_offs_ = 360 ;
285
+ constexpr static int reg_aux_src_grouped_sum_offs_ = 368 ;
286
+ constexpr static int reg_aux2_src_grouped_sum_offs_ = 376 ;
284
287
// these are used for FP8 as temporary push/pop spaces
285
- constexpr static int reg_val_tmp_1_ = 368 ;
286
- constexpr static int reg_val_tmp_2_ = 376 ;
287
- constexpr static int stack_space_needed_ = 384 ;
288
- // regsiters for dynamic quant
288
+ constexpr static int reg_val_tmp_1_ = 384 ;
289
+ constexpr static int reg_val_tmp_2_ = 392 ;
290
+ constexpr static int stack_space_needed_ = 400 ;
289
291
290
292
291
293
bool is_ldb_loop_ = false ;
@@ -323,7 +325,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
323
325
}
324
326
325
327
if (brg.with_src_dyn_quant ) {
326
- used_vregs += 2 ;
328
+ used_vregs += 1 ;
327
329
}
328
330
329
331
if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) {
@@ -971,6 +973,12 @@ void jit_brgemm_kernel_t<Wmm>::copy_post_ops_stack_values_to_aux(
971
973
mov (reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
972
974
mov (ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales);
973
975
mov (ptr[rsp + reg_aux2_src_scales_offs_], reg_src_scales);
976
+
977
+ if (brg.with_wei_decomp_zero_points ) {
978
+ mov (reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]);
979
+ mov (ptr[rsp + reg_aux_src_grouped_sum_offs_], reg_src_grouped_sum);
980
+ mov (ptr[rsp + reg_aux2_src_grouped_sum_offs_], reg_src_grouped_sum);
981
+ }
974
982
}
975
983
if (brg.zp_type_b != brgemm_broadcast_t ::none) {
976
984
mov (reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
@@ -1048,6 +1056,9 @@ void jit_brgemm_kernel_t<Wmm>::read_params() {
1048
1056
if (brg.with_src_dyn_quant ) {
1049
1057
mov (reg_src_scales, ptr[param1 + GET_OFF (ptr_src_scales)]);
1050
1058
mov (ptr[rsp + reg_src_scales_offs_], reg_src_scales);
1059
+
1060
+ mov (reg_src_grouped_sum, ptr[param1 + GET_OFF (ptr_src_grouped_sum)]);
1061
+ mov (ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum);
1051
1062
}
1052
1063
1053
1064
if (brg.zp_type_c != brgemm_broadcast_t ::none) {
@@ -2296,21 +2307,10 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2296
2307
};
2297
2308
2298
2309
auto vmm_zero_point = [&](int ld) {
2299
- int idx = isa_num_vregs (brg.isa_impl ) - 3 - ld;
2310
+ int idx = isa_num_vregs (brg.isa_impl ) - 2 - ld;
2300
2311
return Vmm (idx);
2301
2312
};
2302
2313
2303
- static const int8_t negative_one[64 ] = {
2304
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2305
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2306
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2307
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2308
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2309
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2310
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
2311
- -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1
2312
- };
2313
-
2314
2314
static const int8_t mask_low_half[64 ] = {
2315
2315
0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F ,
2316
2316
0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F , 0x0F ,
@@ -2328,33 +2328,18 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2328
2328
if (brg.with_wei_decomp_zero_points ) {
2329
2329
mov (reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]);
2330
2330
if (brg.wei_decomp_zero_points_stride == 0 ) {
2331
- auto reg_ptr_8 = Reg8 (reg_ptr.getIdx ());
2332
- mov (reg_ptr_8, ptr[reg_local_wei_zp]);
2333
- uni_vpbroadcastb (vmm_zero_point (0 ), reg_ptr_8);
2331
+ auto reg_ptr_32 = Reg32 (reg_ptr.getIdx ());
2332
+ movzx (reg_ptr_32, ptr[reg_local_wei_zp]);
2333
+ uni_vmovq (Xmm (vmm_zero_point (0 ).getIdx ()), reg_ptr);
2334
+ uni_vbroadcastss (vmm_zero_point (0 ), Xmm (vmm_zero_point (0 ).getIdx ()));
2334
2335
} else {
2335
- static const int8_t index_table[64 ] = {
2336
- 0x00 , 0x00 , 0x00 , 0x00 , 0x04 , 0x04 , 0x04 , 0x04 , 0x08 , 0x08 , 0x08 , 0x08 , 0x0C , 0x0C , 0x0C , 0x0C ,
2337
- 0x00 , 0x00 , 0x00 , 0x00 , 0x04 , 0x04 , 0x04 , 0x04 , 0x08 , 0x08 , 0x08 , 0x08 , 0x0C , 0x0C , 0x0C , 0x0C ,
2338
- 0x00 , 0x00 , 0x00 , 0x00 , 0x04 , 0x04 , 0x04 , 0x04 , 0x08 , 0x08 , 0x08 , 0x08 , 0x0C , 0x0C , 0x0C , 0x0C ,
2339
- 0x00 , 0x00 , 0x00 , 0x00 , 0x04 , 0x04 , 0x04 , 0x04 , 0x08 , 0x08 , 0x08 , 0x08 , 0x0C , 0x0C , 0x0C , 0x0C
2340
- };
2341
-
2342
- auto vmm_indexes = Vmm (isa_num_vregs (brg.isa_impl ) - 1 );
2343
- mov (reg_ptr, (size_t )index_table);
2344
- uni_vmovups (vmm_indexes, ptr[reg_ptr]);
2345
-
2346
2336
for (int ld = 0 ; ld < ld_block2; ld++) {
2347
2337
uni_vpmovzxbd (vmm_zero_point (ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size (brg.wei_decomp_zero_points_dt )]);
2348
- vpshufb (vmm_zero_point (ld), vmm_zero_point (ld), vmm_indexes);
2349
2338
}
2350
2339
}
2351
2340
}
2352
2341
2353
- auto vmm_neg_one = Vmm (isa_num_vregs (brg.isa_impl ) - 1 );
2354
- mov (reg_ptr, (size_t )negative_one);
2355
- uni_vmovups (vmm_neg_one, ptr[reg_ptr]);
2356
-
2357
- auto vmm_mask_low_half = Vmm (isa_num_vregs (brg.isa_impl ) - 2 );
2342
+ auto vmm_mask_low_half = Vmm (isa_num_vregs (brg.isa_impl ) - 1 );
2358
2343
mov (reg_ptr, (size_t )mask_low_half);
2359
2344
uni_vmovups (vmm_mask_low_half, ptr[reg_ptr]);
2360
2345
@@ -2409,22 +2394,28 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2409
2394
auto vmm = accm (ld_block2, bd, ld);
2410
2395
vpdpbusd (vmm, load (ld), bcst (), is_superset (brg.isa_impl , avx512_core) ? EvexEncoding : VexEncoding);
2411
2396
}
2412
- if (brg.with_wei_decomp_zero_points ) {
2413
- uni_vpxor (bcst (), bcst (), vmm_neg_one);
2414
- uni_vpsubb (bcst (), bcst (), vmm_neg_one);
2415
- for (int ld = 0 ; ld < ld_block2; ld++) {
2416
- auto vmm = accm (ld_block2, bd, ld);
2417
- Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point (0 ) : vmm_zero_point (ld);
2418
- vpdpbusd (vmm, vmm_zp, bcst (), is_superset (brg.isa_impl , avx512_core) ? EvexEncoding : VexEncoding);
2419
- }
2420
- }
2421
2397
}
2422
2398
}
2423
2399
2424
2400
auto reg_local_src_scales = reg_local_wei_zp;
2401
+ auto reg_local_src_grouped_sum = reg_local_wei_zp;
2425
2402
auto vmm_src_scales = bcst ();
2426
- mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space] );
2403
+ auto vmm_src_grouped_sum = bcst ( );
2427
2404
2405
+ if (brg.with_wei_decomp_zero_points ) {
2406
+ mov (reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]);
2407
+ for (int bd = bd_b; bd < bd_e; bd++) {
2408
+ for (int ld = 0 ; ld < ld_block2; ld++) {
2409
+ auto vmm_accm = accm (ld_block2, bd, ld);
2410
+ Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point (0 ) : vmm_zero_point (ld);
2411
+ uni_vbroadcastss (vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof (int32_t )]);
2412
+ uni_vpmulld (vmm_src_grouped_sum, vmm_src_grouped_sum, vmm_zp);
2413
+ uni_vpsubd (vmm_accm, vmm_accm, vmm_src_grouped_sum);
2414
+ }
2415
+ }
2416
+ }
2417
+
2418
+ mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
2428
2419
for (int bd = bd_b; bd < bd_e; bd++) {
2429
2420
uni_vbroadcastss (vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof (float )]);
2430
2421
for (int ld = 0 ; ld < ld_block2; ld++) {
@@ -3073,6 +3064,11 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
3073
3064
if (brg.with_src_dyn_quant ) {
3074
3065
ic_group_shift (reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_,
3075
3066
brg.src_scales_group_size , sizeof (float ));
3067
+
3068
+ if (brg.with_wei_decomp_zero_points ) {
3069
+ ic_group_shift (reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_,
3070
+ brg.src_sum_group_size , sizeof (int32_t ));
3071
+ }
3076
3072
}
3077
3073
3078
3074
mov (reg_local_ic, ptr[rsp + reg_aux_ic_offs_]);
@@ -3306,6 +3302,10 @@ void jit_brgemm_kernel_t<Wmm>::bdb_loop() {
3306
3302
mov (reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
3307
3303
add (reg_src_scales, bd_block2 * brg.bd_block * brg.src_scales_stride * sizeof (float ));
3308
3304
mov (ptr[rsp + reg_src_scales_offs_], reg_src_scales);
3305
+
3306
+ mov (reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]);
3307
+ add (reg_src_grouped_sum, bd_block2 * brg.bd_block * brg.src_grouped_sum_stride * sizeof (int32_t ));
3308
+ mov (ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum);
3309
3309
}
3310
3310
3311
3311
advance_bd_block2_post_op_regs (bd_block2);
0 commit comments