Skip to content

Commit 188ae7f

Browse files
amakarevtprimak
authored andcommitted
matmul: x64: added support for bf16,f16 bias dt
1 parent bf58e72 commit 188ae7f

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

src/cpu/x64/brgemm/brgemm.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,15 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg,
348348
data_type::f16)))
349349
return status::unimplemented;
350350
const auto bias_f8_e5m2_compatible
351-
= one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e5m2)
351+
= one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16,
352+
data_type::f8_e5m2)
352353
&& one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16,
353-
data_type::f8_e5m2, data_type::f8_e4m3);
354+
data_type::bf16, data_type::f8_e5m2, data_type::f8_e4m3);
354355
const auto bias_f8_e4m3_compatible
355-
= one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e4m3)
356+
= one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16,
357+
data_type::f8_e4m3)
356358
&& one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16,
357-
data_type::f8_e4m3, data_type::f8_e5m2);
359+
data_type::bf16, data_type::f8_e4m3, data_type::f8_e5m2);
358360
if (!IMPLICATION(brg->is_fp8,
359361
bias_f8_e5m2_compatible || bias_f8_e4m3_compatible))
360362
return status::unimplemented;

src/cpu/x64/matmul/brgemm_matmul.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
7070
const bool is_bia_dt_correct
7171
= IMPLICATION(is_int8 == true,
7272
one_of(bia_dt, f32, s32, s8, u8, bf16))
73-
&& IMPLICATION(!is_int8, one_of(bia_dt, f32, src_dt));
73+
&& IMPLICATION(
74+
is_f8 == true, one_of(bia_dt, f32, f16, bf16, src_dt))
75+
&& IMPLICATION(
76+
!(is_int8 || is_f8), one_of(bia_dt, f32, src_dt));
7477
return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN());
7578
};
7679

tests/benchdnn/inputs/matmul/test_matmul_fp8

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
--dt=f8_e4m3:f8_e4m3:f32,f8_e4m3,f8_e5m2:f8_e5m2:f32,f8_e5m2
55
--stag=ab,ba --wtag=ab,ba --dtag=ab
66
--runtime_dims_masks=0,2:1,1:0,3:1
7-
--bia_dt=undef,f32 --bia_mask=2
7+
--bia_dt=undef,f32,f16,bf16 --bia_mask=2
88

99
--attr-scales=
1010
--attr-post-ops=
@@ -21,8 +21,8 @@
2121

2222
--stag=ba --wtag=ab,ba --dtag=ab
2323
--runtime_dims_masks=3:1,3:3
24-
--bia_dt=f8_e4m3,f8_e5m2 --bia_mask=1,2,3
25-
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:2.25
24+
--bia_dt=f8_e4m3,f8_e5m2,f16,bf16 --bia_mask=1,2,3
25+
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:4
2626
--attr-post-ops=add:f32,sum+mul:s32:per_oc+linear:2:-1
2727
--batch=shapes_2d
2828

0 commit comments

Comments
 (0)