From 157fb8153d8c54cf37311d95a63ed63e66b14ac5 Mon Sep 17 00:00:00 2001 From: dmitrygo Date: Tue, 21 Jan 2025 10:37:33 +0400 Subject: [PATCH] [FORK][FEATURE] DQ IP: performance enhansments - allocate aux accums regs on stack - precompute grouped src sums - optimize pointer arithmetic - reduce aux vecs count requred for the microkernel --- src/common/memory_tracking.hpp | 1 + src/cpu/x64/brgemm/brgemm.cpp | 22 +- src/cpu/x64/brgemm/brgemm.hpp | 8 +- src/cpu/x64/brgemm/brgemm_types.hpp | 3 + src/cpu/x64/brgemm/brgemm_utils.cpp | 19 +- src/cpu/x64/brgemm/jit_brgemm_kernel.cpp | 399 +++++++++++------- src/cpu/x64/jit_brgemm_inner_product.cpp | 35 +- src/cpu/x64/jit_brgemm_inner_product.hpp | 2 + .../x64/jit_brgemm_inner_product_utils.cpp | 49 ++- src/cpu/x64/jit_brgemm_primitive_conf.hpp | 1 + .../jit_brgemm_src_quantization_kernel.cpp | 70 ++- .../jit_brgemm_src_quantization_kernel.hpp | 11 + 12 files changed, 418 insertions(+), 202 deletions(-) diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 2b37f9fc05a..8752a8240af 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -305,6 +305,7 @@ enum { key_decompression_zero_points, key_src_quantized, key_src_dequantized_scales, + key_src_grouped_sum, // These two keys should always be the last ones, // even though they are not in alphabetical order key_nested, diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 52cbf163ca2..e8b88515348 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() { void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; assert(brg_kernel); @@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const void *addr_A, const void *addr_B, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; @@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; @@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch, const brgemm_dynamic_values_t *dynamic_values, - const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) { + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.ptr_wei_scales = ptr_wei_scales; brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; @@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, CHECK(brgemm_blocking(brg)); + brg->src_sum_group_size = wei_d.dims()[1]; + if (brg->with_src_dyn_quant) { + brg->src_sum_group_size = brg->rd_block; + brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size); + } + // avx2_vnni_2 kernel with xf16 data type requires blocked weights. if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16() && brg->LDB % brg->ld_block > 0) diff --git a/src/cpu/x64/brgemm/brgemm.hpp b/src/cpu/x64/brgemm/brgemm.hpp index e53fdf18999..bbd39ffe6f4 100644 --- a/src/cpu/x64/brgemm/brgemm.hpp +++ b/src/cpu/x64/brgemm/brgemm.hpp @@ -175,7 +175,7 @@ void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) /// @@ -205,7 +205,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_addr version) /// @@ -234,7 +234,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) /// @@ -267,7 +267,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr, const brgemm_dynamic_values_t *dynamic_values = nullptr, const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, - const void *ptr_src_scales = nullptr, size_t ic = 0); + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// AMX utilities: Creates a palette based on BRGEMM descriptor /// diff --git a/src/cpu/x64/brgemm/brgemm_types.hpp b/src/cpu/x64/brgemm/brgemm_types.hpp index f9a619e4912..e84bb8386b3 100644 --- a/src/cpu/x64/brgemm/brgemm_types.hpp +++ b/src/cpu/x64/brgemm/brgemm_types.hpp @@ -321,6 +321,8 @@ struct brgemm_desc_t { bool with_src_dyn_quant = false; int src_scales_group_size = 0; int src_scales_stride = 0; + int src_sum_group_size = 0; + int src_grouped_sum_stride = 0; bool is_row_major() const { assert(layout != brgemm_layout_undef); @@ -500,6 +502,7 @@ struct brgemm_kernel_params_t { const void *ptr_wei_scales = nullptr; const void *ptr_wei_zero_points = nullptr; const void *ptr_src_scales = nullptr; + const void *ptr_src_grouped_sum = nullptr; size_t ic; dim_t dynamic_LDA = 0; dim_t dynamic_LDB = 0; diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index 0edf6e6b482..9eb1e7d2c31 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -230,14 +230,10 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5; if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2; if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1; - if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1; - if (brg->with_src_dyn_quant) max_bcast_block -= 2; - if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2; + if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1; + if (brg->with_src_dyn_quant) max_bcast_block -= 1; max_bcast_block /= adj_ld_block2; - if (brg->with_src_dyn_quant) { - max_bcast_block /= 2; - } return max_bcast_block; } @@ -301,15 +297,22 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { = (brg->is_f16 && brg->isa_impl == avx512_core_fp16) ? 1 : data_type_vnni_granularity(brg->dt_a); + int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4; - if (brg->with_grouped_wei_decomp) { + if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) { auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); min_group_size = nstl::min(min_group_size, brg->src_scales_group_size); rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity); rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity); + brg->rd_block = rd_unroll * vnni_granularity; + } else if (brg->with_src_dyn_quant) { + brg->rd_block = brg->src_scales_group_size; + auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); + brg->rd_block = nstl::min(brg->rd_block, min_group_size); + } else { + brg->rd_block = rd_unroll * vnni_granularity; } - brg->rd_block = rd_unroll * vnni_granularity; brg->rdb = brg->reduce_dim / brg->rd_block; brg->rdb_tail = brg->reduce_dim % brg->rd_block; diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index aa7bce903ae..d1498358ec0 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -203,6 +203,7 @@ struct jit_brgemm_kernel_t : public jit_generator { const reg64_t reg_aux_wei_zp = reg_rdb_loop; const reg64_t reg_ic = reg_rdb_loop; const reg64_t reg_src_scales = reg_rdb_loop; + const reg64_t reg_src_grouped_sum = reg_rdb_loop; const reg64_t reg_tmp_read_values = reg_rdb_loop; const reg64_t reg_aux_scales = reg_aux_B; @@ -280,12 +281,13 @@ struct jit_brgemm_kernel_t : public jit_generator { constexpr static int reg_src_scales_offs_ = 336; constexpr static int reg_aux_src_scales_offs_ = 344; constexpr static int reg_aux2_src_scales_offs_ = 352; - // constexpr static int stack_space_needed_ = 360; + constexpr static int reg_src_grouped_sum_offs_ = 360; + constexpr static int reg_aux_src_grouped_sum_offs_ = 368; + constexpr static int reg_aux2_src_grouped_sum_offs_ = 376; // these are used for FP8 as temporary push/pop spaces - constexpr static int reg_val_tmp_1_ = 368; - constexpr static int reg_val_tmp_2_ = 376; - constexpr static int stack_space_needed_ = 384; - // regsiters for dynamic quant + constexpr static int reg_val_tmp_1_ = 384; + constexpr static int reg_val_tmp_2_ = 392; + constexpr static int stack_space_needed_ = 400; bool is_ldb_loop_ = false; @@ -318,16 +320,12 @@ struct jit_brgemm_kernel_t : public jit_generator { used_vregs += 1; } - if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0) { + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 && !brg.with_src_dyn_quant) { used_vregs += 1; } if (brg.with_src_dyn_quant) { - used_vregs += 2; - } - - if (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { - used_vregs += brg.ld_block2; + used_vregs += 1; } return isa_num_vregs(brg.isa_impl) - used_vregs; } @@ -874,12 +872,13 @@ void jit_brgemm_kernel_t::ldb_regs_shift(int ld_block2, bool is_tail) { mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales); } - if (brg.with_wei_decomp) { + if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { mov(reg_aux_wei_scales, ptr[rsp + reg_aux_wei_scales_offs_]); add(reg_aux_wei_scales, (is_tail) ? wei_scales_offset(1, true) : wei_scales_offset(ld_block2)); mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_aux_wei_scales); mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_aux_wei_scales); - + } + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { mov(reg_aux_wei_zp, ptr[rsp + reg_aux_wei_zero_points_offs_]); add(reg_aux_wei_zp, (is_tail) ? wei_zp_offset(1, true) : wei_zp_offset(ld_block2)); mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_aux_wei_zp); @@ -966,14 +965,16 @@ void jit_brgemm_kernel_t::copy_post_ops_stack_values_to_aux( } } - if (brg.with_grouped_wei_decomp) { - mov(reg_ic, ptr[rsp + reg_ic_offs_]); - mov(ptr[rsp + reg_aux_ic_offs_], reg_ic); - } if (brg.with_src_dyn_quant) { mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]); mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales); mov(ptr[rsp + reg_aux2_src_scales_offs_], reg_src_scales); + + if (brg.with_wei_decomp_zero_points) { + mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]); + mov(ptr[rsp + reg_aux_src_grouped_sum_offs_], reg_src_grouped_sum); + mov(ptr[rsp + reg_aux2_src_grouped_sum_offs_], reg_src_grouped_sum); + } } if (brg.zp_type_b != brgemm_broadcast_t::none) { mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]); @@ -1051,6 +1052,9 @@ void jit_brgemm_kernel_t::read_params() { if (brg.with_src_dyn_quant) { mov(reg_src_scales, ptr[param1 + GET_OFF(ptr_src_scales)]); mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales); + + mov(reg_src_grouped_sum, ptr[param1 + GET_OFF(ptr_src_grouped_sum)]); + mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum); } if (brg.zp_type_c != brgemm_broadcast_t::none) { @@ -1237,7 +1241,7 @@ void jit_brgemm_kernel_t::apply_alpha_beta( uni_vaddps(vmm_masked, vmm, ptr_C); } else { vmaskmovps(vmm_prev_dst, vmm_tail_mask(), ptr_C); - if (brg.is_int8) + if (brg.is_int8 && !brg.with_src_dyn_quant) uni_vpaddd(vmm, vmm, vmm_prev_dst); else uni_vaddps(vmm, vmm, vmm_prev_dst); @@ -2298,27 +2302,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift()); }; - auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) { - int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld); - return Vmm(idx); - }; - - auto vmm_zero_point = [&](int ld) { - int idx = isa_num_vregs(brg.isa_impl) - 3 - ld; - return Vmm(idx); - }; - - static const int8_t negative_one[64] = { - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1 - }; - static const int8_t mask_low_half[64] = { 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, @@ -2329,48 +2312,19 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); - auto reg_local_wei_scales = reg_bdb_loop; - auto reg_local_wei_zp = reg_ldb_loop; - auto reg_ptr = reg_local_wei_scales; - - if (brg.with_wei_decomp_zero_points) { - mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]); - if (brg.wei_decomp_zero_points_stride == 0) { - auto reg_ptr_8 = Reg8(reg_ptr.getIdx()); - mov(reg_ptr_8, ptr[reg_local_wei_zp]); - uni_vpbroadcastb(vmm_zero_point(0), reg_ptr_8); - } else { - static const int8_t index_table[64] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C, - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C, - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C, - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x0C, 0x0C, 0x0C, 0x0C - }; - - auto vmm_indexes = Vmm(isa_num_vregs(brg.isa_impl) - 1); - mov(reg_ptr, (size_t)index_table); - uni_vmovups(vmm_indexes, ptr[reg_ptr]); - - for (int ld = 0; ld < ld_block2; ld++) { - uni_vpmovzxbd(vmm_zero_point(ld), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]); - vpshufb(vmm_zero_point(ld), vmm_zero_point(ld), vmm_indexes); - } - } - } - - auto vmm_neg_one = Vmm(isa_num_vregs(brg.isa_impl) - 1); - mov(reg_ptr, (size_t)negative_one); - uni_vmovups(vmm_neg_one, ptr[reg_ptr]); - - auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 2); + auto reg_ptr = reg_bdb_loop; + auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 1); mov(reg_ptr, (size_t)mask_low_half); uni_vmovups(vmm_mask_low_half, ptr[reg_ptr]); - mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]); - + const int vec_size = vreg_traits::vlen; + auto accums_stack_space = bd_e * ld_block2 * vec_size; + sub(rsp, accums_stack_space); for (int bd = bd_b; bd < bd_e; bd++) { for (int ld = 0; ld < ld_block2; ld++) { - auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld); + auto vmm_accm = accm(ld_block2, bd, ld); + vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm); + uni_vxorps(vmm_accm, vmm_accm, vmm_accm); } } @@ -2405,48 +2359,83 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, have_to_load_bytes && bd_by_load_bytes, brg.dt_a); } if (prefetch_count_B < ld_block2) { + int typesize_scale = brg.dt_b == data_type::u4 ? 2 : 1; prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd) - + brg.LDB * brg.rd_block * brg.typesize_B]); + + brg.LDB * brg.rd_block * brg.typesize_B / typesize_scale]); } for (int ld = 0; ld < ld_block2; ld++) { - auto vmm = vmm_accm_tmp(ld_block2, bd, ld); + auto vmm = accm(ld_block2, bd, ld); vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); } - if (brg.with_wei_decomp_zero_points) { - uni_vpxor(bcst(), bcst(), vmm_neg_one); - uni_vpsubb(bcst(), bcst(), vmm_neg_one); - for (int ld = 0; ld < ld_block2; ld++) { - auto vmm = vmm_accm_tmp(ld_block2, bd, ld); - Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld); - vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); + } + } + + auto vmm_zero_point = [&](int ld) { + return load(ld); + }; + + auto reg_local_wei_zp = reg_ldb_loop; + auto reg_local_src_grouped_sum = reg_bdb_loop; + auto vmm_tmp = Vmm(isa_num_vregs(brg.isa_impl) - 1); + auto vmm_src_grouped_sum = bcst(); + + if (brg.with_wei_decomp_zero_points) { + mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_ + accums_stack_space]); + if (brg.wei_decomp_zero_points_stride == 0) { + Vmm vmm_zp = vmm_zero_point(0); + auto reg_ptr_32 = Reg32(reg_ptr.getIdx()); + movzx(reg_ptr_32, ptr[reg_local_wei_zp]); + uni_vmovq(Xmm(vmm_zp.getIdx()), reg_ptr); + uni_vbroadcastss(vmm_zp, Xmm(vmm_zp.getIdx())); + } + + mov(reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]); + for (int bd = bd_b; bd < bd_e; bd++) { + uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]); + for (int ld = 0; ld < ld_block2; ld++) { + Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld); + if (bd == bd_b && brg.wei_decomp_zero_points_stride != 0) { + uni_vpmovzxbd(vmm_zp, ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]); } + + auto vmm_accm = accm(ld_block2, bd, ld); + uni_vpmulld(vmm_tmp, vmm_src_grouped_sum, vmm_zp); + uni_vpsubd(vmm_accm, vmm_accm, vmm_tmp); } } } - auto reg_local_src_scales = reg_local_wei_zp; + auto wei_scale = [&](int ld) { + return load(ld); + }; + + auto reg_local_src_scales = reg_ldb_loop; + auto reg_local_wei_scales = reg_bdb_loop; auto vmm_src_scales = bcst(); - mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]); + + mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_ + accums_stack_space]); + mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]); + if (brg.wei_decomp_scales_stride == 0) { + uni_vbroadcastss(wei_scale(0), ptr[reg_local_wei_scales]); + } for (int bd = bd_b; bd < bd_e; bd++) { uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]); for (int ld = 0; ld < ld_block2; ld++) { - if (brg.wei_decomp_scales_stride == 0) { - uni_vbroadcastss(load(ld), ptr[reg_local_wei_scales]); - } else { - uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); + auto vmm_wei_scale = brg.wei_decomp_scales_stride == 0 ? wei_scale(0) : wei_scale(ld); + if (bd == bd_b && brg.wei_decomp_scales_stride != 0) { + uni_vmovups(vmm_wei_scale, ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); } - } - for (int ld = 0; ld < ld_block2; ld++) { - auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld); - auto vmm_accm = accm(ld_block2, bd, ld); - uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux); - uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales); - uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld)); + auto vmm_accm = accm(ld_block2, bd, ld); + uni_vcvtdq2ps(vmm_accm, vmm_accm); + uni_vmulps(vmm_tmp, vmm_accm, vmm_src_scales); + uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]); + uni_vfmadd231ps(vmm_accm, vmm_tmp, vmm_wei_scale); } } + add(rsp, accums_stack_space); mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); @@ -3007,6 +2996,98 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, int ld_block2, int ldb_loop_length, bool is_reg_tail, bool is_ld_tail, bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail, bool skip_accumulation) { + auto ic_group_shift_generic = [&]() { + if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0)) + || brg.with_src_dyn_quant) { + auto reg_local_ic = reg_aux_D; + auto reg_local_wei_params = reg_bdb_loop; + auto reg_local_ic_group = reg_ldb_loop; + + auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) { + mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); + mov(reg_local_ic_group, group_size); + xor_(rdx, rdx); + idiv(reg_local_ic_group); + imul(reg_local_ic, reg_local_ic, stride); + + mov(reg_local_wei_params, ptr[rsp + src_offs]); + add(reg_local_wei_params, reg_local_ic); + mov(ptr[rsp + dst_offs], reg_local_wei_params); + }; + + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + mov(ptr[rsp + reg_aux2_D_offs_], reg_aux_D); + mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); + mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv + + if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { + ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_, + brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt)); + } + + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { + ic_group_shift(reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_, + brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt)); + } + + if (brg.with_src_dyn_quant) { + ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_, + brg.src_scales_group_size, sizeof(float)); + + if (brg.with_wei_decomp_zero_points) { + ic_group_shift(reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_, + brg.src_sum_group_size, sizeof(int32_t)); + } + } + + mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); + add(reg_local_ic, brg.rd_block); + mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic); + + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]); + mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); + mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]); + } + }; + + auto ic_group_shift_opt = [&](int rb) { + if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0)) + || brg.with_src_dyn_quant) { + mov(ptr[rsp + reg_bdb_loop_offs_], reg_rdb_loop); + auto reg_ptr = reg_rdb_loop; + + auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) { + if ((rb + 1) * brg.rd_block % group_size == 0) { + mov(reg_ptr, ptr[rsp + src_offs]); + add(reg_ptr, stride); + mov(ptr[rsp + dst_offs], reg_ptr); + } + }; + + if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { + ic_group_shift(reg_aux2_wei_scales_offs_, reg_aux2_wei_scales_offs_, + brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt)); + } + + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { + ic_group_shift(reg_aux2_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_, + brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt)); + } + + if (brg.with_src_dyn_quant) { + ic_group_shift(reg_aux2_src_scales_offs_, reg_aux2_src_scales_offs_, + brg.src_scales_group_size, sizeof(float)); + + if (brg.with_wei_decomp_zero_points) { + ic_group_shift(reg_aux2_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_, + brg.src_sum_group_size, sizeof(int32_t)); + } + } + + mov(reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + } + }; Label ldb_loop_label; Label BS_loop_label; @@ -3014,6 +3095,11 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, copy_post_ops_stack_values_to_aux(is_reg_tail); auto ld_loop_body = [&](int vpad) { + if (brg.with_grouped_wei_decomp) { + mov(reg_ic, ptr[rsp + reg_ic_offs_]); + mov(ptr[rsp + reg_aux_ic_offs_], reg_ic); + } + set_A_B_matrices(); int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; @@ -3028,70 +3114,77 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, gemm_microkernel_amx( bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail); } else { - if (brg.rdb > 0) { - Label rdb_loop_label; - mov(reg_rdb_loop, brg.rdb); - L_aligned(rdb_loop_label, 64); - { - if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || - brg.wei_decomp_zero_points_stride != 0)) || brg.with_src_dyn_quant) { - auto reg_local_ic = reg_aux_D; - auto reg_local_wei_params = reg_bdb_loop; - auto reg_local_ic_group = reg_ldb_loop; - - auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) { - mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); - mov(reg_local_ic_group, group_size); - xor_(rdx, rdx); - idiv(reg_local_ic_group); - imul(reg_local_ic, reg_local_ic, stride); - - mov(reg_local_wei_params, ptr[rsp + src_offs]); - add(reg_local_wei_params, reg_local_ic); - mov(ptr[rsp + dst_offs], reg_local_wei_params); - }; - - mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); - mov(ptr[rsp + reg_aux2_D_offs_], reg_aux_D); - mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); - mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv - - if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { - ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_, - brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt)); - } + ic_group_shift_generic(); + + auto rdb_group = brg.rd_block; + auto rd_size = brg.rdb * brg.rd_block + brg.rdb_tail; + if (brg.wei_decomp_scales_group_size < rd_size) + rdb_group = nstl::max(rdb_group, brg.wei_decomp_scales_group_size); + if (brg.wei_decomp_zero_points_group_size < rd_size) + rdb_group = nstl::max(rdb_group, brg.wei_decomp_zero_points_group_size); + if (brg.with_src_dyn_quant) { + rdb_group = nstl::max(rdb_group, brg.src_scales_group_size); + if (brg.with_wei_decomp_zero_points) { + rdb_group = nstl::max(rdb_group, brg.src_sum_group_size); + } + } + rdb_group = rdb_group / brg.rd_block; + auto rbd_blocks = brg.rdb / rdb_group; + auto max_rdb_unroll = 8; + + if (brg.with_wei_decomp && rdb_group <= max_rdb_unroll) { + if (rbd_blocks > 0) { + Label rdb_loop_label; + mov(reg_rdb_loop, rbd_blocks); + L_aligned(rdb_loop_label, 64); + { + for (int rb = 0; rb < rdb_group; rb++) { + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, false, + is_ld_tail, vpad, rows_for_rd_tail); - if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { - ic_group_shift(reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_, - brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt)); - } + add(reg_aux_A, rdb_A_offset()); + add(reg_aux_B, rdb_B_offset()); - if (brg.with_src_dyn_quant) { - ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_, - brg.src_scales_group_size, sizeof(float)); + ic_group_shift_opt(rb); } - mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); - add(reg_local_ic, brg.rd_block); - mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic); - - mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); - mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]); - mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); - mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]); + dec(reg_rdb_loop); + cmp(reg_rdb_loop, 0); } + jg(rdb_loop_label, T_NEAR); + } - const bool is_rd_tail = false; - gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, - is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); + for (int rb = rbd_blocks * rdb_group; rb < brg.rdb; rb++) { + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, false, + is_ld_tail, vpad, rows_for_rd_tail); add(reg_aux_A, rdb_A_offset()); add(reg_aux_B, rdb_B_offset()); - dec(reg_rdb_loop); - cmp(reg_rdb_loop, 0); + ic_group_shift_opt(rb); + + mov(reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + } + } else { + if (brg.rdb > 0) { + Label rdb_loop_label; + mov(reg_rdb_loop, brg.rdb); + L_aligned(rdb_loop_label, 64); + { + const bool is_rd_tail = false; + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, + is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); + + add(reg_aux_A, rdb_A_offset()); + add(reg_aux_B, rdb_B_offset()); + + ic_group_shift_generic(); + + dec(reg_rdb_loop); + cmp(reg_rdb_loop, 0); + } + jg(rdb_loop_label, T_NEAR); } - jg(rdb_loop_label, T_NEAR); } } if (brg.rdb_tail != 0) { @@ -3302,6 +3395,10 @@ void jit_brgemm_kernel_t::bdb_loop() { mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]); add(reg_src_scales, bd_block2 * brg.bd_block * brg.src_scales_stride * sizeof(float)); mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales); + + mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]); + add(reg_src_grouped_sum, bd_block2 * brg.bd_block * brg.src_grouped_sum_stride * sizeof(int32_t)); + mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum); } advance_bd_block2_post_op_regs(bd_block2); diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 5879b5b5a89..f6f2092cb8f 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -143,21 +143,26 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int8_t* qsrc = nullptr; float* src_dscales = nullptr; + int32_t* src_grouped_sum = nullptr; if (jbgp.with_src_dynamic_quant) { qsrc = scratchpad.template get(key_src_quantized); src_dscales = scratchpad.template get(key_src_dequantized_scales); + src_grouped_sum = scratchpad.template get(key_src_grouped_sum); int ic_groups = div_up(jbgp.ic, jbgp.src_quant_group_size); + int ic_sum_groups = div_up(jbgp.ic, jbgp.src_sum_group_size); auto src_ptr = reinterpret_cast(src); auto qsrc_ptr = qsrc; auto src_dscales_ptr = src_dscales; - int vec_loop_end = (ic_groups - 1) * jbgp.src_quant_group_size; + auto src_grouped_sum_ptr = src_grouped_sum; + int vec_loop_end = rnd_dn(jbgp.ic, jbgp.src_quant_group_size); parallel_nd(jbgp.mb, [&](int mb) { src_quantization_runtime_params_t rt_params = {}; rt_params.src_ptr = src_ptr + mb * jbgp.ic; rt_params.qsrc_ptr = qsrc_ptr + mb * jbgp.ic; rt_params.src_scales_ptr = src_dscales_ptr + mb * ic_groups; + rt_params.src_grouped_sum_ptr = src_grouped_sum_ptr + mb * ic_sum_groups; rt_params.ic_size = vec_loop_end; (*brg_src_quant_kernel_)(&rt_params); @@ -175,6 +180,18 @@ status_t brgemm_inner_product_fwd_t::execute_forward( qsrc_ptr[mb * jbgp.ic + ic] = std::round(src_ptr[mb * jbgp.ic + ic] * qscale); } } + + if (jbgp.wei_decomp_zero_points_dt) { + for (int icb = vec_loop_end / jbgp.src_quant_group_size; icb < ic_sum_groups; icb++) { + int ic_begin = icb * jbgp.src_sum_group_size; + int ic_end = nstl::min(static_cast((icb + 1) * jbgp.src_sum_group_size), jbgp.ic); + int sum = 0; + for (int ic = ic_begin; ic < ic_end; ic++) { + sum += qsrc_ptr[mb * jbgp.ic + ic]; + } + src_grouped_sum_ptr[mb * ic_sum_groups + icb] = sum; + } + } }); src = reinterpret_cast(qsrc); @@ -429,10 +446,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_scales_offset = 0; int wei_zero_points_offset = 0; int src_scales_offset = 0; + int src_grouped_sum_offset = 0; if (jbgp.weights_decompression) { wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); + src_grouped_sum_offset = n * div_up(jbgp.ic, jbgp.src_sum_group_size); } auto ptr_D = dst + dst_off; @@ -456,10 +475,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data, - scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } else { brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } } @@ -534,10 +555,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_scales_offset = 0; int wei_zero_points_offset = 0; int src_scales_offset = 0; + int src_grouped_sum_offset = 0; if (jbgp.weights_decompression) { wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); + src_grouped_sum_offset = n * div_up(jbgp.ic, jbgp.src_sum_group_size); } auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get(); @@ -560,10 +583,12 @@ status_t brgemm_inner_product_fwd_t::execute_forward( nullptr, false, 1, false, false, dst_scales}; brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } else { brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, src_dscales + src_scales_offset, ic); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); } } }; diff --git a/src/cpu/x64/jit_brgemm_inner_product.hpp b/src/cpu/x64/jit_brgemm_inner_product.hpp index 118b6b79fc4..041c27abd37 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.hpp +++ b/src/cpu/x64/jit_brgemm_inner_product.hpp @@ -265,6 +265,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t { if (pd()->jbgp_.with_src_dynamic_quant) { src_quantization_compile_params_t jcp = {}; jcp.ic_quant_block = pd()->jbgp_.src_quant_group_size; + jcp.with_src_grouped_sum = !pd()->attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS); + jcp.src_sum_group_size = pd()->jbgp_.src_sum_group_size; jcp.src_dt = pd()->jbgp_.orig_src_dt; jcp.qsrc_dt = data_type::s8; diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index c1a247c241c..b679f41caba 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -699,7 +699,22 @@ status_t jit_brgemm_ip_fwd_conf_t::init_conf(cpu_isa_t isa, // Current implementation of grouped weights decompression algorithm requires K size to be aligned on group size. // Besides that "batched" usage of brgemm block is not covered, so forcing the value to 1. - if (jbgp.with_grouped_weights_decompression || jbgp.with_src_dynamic_quant) { + if (jbgp.with_src_dynamic_quant) { + size_t max_ic_group_size = k_blk; + if (jbgp.wei_scales_ic_group_size != static_cast(jbgp.ic)) + max_ic_group_size = std::max(max_ic_group_size, jbgp.wei_scales_ic_group_size); + if (jbgp.wei_zero_points_ic_group_size != static_cast(jbgp.ic)) + max_ic_group_size = std::max(max_ic_group_size, jbgp.wei_zero_points_ic_group_size); + max_ic_group_size = std::max(max_ic_group_size, jbgp.src_quant_group_size); + max_ic_group_size = std::max(max_ic_group_size, jbgp.src_sum_group_size); + + if ((jbgp.nb_ic_blocking * k_blk) % max_ic_group_size != 0) { + jbgp.nb_ic_blocking = max_ic_group_size; + } + jbgp.K = k_blk * jbgp.nb_ic_blocking; + jbgp.gemm_batch_size = 1; + jbgp.nthr_ic_b = 1; + } else if (jbgp.with_grouped_weights_decompression) { auto min_ic_group_size = std::min(jbgp.wei_scales_ic_group_size, jbgp.wei_zero_points_ic_group_size); min_ic_group_size = std::min(min_ic_group_size, jbgp.src_quant_group_size); if ((jbgp.nb_ic_blocking * k_blk) % min_ic_group_size != 0) { @@ -1421,6 +1436,7 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.with_src_dynamic_quant = false; if (jbgp.weights_decompression) { jbgp.src_quant_group_size = jbgp.ic; + jbgp.src_sum_group_size = jbgp.ic; if (!attr.src_dyn_quant_params_.has_default_values()) { jbgp.with_src_dynamic_quant = true; } @@ -1441,11 +1457,6 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.wei_zero_points_ic_group_size = div_up(jbgp.ic, attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1]); } - // todo: fix avx2 brgemm kernel behavior for non scalar zp - if (!is_superset(isa, avx512_core) && attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0] != 1) { - jbgp.with_src_dynamic_quant = false; - } - jbgp.wei_decomp_zero_points_dt = attr.zero_points_.get_data_type(DNNL_ARG_WEIGHTS); if (!one_of(jbgp.wei_decomp_zero_points_dt, f32, u8)) return status::unimplemented; @@ -1467,6 +1478,12 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, if (jbgp.is_amx && jbgp.wei_decomp_algo == weights_decomp_kind_t::immediate) return status::unimplemented; + auto min_group_size = nstl::min(jbgp.wei_scales_ic_group_size, jbgp.wei_zero_points_ic_group_size); + if (jbgp.wei_scales_ic_group_size % min_group_size) + return status::unimplemented; + if (jbgp.wei_zero_points_ic_group_size % min_group_size) + return status::unimplemented; + if (jbgp.with_src_dynamic_quant) { if (!(one_of(jbgp.wei_dt, u4, u8) && one_of(jbgp.wei_decomp_scales_dt, f32) && @@ -1476,12 +1493,20 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, const size_t simd_width = 16; if (jbgp.src_quant_group_size == 0 || jbgp.src_quant_group_size % simd_width) return status::unimplemented; - } - } - if (jbgp.with_src_dynamic_quant) { - jbgp.orig_src_dt = jbgp.src_dt; - jbgp.src_dt = s8; + jbgp.orig_src_dt = jbgp.src_dt; + jbgp.src_dt = s8; + + size_t rd_unroll = jbgp.src_quant_group_size; + jbgp.src_sum_group_size = nstl::min(rd_unroll, min_group_size); + + if (jbgp.wei_scales_ic_group_size != static_cast(jbgp.ic) && jbgp.wei_scales_ic_group_size % jbgp.src_sum_group_size) + return status::unimplemented; + if (jbgp.wei_zero_points_ic_group_size != static_cast(jbgp.ic) && jbgp.wei_zero_points_ic_group_size % jbgp.src_sum_group_size) + return status::unimplemented; + if (jbgp.src_quant_group_size % jbgp.src_sum_group_size) + return status::unimplemented; + } } jbgp.bia_dt = jbgp.with_bias @@ -1696,6 +1721,8 @@ void jit_brgemm_ip_conf_t::init_scratchpad_base( if (jbgp.with_src_dynamic_quant) { scratchpad.book(key_src_quantized, jbgp.mb * jbgp.ic, sizeof(int8_t)); scratchpad.book(key_src_dequantized_scales, jbgp.mb * div_up(jbgp.ic, jbgp.src_quant_group_size), sizeof(float)); + if (jbgp.wei_decomp_zero_points_dt) + scratchpad.book(key_src_grouped_sum, jbgp.mb * div_up(jbgp.ic, jbgp.src_sum_group_size), sizeof(int32_t)); } } diff --git a/src/cpu/x64/jit_brgemm_primitive_conf.hpp b/src/cpu/x64/jit_brgemm_primitive_conf.hpp index 5f7ebd2cf0e..60cbe03f718 100644 --- a/src/cpu/x64/jit_brgemm_primitive_conf.hpp +++ b/src/cpu/x64/jit_brgemm_primitive_conf.hpp @@ -113,6 +113,7 @@ struct jit_brgemm_primitive_conf_t { bool with_src_dynamic_quant; size_t src_quant_group_size; + size_t src_sum_group_size; data_type_t orig_src_dt; }; diff --git a/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp b/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp index 90256ade02d..b6aff7896b0 100644 --- a/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp +++ b/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp @@ -43,6 +43,39 @@ void jit_brgemm_src_quantization_kernel_t::load_src(Vmm vmm_load, const Xby } } +template +void jit_brgemm_src_quantization_kernel_t::horiz_op(Vmm vmm_src, Vmm vmm_aux, op_type type) { + auto uni_op = [&](const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { + if (type == op_type::max) { + uni_vmaxps(x1, x2, op); + } else if (type == op_type::sum) { + uni_vpaddd(x1, x2, op); + } else { + assert(!"unsupported op type"); + } + }; + + if (isa == avx512_core) { + Xbyak::Zmm zmm_src = Xbyak::Zmm(vmm_src.getIdx()); + Xbyak::Zmm zmm_aux = Xbyak::Zmm(vmm_aux.getIdx()); + vshuff32x4(zmm_aux, zmm_src, zmm_src, 0x4E); + uni_op(zmm_src, zmm_src, zmm_aux); + vshuff32x4(zmm_aux, zmm_src, zmm_src, 0xB1); + uni_op(zmm_src, zmm_src, zmm_aux); + } else if (isa == avx2) { + Xbyak::Ymm ymm_src = Xbyak::Ymm(vmm_src.getIdx()); + Xbyak::Ymm ymm_aux = Xbyak::Ymm(vmm_aux.getIdx()); + vperm2i128(ymm_aux, ymm_src, ymm_src, 0x01); + uni_op(ymm_src, ymm_src, ymm_aux); + } else { + assert(!"unsupported isa"); + } + uni_vshufps(vmm_aux, vmm_src, vmm_src, 0x4E); + uni_op(vmm_src, vmm_src, vmm_aux); + uni_vshufps(vmm_aux, vmm_src, vmm_src, 0xB1); + uni_op(vmm_src, vmm_src, vmm_aux); +} + template void jit_brgemm_src_quantization_kernel_t::generate() { preamble(); @@ -50,6 +83,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { mov(reg_src, ptr[param1 + GET_OFF(src_ptr)]); mov(reg_qsrc, ptr[param1 + GET_OFF(qsrc_ptr)]); mov(reg_src_scales, ptr[param1 + GET_OFF(src_scales_ptr)]); + mov(reg_src_grouped_sum, ptr[param1 + GET_OFF(src_grouped_sum_ptr)]); mov(reg_ic_size, ptr[param1 + GET_OFF(ic_size)]); Xbyak::Label ic_loop_label; @@ -58,6 +92,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { size_t src_dt_size = types::data_type_size(jcp_.src_dt); size_t qsrc_dt_size = types::data_type_size(jcp_.qsrc_dt); size_t src_scales_dt_size = types::data_type_size(data_type::f32); + size_t src_grouped_sum_dt_size = types::data_type_size(data_type::s32); static const float negative_zero[16] = { -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, @@ -89,6 +124,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { jl(ic_end_label, T_NEAR); assert(!(jcp_.ic_quant_block % vec_size)); + assert(!(jcp_.src_sum_group_size % vec_size)); int ic_blocks = jcp_.ic_quant_block / vec_size; uni_vpxor(vmm_max(), vmm_max(), vmm_max()); @@ -98,25 +134,7 @@ void jit_brgemm_src_quantization_kernel_t::generate() { uni_vmaxps(vmm_max(), vmm_max(), vmm_src()); } - if (isa == avx512_core) { - Xbyak::Zmm max_zmm = Xbyak::Zmm(vmm_max().getIdx()); - Xbyak::Zmm aux_zmm = Xbyak::Zmm(vmm_aux().getIdx()); - vshuff32x4(aux_zmm, max_zmm, max_zmm, 0x4E); - uni_vmaxps(max_zmm, max_zmm, aux_zmm); - vshuff32x4(aux_zmm, max_zmm, max_zmm, 0xB1); - uni_vmaxps(max_zmm, max_zmm, aux_zmm); - } else if (isa == avx2) { - Xbyak::Ymm max_ymm = Xbyak::Ymm(vmm_max().getIdx()); - Xbyak::Ymm aux_ymm = Xbyak::Ymm(vmm_aux().getIdx()); - vperm2i128(aux_ymm, max_ymm, max_ymm, 0x01); - uni_vmaxps(max_ymm, max_ymm, aux_ymm); - } else { - assert(!"unsupported isa"); - } - uni_vshufps(vmm_aux(), vmm_max(), vmm_max(), 0x4E); - uni_vmaxps(vmm_max(), vmm_max(), vmm_aux()); - uni_vshufps(vmm_aux(), vmm_max(), vmm_max(), 0xB1); - uni_vmaxps(vmm_max(), vmm_max(), vmm_aux()); + horiz_op(vmm_max(), vmm_aux(), op_type::max); auto vmm_dscale = vmm_max(); uni_vbroadcastss(vmm_dscale, Xmm(vmm_dscale.getIdx())); @@ -126,11 +144,25 @@ void jit_brgemm_src_quantization_kernel_t::generate() { uni_vdivps(vmm_qscale(), vmm_one(), vmm_dscale); uni_vmovss(ptr[reg_src_scales], Xmm(vmm_dscale.getIdx())); + if (jcp_.with_src_grouped_sum) { + uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src_sum_accum()); + } for (int icb = 0; icb < ic_blocks; icb++) { load_src(vmm_src(), ptr[reg_src + icb * vec_size * src_dt_size]); uni_vmulps(vmm_src(), vmm_src(), vmm_qscale()); uni_vcvtps2dq(vmm_src(), vmm_src()); + if (jcp_.with_src_grouped_sum) { + uni_vpaddd(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src()); + + if (((icb + 1) * vec_size) % jcp_.src_sum_group_size == 0) { + horiz_op(vmm_src_sum_accum(), vmm_aux(), op_type::sum); + uni_vmovss(ptr[reg_src_grouped_sum], Xmm(vmm_src_sum_accum().getIdx())); + uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src_sum_accum()); + add(reg_src_grouped_sum, src_grouped_sum_dt_size); + } + } + if (isa == avx512_core) { vpmovsdb(ptr[reg_qsrc + icb * vec_size * qsrc_dt_size], vmm_src()); } else { diff --git a/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp b/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp index 15c2621940e..b93b19ee25d 100644 --- a/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp +++ b/src/cpu/x64/jit_brgemm_src_quantization_kernel.hpp @@ -32,6 +32,8 @@ namespace x64 { struct src_quantization_compile_params_t { size_t ic_quant_block; + bool with_src_grouped_sum; + size_t src_sum_group_size; data_type_t src_dt; data_type_t qsrc_dt; }; @@ -40,6 +42,7 @@ struct src_quantization_runtime_params_t { const void *src_ptr; const void *qsrc_ptr; const void *src_scales_ptr; + const void *src_grouped_sum_ptr; size_t ic_size; }; @@ -76,6 +79,9 @@ struct jit_brgemm_src_quantization_kernel_t : public jit_src_quantization_kernel void generate() override; void load_src(Vmm vmm_load, const Xbyak::Address& addr); + enum class op_type {max, sum}; + void horiz_op(Vmm vmm_src, Vmm vmm_aux, op_type op); + Vmm vmm_src() { return Vmm(0); } @@ -104,11 +110,16 @@ struct jit_brgemm_src_quantization_kernel_t : public jit_src_quantization_kernel return Vmm(6); } + Vmm vmm_src_sum_accum() { + return Vmm(7); + } + Xbyak::Reg64 reg_src = r8; Xbyak::Reg64 reg_qsrc = r9; Xbyak::Reg64 reg_src_scales = r10; Xbyak::Reg64 reg_ic_size = r11; Xbyak::Reg64 reg_tmp = r12; + Xbyak::Reg64 reg_src_grouped_sum = r13; size_t vec_size; };