@@ -355,21 +355,39 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc,
355
355
? utils::get_dims_mask (dst_desc->dims , op_d.bias_desc .dims , ndims)
356
356
: 0 ;
357
357
358
- // TODO: requirement is for innermost dim to be multiple of 2 for
359
- // the memory to be byte aligned.
360
-
361
- // s4/u4/f4 weights requires n to be multiple of 2 to be byte aligned
362
- VCHECK_MATMUL (IMPLICATION (utils::one_of (weights_desc->data_type ,
363
- data_type::s4, data_type::u4,
364
- data_type::f4_e2m1, data_type::f4_e3m0),
365
- weights_desc->dims [n_idx] % 2 == 0 ),
366
- VERBOSE_BAD_DIM, " weights" , n_idx);
367
- // s4/u4/f4 src requires k to be multiple of 2 to be byte aligned
368
- VCHECK_MATMUL (IMPLICATION (utils::one_of (src_desc->data_type , data_type::s4,
369
- data_type::u4, data_type::f4_e2m1,
370
- data_type::f4_e3m0),
371
- src_desc->dims [k_idx_src] % 2 == 0 ),
372
- VERBOSE_BAD_DIM, " src" , n_idx);
358
+ using namespace data_type ;
359
+ if (weights_desc->format_kind == format_kind::blocked
360
+ && utils::one_of (
361
+ weights_desc->data_type , s4, u4, f4_e2m1, f4_e3m0)) {
362
+ const auto &wei_strides = weights_desc->format_desc .blocking .strides ;
363
+
364
+ int n_unit_strides = 0 ;
365
+ for (int d = 0 ; d < ndims; d++) {
366
+ if (wei_strides[d] == 1 ) {
367
+ n_unit_strides++;
368
+ VCHECK_MATMUL (
369
+ n_unit_strides <= 1 , VERBOSE_BAD_DIM, " weights" , d);
370
+ }
371
+ VCHECK_MATMUL (
372
+ IMPLICATION (wei_strides[d] > 1 , wei_strides[d] % 2 == 0 ),
373
+ VERBOSE_BAD_DIM, " weights" , d);
374
+ }
375
+ }
376
+ if (src_desc->format_kind == format_kind::blocked
377
+ && utils::one_of (src_desc->data_type , s4, u4, f4_e2m1, f4_e3m0)) {
378
+ const auto &src_strides = src_desc->format_desc .blocking .strides ;
379
+
380
+ int n_unit_strides = 0 ;
381
+ for (int d = 0 ; d < ndims; d++) {
382
+ if (src_strides[d] == 1 ) {
383
+ n_unit_strides++;
384
+ VCHECK_MATMUL (n_unit_strides <= 1 , VERBOSE_BAD_DIM, " src" , d);
385
+ }
386
+ VCHECK_MATMUL (
387
+ IMPLICATION (src_strides[d] > 1 , src_strides[d] % 2 == 0 ),
388
+ VERBOSE_BAD_DIM, " src" , d);
389
+ }
390
+ }
373
391
374
392
// check if other dims match.
375
393
for (int d = 0 ; d < ndims - 2 ; ++d) {
0 commit comments