From adb9d66a5bb7959128bc18b53b81ab47a61bfb6c Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Tue, 18 Mar 2025 12:48:10 -0700 Subject: [PATCH 1/2] benchdnn: matmul: adjust int4 invalid cases --- tests/benchdnn/matmul/matmul.cpp | 42 ++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/tests/benchdnn/matmul/matmul.cpp b/tests/benchdnn/matmul/matmul.cpp index ca709152f3d..92626e7d3f4 100644 --- a/tests/benchdnn/matmul/matmul.cpp +++ b/tests/benchdnn/matmul/matmul.cpp @@ -676,15 +676,41 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) { } } + // Check int4 weights byte alignment if format is specified. if ((prb->wei_dt() == dnnl_s4 || prb->wei_dt() == dnnl_u4) - && (prb->n % 2)) { - BENCHDNN_PRINT(2, - "[INVALID][%s:%d]: Int4 Weights decompression requires OC " - "('%d') to be even.\n", - __FILE__, __LINE__, (int)prb->n); - res->state = SKIPPED; - res->reason = skip_reason::invalid_case; - return; + && (!prb->strides[WEI].empty() + || (prb->wtag != tag::any && prb->wtag != tag::undef))) { + const auto &weights_rt_dims = get_runtime_dims( + prb->weights_dims(), prb->weights_runtime_dim_mask()); + const auto wei_md + = dnn_mem_t::init_md(prb->ndims, weights_rt_dims.data(), + prb->wei_dt(), prb->wtag, prb->strides[STRIDES_WEI]); + + const auto wei_strides = query_md_strides(wei_md); + int n_unit_strides = 0; + for (int d = 0; d < query_md_ndims(wei_md); d++) { + if (wei_strides[d] == 1) { + n_unit_strides++; + if (n_unit_strides > 1) { + BENCHDNN_PRINT(2, + "[INVALID][%s:%d]: Int4 Weights decompression " + "requires byte alignment for the tensor.\n", + __FILE__, __LINE__); + res->state = SKIPPED; + res->reason = skip_reason::invalid_case; + return; + } + } + if (wei_strides[d] > 1 && (wei_strides[d] % 2)) { + BENCHDNN_PRINT(2, + "[INVALID][%s:%d]: Int4 Weights decompression requires " + "byte alignment for the tensor.\n", + __FILE__, __LINE__); + res->state = SKIPPED; + res->reason = skip_reason::invalid_case; + return; + } + } } auto src_rt_mask = prb->src_runtime_dim_mask(); From 44b0e820a2c5ddfb2eb2cdaeaacb00ee0881365c Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Tue, 18 Mar 2025 13:35:52 -0700 Subject: [PATCH 2/2] common: matmul: adjust check for int4 tensors w.r.t. strides --- src/common/matmul.cpp | 48 +++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/common/matmul.cpp b/src/common/matmul.cpp index 756ace608f7..7de9b5317ac 100644 --- a/src/common/matmul.cpp +++ b/src/common/matmul.cpp @@ -355,21 +355,39 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, ? utils::get_dims_mask(dst_desc->dims, op_d.bias_desc.dims, ndims) : 0; - // TODO: requirement is for innermost dim to be multiple of 2 for - // the memory to be byte aligned. - - // s4/u4/f4 weights requires n to be multiple of 2 to be byte aligned - VCHECK_MATMUL(IMPLICATION(utils::one_of(weights_desc->data_type, - data_type::s4, data_type::u4, - data_type::f4_e2m1, data_type::f4_e3m0), - weights_desc->dims[n_idx] % 2 == 0), - VERBOSE_BAD_DIM, "weights", n_idx); - // s4/u4/f4 src requires k to be multiple of 2 to be byte aligned - VCHECK_MATMUL(IMPLICATION(utils::one_of(src_desc->data_type, data_type::s4, - data_type::u4, data_type::f4_e2m1, - data_type::f4_e3m0), - src_desc->dims[k_idx_src] % 2 == 0), - VERBOSE_BAD_DIM, "src", n_idx); + using namespace data_type; + if (weights_desc->format_kind == format_kind::blocked + && utils::one_of( + weights_desc->data_type, s4, u4, f4_e2m1, f4_e3m0)) { + const auto &wei_strides = weights_desc->format_desc.blocking.strides; + + int n_unit_strides = 0; + for (int d = 0; d < ndims; d++) { + if (wei_strides[d] == 1) { + n_unit_strides++; + VCHECK_MATMUL( + n_unit_strides <= 1, VERBOSE_BAD_DIM, "weights", d); + } + VCHECK_MATMUL( + IMPLICATION(wei_strides[d] > 1, wei_strides[d] % 2 == 0), + VERBOSE_BAD_DIM, "weights", d); + } + } + if (src_desc->format_kind == format_kind::blocked + && utils::one_of(src_desc->data_type, s4, u4, f4_e2m1, f4_e3m0)) { + const auto &src_strides = src_desc->format_desc.blocking.strides; + + int n_unit_strides = 0; + for (int d = 0; d < ndims; d++) { + if (src_strides[d] == 1) { + n_unit_strides++; + VCHECK_MATMUL(n_unit_strides <= 1, VERBOSE_BAD_DIM, "src", d); + } + VCHECK_MATMUL( + IMPLICATION(src_strides[d] > 1, src_strides[d] % 2 == 0), + VERBOSE_BAD_DIM, "src", d); + } + } // check if other dims match. for (int d = 0; d < ndims - 2; ++d) {