diff --git a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp index 0d48e990fb8..fef520a0929 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp @@ -737,7 +737,7 @@ size_t jit_brgemm_amx_uker_base_t::C_offset(const brgemm_iteration_t &bi, const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0); const auto bd = get_out_bd(bi.bdi, bdb, inp_bd); const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0); - size_t ldc_elem = (size_t)ldb * bi.ldi->block(0); + size_t ldc_elem = (size_t)ldb * brg.ld_block; size_t bloc_idx = ldc_elem / brg.LDC; size_t in_block = ldc_elem % brg.LDC; diff --git a/src/cpu/x64/cpu_reducer.cpp b/src/cpu/x64/cpu_reducer.cpp index 86397f0881c..286f56f1d13 100644 --- a/src/cpu/x64/cpu_reducer.cpp +++ b/src/cpu/x64/cpu_reducer.cpp @@ -195,9 +195,11 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t { for (int i = 0; i < nloads; ++i) { size_t off = base_off + i * load_len; - if (load_len == typesize) - this->uni_add(Xmm(i), this->ptr[reg_src + off]); - else if (load_len == vlen) + if (load_len == typesize) { + assert(nloads == 1); + this->movd(Xmm(nloads + i), this->ptr[reg_src + off]); + this->uni_add(Xmm(i), Xmm(nloads + i)); + } else if (load_len == vlen) this->uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); else assert(!"unsupported"); diff --git a/src/cpu/x64/matmul/amx_blocking_heuristics.cpp b/src/cpu/x64/matmul/amx_blocking_heuristics.cpp index 78a10866df3..a33a5d5e45a 100644 --- a/src/cpu/x64/matmul/amx_blocking_heuristics.cpp +++ b/src/cpu/x64/matmul/amx_blocking_heuristics.cpp @@ -113,7 +113,7 @@ bool matmul_amx_blocking_params_macro_t::is_supported( && bgmmc.orig_wei_dt == bgmmc.wei_dt && bgmmc.is_amx && !bgmmc.is_runtime_N && !bgmmc.is_runtime_M && a_dt_ok && a_tag_ok && (bgmmc.reduce_kind == matmul_reduce_kind::undef) && b_tag_ok - && b_dt_ok && !has_zp; + && b_dt_ok && !has_zp && !bgmmc.packed_sparse_weights; } bool matmul_amx_blocking_params_macro_t::divs_are_acceptable() const { @@ -955,9 +955,11 @@ void matmul_amx_blocking_params_micro_t::set_blocking_parameters( if (brgemm_k_elems >= K) { k_blk_ = K; k_chunk_size_ = 1; + brgemm_batch_size_ = 1; } else { k_blk_ = brgemm_k_elems; k_chunk_size_ = 1; + brgemm_batch_size_ = 1; } } else if (current_k_tail == 0 && K % (k_blk_ * brgemm_batch_size_) == 0) { diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 1e0f5a5184e..dd51d5f5fd8 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -1340,8 +1340,10 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag); // When M == 1 MatMul always considers A to be non-transposed even if A md - // was created using "ba" tag. - bgmmc.treat_A_as_plain = bgmmc.M == 1; + // was created using "ba" tag. It is not plain in cab layout. + bgmmc.treat_A_as_plain = bgmmc.M == 1 + && IMPLICATION(bgmmc.batch != 1, + bm_conf_utils.check_is_plain(bgmmc.src_tag)); bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_A_as_plain) || bgmmc.src_tag == adbc); // For batched problems with plain A and C and fully broadcasted across B @@ -1738,8 +1740,7 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, bgmmc.buffer_b_chunk_sz = bgmmc.tr_b_dt_sz * rnd_up(bgmmc.N_blk, bgmmc.LDB) * rnd_up(bgmmc.K_chunk_elems, bgmmc.wei_k_blk); - bgmmc.buffer_b_per_thread_sz - = bgmmc.buffer_b_chunk_sz * bgmmc.brgemm_batch_size; + bgmmc.buffer_b_per_thread_sz = bgmmc.buffer_b_chunk_sz; bgmmc.buffer_reduce_per_thread_sz = 0; if (bgmmc.reduce_kind == matmul_reduce_kind::src) {