Skip to content

Commit 2eb3dd1

Browse files
ankalinindzarukin
authored andcommittedDec 5, 2024·
x64: brgemm: avx: bd_block should not be smaller than vpad
1 parent 19ef223 commit 2eb3dd1

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed
 

‎src/cpu/x64/brgemm/brgemm_utils.cpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -240,20 +240,27 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
240240
brg->ldb = brg->load_dim / brg->ld_block;
241241
brg->ldb_tail = brg->load_dim % brg->ld_block;
242242

243+
const int max_vpad = nstl::max(
244+
brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad);
245+
243246
int adj_ld_block2 = calculate_ldb_params(brg, 4);
244247
int max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2);
245-
246248
// reduce 'ld_block2' to allow a larger 'bd_block'
247-
const int max_vpad = nstl::max(
248-
brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad);
249249
if (is_superset(brg->isa_impl, avx2) && max_bcast_block < max_vpad) {
250-
adj_ld_block2 = calculate_ldb_params(brg, 2);
251-
max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2);
250+
for (int try_ld_block2 = 2; try_ld_block2 > 0; --try_ld_block2) {
251+
adj_ld_block2 = calculate_ldb_params(brg, try_ld_block2);
252+
max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2);
253+
if (max_bcast_block >= max_vpad) break;
254+
}
255+
// bcast block in brgemm kernel should be greater than virtual
256+
// padding to avoid possible functional issues
257+
if (max_bcast_block < max_vpad) return status::unimplemented;
252258
}
253259

254-
const int min_block = 1;
260+
const int min_block = nstl::max(1, max_vpad);
261+
255262
float best_bd_block_eff = 0.f;
256-
brg->bd_block = 1;
263+
brg->bd_block = max_bcast_block;
257264
for (int bd_block = max_bcast_block; bd_block >= min_block;
258265
bd_block--) {
259266
const auto bd_block_disb = static_cast<float>(brg->bcast_dim)

0 commit comments

Comments
 (0)
Please sign in to comment.