Skip to content

Commit 7363edc

Browse files
inteldimitriusazhai219
authored andcommitted
x64: matmul: fix LDA init via strides (uxlfoundation#2462)
1 parent 9502763 commit 7363edc

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,8 @@ struct matmul_avx512_blocking_params_t {
688688

689689
bgmmc.use_buffer_c = is_buffer_c_required(
690690
bgmmc.acc_dt, bgmmc.dst_dt, bgmmc.with_sum);
691-
bgmmc.LDA = bgmmc.use_buffer_a || bgmmc.treat_transposed_A_as_plain
691+
bgmmc.LDA = bgmmc.adjust_a_strides || bgmmc.use_buffer_a
692+
|| bgmmc.treat_transposed_A_as_plain
692693
? get_actual_lda(bgmmc.use_buffer_a, bgmmc.tr_a_dt_sz)
693694
: bgmmc.A_strides[1] / bgmmc.a_dt_sz;
694695
}
@@ -1522,10 +1523,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
15221523

15231524
// We need to correct A_strides if batched dimensions are merged in M and
15241525
// A layout is formally transposed but could be treated as plain
1525-
if (merge_batch_dims_into_M
1526-
&& (src_d.matches_tag(acbd) || bgmmc.treat_transposed_A_as_plain)) {
1527-
bgmmc.A_strides[1] = bgmmc.A_strides[2];
1528-
}
1526+
bgmmc.adjust_a_strides = merge_batch_dims_into_M
1527+
&& (src_d.matches_tag(acbd) || bgmmc.treat_transposed_A_as_plain);
1528+
if (bgmmc.adjust_a_strides) bgmmc.A_strides[1] = bgmmc.A_strides[2];
15291529

15301530
// We need to correct C_strides if batched dimensions are merged in M and
15311531
// C layout is formally transposed but could be treated as plain

src/cpu/x64/matmul/brgemm_matmul_utils.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ struct brgemm_matmul_conf_t {
176176
bool blocked_B;
177177
bool treat_transposed_A_as_plain;
178178

179+
// A_strides could be changed during
180+
// Matmul conf initialization in case when batches merged into M.
181+
// This flag helps to properly initialize LDA when A_strides
182+
// were changed.
183+
bool adjust_a_strides = false;
184+
179185
dim_t zp_a_comp_shift_n;
180186
dim_t zp_a_comp_elems_per_thr;
181187

tests/benchdnn/inputs/matmul/harness_matmul_regression_f32

+4
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@
2020
# test for K parallel_reduction with batched case
2121
--reset
2222
--stag=acb --wtag=abc --dtag=abc 2x16x2048:2x2048x16_n"large_K_with_batch"
23+
24+
# test correct LDA initialization, when batches are merged into M dimension
25+
--reset
26+
--stag=abcd --dtag=abcd 2x1x8x2:1x1x2x8_n"merge_batches_into_M"

0 commit comments

Comments
 (0)