Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yobodovs/amx blocking heuristics fixes #2938

Merged
merged 5 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ size_t jit_brgemm_amx_uker_base_t::C_offset(const brgemm_iteration_t &bi,
const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0);
const auto bd = get_out_bd(bi.bdi, bdb, inp_bd);
const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0);
size_t ldc_elem = (size_t)ldb * bi.ldi->block(0);
size_t ldc_elem = (size_t)ldb * brg.ld_block;
size_t bloc_idx = ldc_elem / brg.LDC;
size_t in_block = ldc_elem % brg.LDC;

Expand Down
8 changes: 5 additions & 3 deletions src/cpu/x64/cpu_reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,11 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t<data_type> {
for (int i = 0; i < nloads; ++i) {
size_t off = base_off + i * load_len;

if (load_len == typesize)
this->uni_add(Xmm(i), this->ptr[reg_src + off]);
else if (load_len == vlen)
if (load_len == typesize) {
assert(nloads == 1);
this->movd(Xmm(nloads + i), this->ptr[reg_src + off]);
this->uni_add(Xmm(i), Xmm(nloads + i));
} else if (load_len == vlen)
this->uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]);
else
assert(!"unsupported");
Expand Down
4 changes: 3 additions & 1 deletion src/cpu/x64/matmul/amx_blocking_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ bool matmul_amx_blocking_params_macro_t::is_supported(
&& bgmmc.orig_wei_dt == bgmmc.wei_dt && bgmmc.is_amx
&& !bgmmc.is_runtime_N && !bgmmc.is_runtime_M && a_dt_ok && a_tag_ok
&& (bgmmc.reduce_kind == matmul_reduce_kind::undef) && b_tag_ok
&& b_dt_ok && !has_zp;
&& b_dt_ok && !has_zp && !bgmmc.packed_sparse_weights;
}

bool matmul_amx_blocking_params_macro_t::divs_are_acceptable() const {
Expand Down Expand Up @@ -955,9 +955,11 @@ void matmul_amx_blocking_params_micro_t::set_blocking_parameters(
if (brgemm_k_elems >= K) {
k_blk_ = K;
k_chunk_size_ = 1;
brgemm_batch_size_ = 1;
} else {
k_blk_ = brgemm_k_elems;
k_chunk_size_ = 1;
brgemm_batch_size_ = 1;
}
} else if (current_k_tail == 0
&& K % (k_blk_ * brgemm_batch_size_) == 0) {
Expand Down
9 changes: 5 additions & 4 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,8 +1340,10 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,

const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag);
// When M == 1 MatMul always considers A to be non-transposed even if A md
// was created using "ba" tag.
bgmmc.treat_A_as_plain = bgmmc.M == 1;
// was created using "ba" tag. It is not plain in cab layout.
bgmmc.treat_A_as_plain = bgmmc.M == 1
&& IMPLICATION(bgmmc.batch != 1,
bm_conf_utils.check_is_plain(bgmmc.src_tag));
bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_A_as_plain)
|| bgmmc.src_tag == adbc);
// For batched problems with plain A and C and fully broadcasted across B
Expand Down Expand Up @@ -1738,8 +1740,7 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc,
bgmmc.buffer_b_chunk_sz = bgmmc.tr_b_dt_sz * rnd_up(bgmmc.N_blk, bgmmc.LDB)
* rnd_up(bgmmc.K_chunk_elems, bgmmc.wei_k_blk);

bgmmc.buffer_b_per_thread_sz
= bgmmc.buffer_b_chunk_sz * bgmmc.brgemm_batch_size;
bgmmc.buffer_b_per_thread_sz = bgmmc.buffer_b_chunk_sz;

bgmmc.buffer_reduce_per_thread_sz = 0;
if (bgmmc.reduce_kind == matmul_reduce_kind::src) {
Expand Down
Loading