@@ -874,12 +874,13 @@ void jit_brgemm_kernel_t<Wmm>::ldb_regs_shift(int ld_block2, bool is_tail) {
874
874
mov (ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
875
875
}
876
876
877
- if (brg.with_wei_decomp ) {
877
+ if (brg.with_wei_decomp_scales && brg. wei_decomp_scales_stride != 0 ) {
878
878
mov (reg_aux_wei_scales, ptr[rsp + reg_aux_wei_scales_offs_]);
879
879
add (reg_aux_wei_scales, (is_tail) ? wei_scales_offset (1 , true ) : wei_scales_offset (ld_block2));
880
880
mov (ptr[rsp + reg_aux_wei_scales_offs_], reg_aux_wei_scales);
881
881
mov (ptr[rsp + reg_aux2_wei_scales_offs_], reg_aux_wei_scales);
882
-
882
+ }
883
+ if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) {
883
884
mov (reg_aux_wei_zp, ptr[rsp + reg_aux_wei_zero_points_offs_]);
884
885
add (reg_aux_wei_zp, (is_tail) ? wei_zp_offset (1 , true ) : wei_zp_offset (ld_block2));
885
886
mov (ptr[rsp + reg_aux_wei_zero_points_offs_], reg_aux_wei_zp);
@@ -966,10 +967,6 @@ void jit_brgemm_kernel_t<Wmm>::copy_post_ops_stack_values_to_aux(
966
967
}
967
968
968
969
}
969
- if (brg.with_grouped_wei_decomp ) {
970
- mov (reg_ic, ptr[rsp + reg_ic_offs_]);
971
- mov (ptr[rsp + reg_aux_ic_offs_], reg_ic);
972
- }
973
970
if (brg.with_src_dyn_quant ) {
974
971
mov (reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
975
972
mov (ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales);
@@ -2298,11 +2295,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2298
2295
if (brg.req_s8s8_compensation ) uni_vpaddb (v1, v1, vmm_inp_shift ());
2299
2296
};
2300
2297
2301
- auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
2302
- int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block ) - ld_block - (bd * ld_block + ld);
2303
- return Vmm (idx);
2304
- };
2305
-
2306
2298
auto vmm_zero_point = [&](int ld) {
2307
2299
int idx = isa_num_vregs (brg.isa_impl ) - 3 - ld;
2308
2300
return Vmm (idx);
@@ -2368,9 +2360,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2368
2360
2369
2361
mov (reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
2370
2362
2363
+ const int vec_size = vreg_traits<Vmm>::vlen;
2364
+ auto accums_stack_space = bd_e * ld_block2 * vec_size;
2365
+ sub (rsp, accums_stack_space);
2371
2366
for (int bd = bd_b; bd < bd_e; bd++) {
2372
2367
for (int ld = 0 ; ld < ld_block2; ld++) {
2373
- auto vmm_accm = vmm_accm_tmp (ld_block2, bd, ld);
2368
+ auto vmm_accm = accm (ld_block2, bd, ld);
2369
+ vmovups (ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);
2370
+
2374
2371
uni_vxorps (vmm_accm, vmm_accm, vmm_accm);
2375
2372
}
2376
2373
}
@@ -2409,14 +2406,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2409
2406
+ brg.LDB * brg.rd_block * brg.typesize_B ]);
2410
2407
}
2411
2408
for (int ld = 0 ; ld < ld_block2; ld++) {
2412
- auto vmm = vmm_accm_tmp (ld_block2, bd, ld);
2409
+ auto vmm = accm (ld_block2, bd, ld);
2413
2410
vpdpbusd (vmm, load (ld), bcst (), is_superset (brg.isa_impl , avx512_core) ? EvexEncoding : VexEncoding);
2414
2411
}
2415
2412
if (brg.with_wei_decomp_zero_points ) {
2416
2413
uni_vpxor (bcst (), bcst (), vmm_neg_one);
2417
2414
uni_vpsubb (bcst (), bcst (), vmm_neg_one);
2418
2415
for (int ld = 0 ; ld < ld_block2; ld++) {
2419
- auto vmm = vmm_accm_tmp (ld_block2, bd, ld);
2416
+ auto vmm = accm (ld_block2, bd, ld);
2420
2417
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point (0 ) : vmm_zero_point (ld);
2421
2418
vpdpbusd (vmm, vmm_zp, bcst (), is_superset (brg.isa_impl , avx512_core) ? EvexEncoding : VexEncoding);
2422
2419
}
@@ -2426,7 +2423,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2426
2423
2427
2424
auto reg_local_src_scales = reg_local_wei_zp;
2428
2425
auto vmm_src_scales = bcst ();
2429
- mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
2426
+ mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space ]);
2430
2427
2431
2428
for (int bd = bd_b; bd < bd_e; bd++) {
2432
2429
uni_vbroadcastss (vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof (float )]);
@@ -2438,15 +2435,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2438
2435
}
2439
2436
}
2440
2437
for (int ld = 0 ; ld < ld_block2; ld++) {
2441
- auto vmm_accm_aux = vmm_accm_tmp (ld_block2, bd, ld);
2442
2438
auto vmm_accm = accm (ld_block2, bd, ld);
2443
2439
2444
- uni_vcvtdq2ps (vmm_accm_aux, vmm_accm_aux);
2445
- uni_vmulps (vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
2446
- uni_vfmadd231ps (vmm_accm, vmm_accm_aux, load (ld));
2440
+ uni_vcvtdq2ps (vmm_accm, vmm_accm);
2441
+ uni_vmulps (vmm_accm, vmm_accm, vmm_src_scales);
2442
+ uni_vmulps (load (ld), vmm_accm, load (ld));
2443
+ uni_vmovups (vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2444
+ uni_vaddps (vmm_accm, vmm_accm, load (ld));
2447
2445
}
2448
2446
}
2449
2447
2448
+ add (rsp, accums_stack_space);
2450
2449
mov (reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
2451
2450
mov (reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
2452
2451
@@ -3014,6 +3013,11 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
3014
3013
copy_post_ops_stack_values_to_aux (is_reg_tail);
3015
3014
3016
3015
auto ld_loop_body = [&](int vpad) {
3016
+ if (brg.with_grouped_wei_decomp ) {
3017
+ mov (reg_ic, ptr[rsp + reg_ic_offs_]);
3018
+ mov (ptr[rsp + reg_aux_ic_offs_], reg_ic);
3019
+ }
3020
+
3017
3021
set_A_B_matrices ();
3018
3022
3019
3023
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block ;
@@ -3033,8 +3037,8 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
3033
3037
mov (reg_rdb_loop, brg.rdb );
3034
3038
L_aligned (rdb_loop_label, 64 );
3035
3039
{
3036
- if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
3037
- brg. wei_decomp_zero_points_stride != 0 )) || brg.with_src_dyn_quant ) {
3040
+ if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg. wei_decomp_zero_points_stride != 0 ))
3041
+ || brg.with_src_dyn_quant ) {
3038
3042
auto reg_local_ic = reg_aux_D;
3039
3043
auto reg_local_wei_params = reg_bdb_loop;
3040
3044
auto reg_local_ic_group = reg_ldb_loop;
0 commit comments