@@ -2298,11 +2298,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2298
2298
if (brg.req_s8s8_compensation ) uni_vpaddb (v1, v1, vmm_inp_shift ());
2299
2299
};
2300
2300
2301
- auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
2302
- int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block ) - ld_block - (bd * ld_block + ld);
2303
- return Vmm (idx);
2304
- };
2305
-
2306
2301
auto vmm_zero_point = [&](int ld) {
2307
2302
int idx = isa_num_vregs (brg.isa_impl ) - 3 - ld;
2308
2303
return Vmm (idx);
@@ -2368,9 +2363,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2368
2363
2369
2364
mov (reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
2370
2365
2366
+ const int vec_size = vreg_traits<Vmm>::vlen;
2367
+ auto accums_stack_space = bd_e * ld_block2 * vec_size;
2368
+ sub (rsp, accums_stack_space);
2371
2369
for (int bd = bd_b; bd < bd_e; bd++) {
2372
2370
for (int ld = 0 ; ld < ld_block2; ld++) {
2373
- auto vmm_accm = vmm_accm_tmp (ld_block2, bd, ld);
2371
+ auto vmm_accm = accm (ld_block2, bd, ld);
2372
+ vmovups (ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);
2373
+
2374
2374
uni_vxorps (vmm_accm, vmm_accm, vmm_accm);
2375
2375
}
2376
2376
}
@@ -2409,14 +2409,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2409
2409
+ brg.LDB * brg.rd_block * brg.typesize_B ]);
2410
2410
}
2411
2411
for (int ld = 0 ; ld < ld_block2; ld++) {
2412
- auto vmm = vmm_accm_tmp (ld_block2, bd, ld);
2412
+ auto vmm = accm (ld_block2, bd, ld);
2413
2413
vpdpbusd (vmm, load (ld), bcst (), is_superset (brg.isa_impl , avx512_core) ? EvexEncoding : VexEncoding);
2414
2414
}
2415
2415
if (brg.with_wei_decomp_zero_points ) {
2416
2416
uni_vpxor (bcst (), bcst (), vmm_neg_one);
2417
2417
uni_vpsubb (bcst (), bcst (), vmm_neg_one);
2418
2418
for (int ld = 0 ; ld < ld_block2; ld++) {
2419
- auto vmm = vmm_accm_tmp (ld_block2, bd, ld);
2419
+ auto vmm = accm (ld_block2, bd, ld);
2420
2420
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point (0 ) : vmm_zero_point (ld);
2421
2421
vpdpbusd (vmm, vmm_zp, bcst (), is_superset (brg.isa_impl , avx512_core) ? EvexEncoding : VexEncoding);
2422
2422
}
@@ -2426,7 +2426,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2426
2426
2427
2427
auto reg_local_src_scales = reg_local_wei_zp;
2428
2428
auto vmm_src_scales = bcst ();
2429
- mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
2429
+ mov (reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space ]);
2430
2430
2431
2431
for (int bd = bd_b; bd < bd_e; bd++) {
2432
2432
uni_vbroadcastss (vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof (float )]);
@@ -2438,15 +2438,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
2438
2438
}
2439
2439
}
2440
2440
for (int ld = 0 ; ld < ld_block2; ld++) {
2441
- auto vmm_accm_aux = vmm_accm_tmp (ld_block2, bd, ld);
2442
2441
auto vmm_accm = accm (ld_block2, bd, ld);
2443
2442
2444
- uni_vcvtdq2ps (vmm_accm_aux, vmm_accm_aux);
2445
- uni_vmulps (vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
2446
- uni_vfmadd231ps (vmm_accm, vmm_accm_aux, load (ld));
2443
+ uni_vcvtdq2ps (vmm_accm, vmm_accm);
2444
+ uni_vmulps (vmm_accm, vmm_accm, vmm_src_scales);
2445
+ uni_vmulps (load (ld), vmm_accm, load (ld));
2446
+ uni_vmovups (vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2447
+ uni_vaddps (vmm_accm, vmm_accm, load (ld));
2447
2448
}
2448
2449
}
2449
2450
2451
+ add (rsp, accums_stack_space);
2450
2452
mov (reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
2451
2453
mov (reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
2452
2454
0 commit comments