Skip to content

Commit 9960b67

Browse files
author
dmitrygo
committed
[FORK][FIX] DQ IP: allocate aux accums via stack
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
1 parent 1789b1e commit 9960b67

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

src/cpu/x64/brgemm/brgemm_utils.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,6 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
235235
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
236236

237237
max_bcast_block /= adj_ld_block2;
238-
if (brg->with_src_dyn_quant) {
239-
max_bcast_block /= 2;
240-
}
241238

242239
return max_bcast_block;
243240
}

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+15-13
Original file line numberDiff line numberDiff line change
@@ -2298,11 +2298,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
22982298
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
22992299
};
23002300

2301-
auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
2302-
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld);
2303-
return Vmm(idx);
2304-
};
2305-
23062301
auto vmm_zero_point = [&](int ld) {
23072302
int idx = isa_num_vregs(brg.isa_impl) - 3 - ld;
23082303
return Vmm(idx);
@@ -2368,9 +2363,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23682363

23692364
mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
23702365

2366+
const int vec_size = vreg_traits<Vmm>::vlen;
2367+
auto accums_stack_space = bd_e * ld_block2 * vec_size;
2368+
sub(rsp, accums_stack_space);
23712369
for (int bd = bd_b; bd < bd_e; bd++) {
23722370
for (int ld = 0; ld < ld_block2; ld++) {
2373-
auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld);
2371+
auto vmm_accm = accm(ld_block2, bd, ld);
2372+
vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);
2373+
23742374
uni_vxorps(vmm_accm, vmm_accm, vmm_accm);
23752375
}
23762376
}
@@ -2409,14 +2409,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24092409
+ brg.LDB * brg.rd_block * brg.typesize_B]);
24102410
}
24112411
for (int ld = 0; ld < ld_block2; ld++) {
2412-
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
2412+
auto vmm = accm(ld_block2, bd, ld);
24132413
vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24142414
}
24152415
if (brg.with_wei_decomp_zero_points) {
24162416
uni_vpxor(bcst(), bcst(), vmm_neg_one);
24172417
uni_vpsubb(bcst(), bcst(), vmm_neg_one);
24182418
for (int ld = 0; ld < ld_block2; ld++) {
2419-
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
2419+
auto vmm = accm(ld_block2, bd, ld);
24202420
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
24212421
vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24222422
}
@@ -2426,7 +2426,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24262426

24272427
auto reg_local_src_scales = reg_local_wei_zp;
24282428
auto vmm_src_scales = bcst();
2429-
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
2429+
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
24302430

24312431
for (int bd = bd_b; bd < bd_e; bd++) {
24322432
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
@@ -2438,15 +2438,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24382438
}
24392439
}
24402440
for (int ld = 0; ld < ld_block2; ld++) {
2441-
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
24422441
auto vmm_accm = accm(ld_block2, bd, ld);
24432442

2444-
uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux);
2445-
uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
2446-
uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld));
2443+
uni_vcvtdq2ps(vmm_accm, vmm_accm);
2444+
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
2445+
uni_vmulps(load(ld), vmm_accm, load(ld));
2446+
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2447+
uni_vaddps(vmm_accm, vmm_accm, load(ld));
24472448
}
24482449
}
24492450

2451+
add(rsp, accums_stack_space);
24502452
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
24512453
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
24522454

src/cpu/x64/jit_brgemm_inner_product_utils.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -1441,11 +1441,6 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
14411441
jbgp.wei_zero_points_ic_group_size = div_up(jbgp.ic, attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1]);
14421442
}
14431443

1444-
// todo: fix avx2 brgemm kernel behavior for non scalar zp
1445-
if (!is_superset(isa, avx512_core) && attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0] != 1) {
1446-
jbgp.with_src_dynamic_quant = false;
1447-
}
1448-
14491444
jbgp.wei_decomp_zero_points_dt = attr.zero_points_.get_data_type(DNNL_ARG_WEIGHTS);
14501445
if (!one_of(jbgp.wei_decomp_zero_points_dt, f32, u8))
14511446
return status::unimplemented;

0 commit comments

Comments
 (0)