Skip to content

Commit 7ee48d7

Browse files
committed
common,cpu: matmul: add fp4_e3m0 support
1 parent 73f9434 commit 7ee48d7

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

src/common/matmul.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,15 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc,
284284
// the memory to be byte aligned.
285285

286286
// s4/u4/f4 weights requires n to be multiple of 2 to be byte aligned
287-
VCHECK_MATMUL(
288-
IMPLICATION(utils::one_of(weights_desc->data_type, data_type::s4,
289-
data_type::u4, data_type::f4_e2m1),
290-
weights_desc->dims[n_idx] % 2 == 0),
287+
VCHECK_MATMUL(IMPLICATION(utils::one_of(weights_desc->data_type,
288+
data_type::s4, data_type::u4,
289+
data_type::f4_e2m1, data_type::f4_e3m0),
290+
weights_desc->dims[n_idx] % 2 == 0),
291291
VERBOSE_BAD_DIM, "weights", n_idx);
292292
// s4/u4/f4 src requires k to be multiple of 2 to be byte aligned
293293
VCHECK_MATMUL(IMPLICATION(utils::one_of(src_desc->data_type, data_type::s4,
294-
data_type::u4, data_type::f4_e2m1),
294+
data_type::u4, data_type::f4_e2m1,
295+
data_type::f4_e3m0),
295296
src_desc->dims[k_idx_src] % 2 == 0),
296297
VERBOSE_BAD_DIM, "src", n_idx);
297298

src/cpu/matmul/ref_matmul.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ struct ref_matmul_t : public primitive_t {
5252
VDISPATCH_MATMUL(
5353
is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG);
5454
VDISPATCH_MATMUL(utils::one_of(src_type, f32, bf16, f16, f8_e5m2,
55-
f8_e4m3, f4_e2m1),
55+
f8_e4m3, f4_e2m1, f4_e3m0),
5656
VERBOSE_UNSUPPORTED_DT);
5757
VDISPATCH_MATMUL(utils::one_of(wei_type, f32, bf16, f16, f8_e5m2,
58-
f8_e4m3, f4_e2m1, u8, s8, u4, s4),
58+
f8_e4m3, f4_e2m1, f4_e3m0, u8, s8, u4, s4),
5959
VERBOSE_UNSUPPORTED_DT);
6060
VDISPATCH_MATMUL(utils::one_of(dst_type, f32, bf16, f16, f8_e5m2,
61-
f8_e4m3, f4_e2m1),
61+
f8_e4m3, f4_e2m1, f4_e3m0),
6262
VERBOSE_UNSUPPORTED_DT);
63-
VDISPATCH_MATMUL(
64-
(src_type == wei_type
65-
|| utils::one_of(wei_type, u8, s8, u4, s4)),
63+
VDISPATCH_MATMUL((src_type == wei_type
64+
|| utils::one_of(wei_type, u8, s8, u4, s4,
65+
f4_e3m0)),
6666
VERBOSE_UNSUPPORTED_DT);
6767
/* int8 weights decompression support */
6868
VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8),

0 commit comments

Comments
 (0)