Skip to content

Commit d97791a

Browse files
dzarukinvpirogov
authored andcommitted
common: matmul: adjust check for int4 tensors w.r.t. strides
1 parent aef1cec commit d97791a

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

src/common/matmul.cpp

+33-15
Original file line numberDiff line numberDiff line change
@@ -355,21 +355,39 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc,
355355
? utils::get_dims_mask(dst_desc->dims, op_d.bias_desc.dims, ndims)
356356
: 0;
357357

358-
// TODO: requirement is for innermost dim to be multiple of 2 for
359-
// the memory to be byte aligned.
360-
361-
// s4/u4/f4 weights requires n to be multiple of 2 to be byte aligned
362-
VCHECK_MATMUL(IMPLICATION(utils::one_of(weights_desc->data_type,
363-
data_type::s4, data_type::u4,
364-
data_type::f4_e2m1, data_type::f4_e3m0),
365-
weights_desc->dims[n_idx] % 2 == 0),
366-
VERBOSE_BAD_DIM, "weights", n_idx);
367-
// s4/u4/f4 src requires k to be multiple of 2 to be byte aligned
368-
VCHECK_MATMUL(IMPLICATION(utils::one_of(src_desc->data_type, data_type::s4,
369-
data_type::u4, data_type::f4_e2m1,
370-
data_type::f4_e3m0),
371-
src_desc->dims[k_idx_src] % 2 == 0),
372-
VERBOSE_BAD_DIM, "src", n_idx);
358+
using namespace data_type;
359+
if (weights_desc->format_kind == format_kind::blocked
360+
&& utils::one_of(
361+
weights_desc->data_type, s4, u4, f4_e2m1, f4_e3m0)) {
362+
const auto &wei_strides = weights_desc->format_desc.blocking.strides;
363+
364+
int n_unit_strides = 0;
365+
for (int d = 0; d < ndims; d++) {
366+
if (wei_strides[d] == 1) {
367+
n_unit_strides++;
368+
VCHECK_MATMUL(
369+
n_unit_strides <= 1, VERBOSE_BAD_DIM, "weights", d);
370+
}
371+
VCHECK_MATMUL(
372+
IMPLICATION(wei_strides[d] > 1, wei_strides[d] % 2 == 0),
373+
VERBOSE_BAD_DIM, "weights", d);
374+
}
375+
}
376+
if (src_desc->format_kind == format_kind::blocked
377+
&& utils::one_of(src_desc->data_type, s4, u4, f4_e2m1, f4_e3m0)) {
378+
const auto &src_strides = src_desc->format_desc.blocking.strides;
379+
380+
int n_unit_strides = 0;
381+
for (int d = 0; d < ndims; d++) {
382+
if (src_strides[d] == 1) {
383+
n_unit_strides++;
384+
VCHECK_MATMUL(n_unit_strides <= 1, VERBOSE_BAD_DIM, "src", d);
385+
}
386+
VCHECK_MATMUL(
387+
IMPLICATION(src_strides[d] > 1, src_strides[d] % 2 == 0),
388+
VERBOSE_BAD_DIM, "src", d);
389+
}
390+
}
373391

374392
// check if other dims match.
375393
for (int d = 0; d < ndims - 2; ++d) {

0 commit comments

Comments
 (0)