@@ -240,20 +240,27 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
240
240
brg->ldb = brg->load_dim / brg->ld_block ;
241
241
brg->ldb_tail = brg->load_dim % brg->ld_block ;
242
242
243
+ const int max_vpad = nstl::max (
244
+ brg->brgattr .max_top_vpad , brg->brgattr .max_bottom_vpad );
245
+
243
246
int adj_ld_block2 = calculate_ldb_params (brg, 4 );
244
247
int max_bcast_block = calculate_max_bcast_block (brg, adj_ld_block2);
245
-
246
248
// reduce 'ld_block2' to allow a larger 'bd_block'
247
- const int max_vpad = nstl::max (
248
- brg->brgattr .max_top_vpad , brg->brgattr .max_bottom_vpad );
249
249
if (is_superset (brg->isa_impl , avx2) && max_bcast_block < max_vpad) {
250
- adj_ld_block2 = calculate_ldb_params (brg, 2 );
251
- max_bcast_block = calculate_max_bcast_block (brg, adj_ld_block2);
250
+ for (int try_ld_block2 = 2 ; try_ld_block2 > 0 ; --try_ld_block2) {
251
+ adj_ld_block2 = calculate_ldb_params (brg, try_ld_block2);
252
+ max_bcast_block = calculate_max_bcast_block (brg, adj_ld_block2);
253
+ if (max_bcast_block >= max_vpad) break ;
254
+ }
255
+ // bcast block in brgemm kernel should be greater than virtual
256
+ // padding to avoid possible functional issues
257
+ if (max_bcast_block < max_vpad) return status::unimplemented;
252
258
}
253
259
254
- const int min_block = 1 ;
260
+ const int min_block = nstl::max (1 , max_vpad);
261
+
255
262
float best_bd_block_eff = 0 .f ;
256
- brg->bd_block = 1 ;
263
+ brg->bd_block = max_bcast_block ;
257
264
for (int bd_block = max_bcast_block; bd_block >= min_block;
258
265
bd_block--) {
259
266
const auto bd_block_disb = static_cast <float >(brg->bcast_dim )
0 commit comments