Skip to content

Commit d421730

Browse files
[FORK][FIX] DQ IP: allocate aux accums via stack
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
1 parent 1789b1e commit d421730

File tree

3 files changed

+25
-29
lines changed

3 files changed

+25
-29
lines changed

src/cpu/x64/brgemm/brgemm_utils.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,6 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
235235
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
236236

237237
max_bcast_block /= adj_ld_block2;
238-
if (brg->with_src_dyn_quant) {
239-
max_bcast_block /= 2;
240-
}
241238

242239
return max_bcast_block;
243240
}

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+25-21
Original file line numberDiff line numberDiff line change
@@ -874,12 +874,13 @@ void jit_brgemm_kernel_t<Wmm>::ldb_regs_shift(int ld_block2, bool is_tail) {
874874
mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
875875
}
876876

877-
if (brg.with_wei_decomp) {
877+
if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
878878
mov(reg_aux_wei_scales, ptr[rsp + reg_aux_wei_scales_offs_]);
879879
add(reg_aux_wei_scales, (is_tail) ? wei_scales_offset(1, true) : wei_scales_offset(ld_block2));
880880
mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_aux_wei_scales);
881881
mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_aux_wei_scales);
882-
882+
}
883+
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
883884
mov(reg_aux_wei_zp, ptr[rsp + reg_aux_wei_zero_points_offs_]);
884885
add(reg_aux_wei_zp, (is_tail) ? wei_zp_offset(1, true) : wei_zp_offset(ld_block2));
885886
mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_aux_wei_zp);
@@ -966,10 +967,6 @@ void jit_brgemm_kernel_t<Wmm>::copy_post_ops_stack_values_to_aux(
966967
}
967968

968969
}
969-
if (brg.with_grouped_wei_decomp) {
970-
mov(reg_ic, ptr[rsp + reg_ic_offs_]);
971-
mov(ptr[rsp + reg_aux_ic_offs_], reg_ic);
972-
}
973970
if (brg.with_src_dyn_quant) {
974971
mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
975972
mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales);
@@ -2298,11 +2295,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
22982295
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
22992296
};
23002297

2301-
auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
2302-
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld);
2303-
return Vmm(idx);
2304-
};
2305-
23062298
auto vmm_zero_point = [&](int ld) {
23072299
int idx = isa_num_vregs(brg.isa_impl) - 3 - ld;
23082300
return Vmm(idx);
@@ -2368,9 +2360,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23682360

23692361
mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
23702362

2363+
const int vec_size = vreg_traits<Vmm>::vlen;
2364+
auto accums_stack_space = bd_e * ld_block2 * vec_size;
2365+
sub(rsp, accums_stack_space);
23712366
for (int bd = bd_b; bd < bd_e; bd++) {
23722367
for (int ld = 0; ld < ld_block2; ld++) {
2373-
auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld);
2368+
auto vmm_accm = accm(ld_block2, bd, ld);
2369+
vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);
2370+
23742371
uni_vxorps(vmm_accm, vmm_accm, vmm_accm);
23752372
}
23762373
}
@@ -2409,14 +2406,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24092406
+ brg.LDB * brg.rd_block * brg.typesize_B]);
24102407
}
24112408
for (int ld = 0; ld < ld_block2; ld++) {
2412-
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
2409+
auto vmm = accm(ld_block2, bd, ld);
24132410
vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24142411
}
24152412
if (brg.with_wei_decomp_zero_points) {
24162413
uni_vpxor(bcst(), bcst(), vmm_neg_one);
24172414
uni_vpsubb(bcst(), bcst(), vmm_neg_one);
24182415
for (int ld = 0; ld < ld_block2; ld++) {
2419-
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
2416+
auto vmm = accm(ld_block2, bd, ld);
24202417
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
24212418
vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
24222419
}
@@ -2426,7 +2423,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24262423

24272424
auto reg_local_src_scales = reg_local_wei_zp;
24282425
auto vmm_src_scales = bcst();
2429-
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
2426+
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
24302427

24312428
for (int bd = bd_b; bd < bd_e; bd++) {
24322429
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
@@ -2438,15 +2435,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24382435
}
24392436
}
24402437
for (int ld = 0; ld < ld_block2; ld++) {
2441-
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
24422438
auto vmm_accm = accm(ld_block2, bd, ld);
24432439

2444-
uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux);
2445-
uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
2446-
uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld));
2440+
uni_vcvtdq2ps(vmm_accm, vmm_accm);
2441+
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
2442+
uni_vmulps(load(ld), vmm_accm, load(ld));
2443+
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2444+
uni_vaddps(vmm_accm, vmm_accm, load(ld));
24472445
}
24482446
}
24492447

2448+
add(rsp, accums_stack_space);
24502449
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
24512450
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
24522451

@@ -3014,6 +3013,11 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
30143013
copy_post_ops_stack_values_to_aux(is_reg_tail);
30153014

30163015
auto ld_loop_body = [&](int vpad) {
3016+
if (brg.with_grouped_wei_decomp) {
3017+
mov(reg_ic, ptr[rsp + reg_ic_offs_]);
3018+
mov(ptr[rsp + reg_aux_ic_offs_], reg_ic);
3019+
}
3020+
30173021
set_A_B_matrices();
30183022

30193023
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
@@ -3033,8 +3037,8 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
30333037
mov(reg_rdb_loop, brg.rdb);
30343038
L_aligned(rdb_loop_label, 64);
30353039
{
3036-
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
3037-
brg.wei_decomp_zero_points_stride != 0)) || brg.with_src_dyn_quant) {
3040+
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0))
3041+
|| brg.with_src_dyn_quant) {
30383042
auto reg_local_ic = reg_aux_D;
30393043
auto reg_local_wei_params = reg_bdb_loop;
30403044
auto reg_local_ic_group = reg_ldb_loop;

src/cpu/x64/jit_brgemm_inner_product_utils.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -1441,11 +1441,6 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
14411441
jbgp.wei_zero_points_ic_group_size = div_up(jbgp.ic, attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1]);
14421442
}
14431443

1444-
// todo: fix avx2 brgemm kernel behavior for non scalar zp
1445-
if (!is_superset(isa, avx512_core) && attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0] != 1) {
1446-
jbgp.with_src_dynamic_quant = false;
1447-
}
1448-
14491444
jbgp.wei_decomp_zero_points_dt = attr.zero_points_.get_data_type(DNNL_ARG_WEIGHTS);
14501445
if (!one_of(jbgp.wei_decomp_zero_points_dt, f32, u8))
14511446
return status::unimplemented;

0 commit comments

Comments
 (0)