Skip to content

Commit bc4e68a

Browse files
author
dmitrygo
committed
[FORK][FEATURE] DQ IP: reduce aux vecs counts required for microkernel
1 parent 77250c1 commit bc4e68a

File tree

2 files changed

+46
-49
lines changed

2 files changed

+46
-49
lines changed

src/cpu/x64/brgemm/brgemm_utils.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,8 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
230230
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5;
231231
if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2;
232232
if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1;
233-
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
233+
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1;
234234
if (brg->with_src_dyn_quant) max_bcast_block -= 1;
235-
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
236235

237236
max_bcast_block /= adj_ld_block2;
238237

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+45-47
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,13 @@ struct jit_brgemm_kernel_t : public jit_generator {
320320
used_vregs += 1;
321321
}
322322

323-
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0) {
323+
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 && !brg.with_src_dyn_quant) {
324324
used_vregs += 1;
325325
}
326326

327327
if (brg.with_src_dyn_quant) {
328328
used_vregs += 1;
329329
}
330-
331-
if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
332-
used_vregs += brg.ld_block2;
333-
}
334330
return isa_num_vregs(brg.isa_impl) - used_vregs;
335331
}
336332

@@ -2306,11 +2302,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23062302
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
23072303
};
23082304

2309-
auto vmm_zero_point = [&](int ld) {
2310-
int idx = isa_num_vregs(brg.isa_impl) - 2 - ld;
2311-
return Vmm(idx);
2312-
};
2313-
23142305
static const int8_t mask_low_half[64] = {
23152306
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
23162307
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
@@ -2321,30 +2312,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23212312
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
23222313
mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
23232314

2324-
auto reg_local_wei_scales = reg_bdb_loop;
2325-
auto reg_local_wei_zp = reg_ldb_loop;
2326-
auto reg_ptr = reg_local_wei_scales;
2327-
2328-
if (brg.with_wei_decomp_zero_points) {
2329-
mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]);
2330-
if (brg.wei_decomp_zero_points_stride == 0) {
2331-
auto reg_ptr_32 = Reg32(reg_ptr.getIdx());
2332-
movzx(reg_ptr_32, ptr[reg_local_wei_zp]);
2333-
uni_vmovq(Xmm(vmm_zero_point(0).getIdx()), reg_ptr);
2334-
uni_vbroadcastss(vmm_zero_point(0), Xmm(vmm_zero_point(0).getIdx()));
2335-
} else {
2336-
for (int ld = 0; ld < ld_block2; ld++) {
2337-
uni_vpmovzxbd(vmm_zero_point(ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]);
2338-
}
2339-
}
2340-
}
2341-
2315+
auto reg_ptr = reg_bdb_loop;
23422316
auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 1);
23432317
mov(reg_ptr, (size_t)mask_low_half);
23442318
uni_vmovups(vmm_mask_low_half, ptr[reg_ptr]);
23452319

2346-
mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);
2347-
23482320
const int vec_size = vreg_traits<Vmm>::vlen;
23492321
auto accums_stack_space = bd_e * ld_block2 * vec_size;
23502322
sub(rsp, accums_stack_space);
@@ -2397,42 +2369,68 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
23972369
}
23982370
}
23992371

2400-
auto reg_local_src_scales = reg_local_wei_zp;
2401-
auto reg_local_src_grouped_sum = reg_local_wei_zp;
2402-
auto vmm_src_scales = bcst();
2372+
auto vmm_zero_point = [&](int ld) {
2373+
return load(ld);
2374+
};
2375+
2376+
auto reg_local_wei_zp = reg_ldb_loop;
2377+
auto reg_local_src_grouped_sum = reg_bdb_loop;
2378+
auto vmm_tmp = Vmm(isa_num_vregs(brg.isa_impl) - 1);
24032379
auto vmm_src_grouped_sum = bcst();
24042380

24052381
if (brg.with_wei_decomp_zero_points) {
2382+
mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_ + accums_stack_space]);
2383+
if (brg.wei_decomp_zero_points_stride == 0) {
2384+
Vmm vmm_zp = vmm_zero_point(0);
2385+
auto reg_ptr_32 = Reg32(reg_ptr.getIdx());
2386+
movzx(reg_ptr_32, ptr[reg_local_wei_zp]);
2387+
uni_vmovq(Xmm(vmm_zp.getIdx()), reg_ptr);
2388+
uni_vbroadcastss(vmm_zp, Xmm(vmm_zp.getIdx()));
2389+
}
2390+
24062391
mov(reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]);
24072392
for (int bd = bd_b; bd < bd_e; bd++) {
2393+
uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]);
24082394
for (int ld = 0; ld < ld_block2; ld++) {
2409-
auto vmm_accm = accm(ld_block2, bd, ld);
24102395
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
2411-
uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]);
2412-
uni_vpmulld(vmm_src_grouped_sum, vmm_src_grouped_sum, vmm_zp);
2413-
uni_vpsubd(vmm_accm, vmm_accm, vmm_src_grouped_sum);
2396+
if (bd == bd_b && brg.wei_decomp_zero_points_stride != 0) {
2397+
uni_vpmovzxbd(vmm_zp, ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]);
2398+
}
2399+
2400+
auto vmm_accm = accm(ld_block2, bd, ld);
2401+
uni_vpmulld(vmm_tmp, vmm_src_grouped_sum, vmm_zp);
2402+
uni_vpsubd(vmm_accm, vmm_accm, vmm_tmp);
24142403
}
24152404
}
24162405
}
24172406

2407+
auto wei_scale = [&](int ld) {
2408+
return load(ld);
2409+
};
2410+
2411+
auto reg_local_src_scales = reg_ldb_loop;
2412+
auto reg_local_wei_scales = reg_bdb_loop;
2413+
auto vmm_src_scales = bcst();
2414+
2415+
mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_ + accums_stack_space]);
24182416
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);
2417+
if (brg.wei_decomp_scales_stride == 0) {
2418+
uni_vbroadcastss(wei_scale(0), ptr[reg_local_wei_scales]);
2419+
}
2420+
24192421
for (int bd = bd_b; bd < bd_e; bd++) {
24202422
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
24212423
for (int ld = 0; ld < ld_block2; ld++) {
2422-
if (brg.wei_decomp_scales_stride == 0) {
2423-
uni_vbroadcastss(load(ld), ptr[reg_local_wei_scales]);
2424-
} else {
2425-
uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
2424+
auto vmm_wei_scale = brg.wei_decomp_scales_stride == 0 ? wei_scale(0) : wei_scale(ld);
2425+
if (bd == bd_b && brg.wei_decomp_scales_stride != 0) {
2426+
uni_vmovups(vmm_wei_scale, ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
24262427
}
2427-
}
2428-
for (int ld = 0; ld < ld_block2; ld++) {
2429-
auto vmm_accm = accm(ld_block2, bd, ld);
24302428

2429+
auto vmm_accm = accm(ld_block2, bd, ld);
24312430
uni_vcvtdq2ps(vmm_accm, vmm_accm);
2432-
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
2433-
uni_vmulps(load(ld), vmm_accm, load(ld));
2431+
uni_vmulps(vmm_tmp, vmm_accm, vmm_src_scales);
24342432
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
2435-
uni_vaddps(vmm_accm, vmm_accm, load(ld));
2433+
uni_vfmadd231ps(vmm_accm, vmm_tmp, vmm_wei_scale);
24362434
}
24372435
}
24382436

0 commit comments

Comments
 (0)