Skip to content

Commit 55058f1

Browse files
author
dmitrygo
committed
[FORK][FIX] DQ IP: allocate aux accums via stack
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
1 parent c7ecd8f commit 55058f1

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

src/cpu/x64/brgemm/brgemm_utils.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,6 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
234234
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
235235

236236
max_bcast_block /= adj_ld_block2;
237-
if (brg->with_src_dyn_quant) {
238-
max_bcast_block /= 2;
239-
}
240237

241238
return max_bcast_block;
242239
}

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+15-13
Original file line numberDiff line numberDiff line change
@@ -2298,11 +2298,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
22982298
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
22992299
};
23002300

2301-
auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
2302-
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld);
2303-
return Vmm(idx);
2304-
};
2305-
23062301
auto vmm_zero_point = [&](int ld) {
23072302
int idx = isa_num_vregs(brg.isa_impl) - 3 - ld;
23082303
return Vmm(idx);
@@ -2368,9 +2363,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23682363

23692364
mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
23702365

2366+
const int vec_size = vreg_traits<Vmm>::vlen;
2367+
auto accums_stack_space = bd_e * ld_block2 * vec_size;
2368+
sub(rsp, accums_stack_space);
23712369
for (int bd = bd_b; bd < bd_e; bd++) {
23722370
for (int ld = 0; ld < ld_block2; ld++) {
2373-
auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld);
2371+
auto vmm_accm = accm(ld_block2, bd, ld);
2372+
vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);
2373+
23742374
uni_vxorps(vmm_accm, vmm_accm, vmm_accm);
23752375
}
23762376
}
@@ -2409,14 +2409,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24092409
+ brg.LDB * brg.rd_block * brg.typesize_B]);
24102410
}
24112411
for (int ld = 0; ld < ld_block2; ld++) {
2412-
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
2412+
auto vmm = accm(ld_block2, bd, ld);
24132413
vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24142414
}
24152415
if (brg.with_wei_decomp_zero_points) {
24162416
uni_vpxor(bcst(), bcst(), vmm_neg_one);
24172417
uni_vpsubb(bcst(), bcst(), vmm_neg_one);
24182418
for (int ld = 0; ld < ld_block2; ld++) {
2419-
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
2419+
auto vmm = accm(ld_block2, bd, ld);
24202420
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
24212421
vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24222422
}
@@ -2426,7 +2426,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24262426

24272427
auto reg_local_src_scales = reg_local_wei_zp;
24282428
auto vmm_src_scales = bcst();
2429-
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
2429+
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
24302430

24312431
for (int bd = bd_b; bd < bd_e; bd++) {
24322432
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
@@ -2438,15 +2438,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24382438
}
24392439
}
24402440
for (int ld = 0; ld < ld_block2; ld++) {
2441-
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
24422441
auto vmm_accm = accm(ld_block2, bd, ld);
24432442

2444-
uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux);
2445-
uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
2446-
uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld));
2443+
uni_vcvtdq2ps(vmm_accm, vmm_accm);
2444+
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
2445+
uni_vmulps(load(ld), vmm_accm, load(ld));
2446+
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2447+
uni_vaddps(vmm_accm, vmm_accm, load(ld));
24472448
}
24482449
}
24492450

2451+
add(rsp, accums_stack_space);
24502452
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
24512453
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
24522454

0 commit comments

Comments
 (0)