Skip to content

Commit c7ecd8f

Browse files
[FORK][FIX] IP weights compression: scalar scale
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
1 parent 1efdaaa commit c7ecd8f

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
305305
used_vregs = 5;
306306
else if (brg.is_f16_b_non_amx_vnni())
307307
used_vregs = 2;
308-
308+
309309
if (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2) {
310310
used_vregs += 5;
311311
}
@@ -2431,7 +2431,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
24312431
for (int bd = bd_b; bd < bd_e; bd++) {
24322432
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
24332433
for (int ld = 0; ld < ld_block2; ld++) {
2434-
uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
2434+
if (brg.wei_decomp_scales_stride == 0) {
2435+
uni_vbroadcastss(load(ld), ptr[reg_local_wei_scales]);
2436+
} else {
2437+
uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
2438+
}
24352439
}
24362440
for (int ld = 0; ld < ld_block2; ld++) {
24372441
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
@@ -2901,7 +2905,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
29012905
for (int ld = 0; ld < ld_block2; ld++) {
29022906
auto vmm_accm_tmp = accm_tmp(ld_block2, 0, ld);
29032907
auto vmm_accm = accm(ld_block2, 0, ld);
2904-
load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]);
2908+
if (brg.wei_decomp_scales_stride == 0) {
2909+
load_scales(bcst(), ptr[reg_local_wei_scales]);
2910+
} else {
2911+
load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]);
2912+
}
29052913
uni_vfmadd231ps(vmm_accm, vmm_accm_tmp, bcst());
29062914
}
29072915
}
@@ -3025,8 +3033,8 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
30253033
mov(reg_rdb_loop, brg.rdb);
30263034
L_aligned(rdb_loop_label, 64);
30273035
{
3028-
if (brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
3029-
brg.wei_decomp_zero_points_stride != 0)) {
3036+
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
3037+
brg.wei_decomp_zero_points_stride != 0)) || brg.with_src_dyn_quant) {
30303038
auto reg_local_ic = reg_aux_D;
30313039
auto reg_local_wei_params = reg_bdb_loop;
30323040
auto reg_local_ic_group = reg_ldb_loop;

0 commit comments

Comments
 (0)