Skip to content

Commit 5d7ed69

Browse files
Yobodovs/amx blocking heuristics fixes (#2938)
1 parent caa770a commit 5d7ed69

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ size_t jit_brgemm_amx_uker_base_t::C_offset(const brgemm_iteration_t &bi,
737737
const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0);
738738
const auto bd = get_out_bd(bi.bdi, bdb, inp_bd);
739739
const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0);
740-
size_t ldc_elem = (size_t)ldb * bi.ldi->block(0);
740+
size_t ldc_elem = (size_t)ldb * brg.ld_block;
741741
size_t bloc_idx = ldc_elem / brg.LDC;
742742
size_t in_block = ldc_elem % brg.LDC;
743743

src/cpu/x64/cpu_reducer.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,11 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t<data_type> {
195195
for (int i = 0; i < nloads; ++i) {
196196
size_t off = base_off + i * load_len;
197197

198-
if (load_len == typesize)
199-
this->uni_add(Xmm(i), this->ptr[reg_src + off]);
200-
else if (load_len == vlen)
198+
if (load_len == typesize) {
199+
assert(nloads == 1);
200+
this->movd(Xmm(nloads + i), this->ptr[reg_src + off]);
201+
this->uni_add(Xmm(i), Xmm(nloads + i));
202+
} else if (load_len == vlen)
201203
this->uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]);
202204
else
203205
assert(!"unsupported");

src/cpu/x64/matmul/amx_blocking_heuristics.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ bool matmul_amx_blocking_params_macro_t::is_supported(
113113
&& bgmmc.orig_wei_dt == bgmmc.wei_dt && bgmmc.is_amx
114114
&& !bgmmc.is_runtime_N && !bgmmc.is_runtime_M && a_dt_ok && a_tag_ok
115115
&& (bgmmc.reduce_kind == matmul_reduce_kind::undef) && b_tag_ok
116-
&& b_dt_ok && !has_zp;
116+
&& b_dt_ok && !has_zp && !bgmmc.packed_sparse_weights;
117117
}
118118

119119
bool matmul_amx_blocking_params_macro_t::divs_are_acceptable() const {
@@ -955,9 +955,11 @@ void matmul_amx_blocking_params_micro_t::set_blocking_parameters(
955955
if (brgemm_k_elems >= K) {
956956
k_blk_ = K;
957957
k_chunk_size_ = 1;
958+
brgemm_batch_size_ = 1;
958959
} else {
959960
k_blk_ = brgemm_k_elems;
960961
k_chunk_size_ = 1;
962+
brgemm_batch_size_ = 1;
961963
}
962964
} else if (current_k_tail == 0
963965
&& K % (k_blk_ * brgemm_batch_size_) == 0) {

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1340,8 +1340,10 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
13401340

13411341
const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag);
13421342
// When M == 1 MatMul always considers A to be non-transposed even if A md
1343-
// was created using "ba" tag.
1344-
bgmmc.treat_A_as_plain = bgmmc.M == 1;
1343+
// was created using "ba" tag. It is not plain in cab layout.
1344+
bgmmc.treat_A_as_plain = bgmmc.M == 1
1345+
&& IMPLICATION(bgmmc.batch != 1,
1346+
bm_conf_utils.check_is_plain(bgmmc.src_tag));
13451347
bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_A_as_plain)
13461348
|| bgmmc.src_tag == adbc);
13471349
// For batched problems with plain A and C and fully broadcasted across B
@@ -1738,8 +1740,7 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc,
17381740
bgmmc.buffer_b_chunk_sz = bgmmc.tr_b_dt_sz * rnd_up(bgmmc.N_blk, bgmmc.LDB)
17391741
* rnd_up(bgmmc.K_chunk_elems, bgmmc.wei_k_blk);
17401742

1741-
bgmmc.buffer_b_per_thread_sz
1742-
= bgmmc.buffer_b_chunk_sz * bgmmc.brgemm_batch_size;
1743+
bgmmc.buffer_b_per_thread_sz = bgmmc.buffer_b_chunk_sz;
17431744

17441745
bgmmc.buffer_reduce_per_thread_sz = 0;
17451746
if (bgmmc.reduce_kind == matmul_reduce_kind::src) {

0 commit comments

Comments
 (0)