Skip to content

Commit 38ff3a7

Browse files
authored
[rls-v3.6 backport] cpu: x64: matmul/brgemm: enable f16 dst dt for int8 matmul (uxlfoundation#2282)
1 parent 2eb3dd1 commit 38ff3a7

File tree

5 files changed

+16
-12
lines changed

5 files changed

+16
-12
lines changed

src/cpu/x64/brgemm/brgemm_utils.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,11 @@ void set_isa_impl(brgemm_desc_t *brg) {
145145
is_isa_ok(avx512_core_fp16), avx512_core_fp16);
146146
}
147147
} else if (brg->is_int8) {
148-
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx),
149-
avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni,
148+
brg->isa_impl = utils::map(true, isa_undef,
149+
is_isa_ok(avx512_core_amx_fp16), avx512_core_amx_fp16,
150+
is_isa_ok(avx512_core_amx), avx512_core_amx,
151+
is_isa_ok(avx512_core_fp16), avx512_core_fp16,
152+
is_isa_ok(avx512_core_vnni), avx512_core_vnni,
150153
is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2),
151154
avx2_vnni_2, is_isa_ok(avx2_vnni), avx2_vnni);
152155
} else if (brg->is_fp8) {
@@ -847,7 +850,8 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
847850

848851
brg->isa_user = isa;
849852
set_isa_impl(brg);
850-
brg->is_int8_tmm = brg->is_int8 && brg->isa_impl == avx512_core_amx;
853+
brg->is_int8_tmm
854+
= brg->is_int8 && is_superset(brg->isa_impl, avx512_core_amx);
851855
brg->is_bf16_tmm = brg->is_bf16 && brg->isa_impl == avx512_core_amx;
852856
brg->is_f16_tmm = brg->is_f16 && brg->isa_impl == avx512_core_amx_fp16;
853857
brg->is_bf32 = is_bf32

src/cpu/x64/matmul/brgemm_matmul.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
5252

5353
const bool is_f32 = everyone_is(f32, src_dt, wei_dt, dst_dt);
5454
const bool is_int8 = one_of(src_dt, u8, s8) && wei_dt == s8
55-
&& one_of(dst_dt, u8, s8, s32, f32, bf16);
55+
&& one_of(dst_dt, u8, s8, s32, f32, f16, bf16);
5656
const bool is_f8 = one_of(src_dt, f8_e5m2, f8_e4m3)
5757
&& one_of(wei_dt, f8_e5m2, f8_e4m3)
5858
&& one_of(dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3);

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ status_t check_isa_with_datatype(
178178
= IMPLICATION(bm_conf_utils.is_f32(),
179179
one_of(isa, avx512_core, avx2) || bm_conf_utils.is_bf32())
180180
&& IMPLICATION(bm_conf_utils.is_int8(),
181-
one_of(isa, avx512_core_amx, avx512_core_vnni, avx512_core,
182-
avx2_vnni_2, avx2_vnni))
181+
is_superset(isa, avx512_core)
182+
|| is_superset(isa, avx2_vnni))
183183
&& IMPLICATION(bm_conf_utils.is_bf16(),
184184
one_of(isa, avx512_core_amx, avx512_core_bf16, avx2_vnni_2))
185185
&& IMPLICATION(bm_conf_utils.is_f16(),

tests/benchdnn/inputs/matmul/test_matmul_ci

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Plain cases
22
--reset
3-
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,f8_e5m2:f8_e4m3:f32,u8:s8:s8,s8:s8:f32
3+
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,f8_e5m2:f8_e4m3:f32,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
44
--bia_dt=f32
55
--bia_mask=2
66
--batch=shapes_2d_ci
@@ -10,7 +10,7 @@
1010

1111
# Post-ops check for different data types
1212
--reset
13-
--dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32
13+
--dt=f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
1414
--attr-post-ops=sum+relu:0.5+add:f32
1515
--batch=shapes_2d_ci
1616

@@ -31,7 +31,7 @@
3131

3232
# Different tags
3333
--reset
34-
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32
34+
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
3535
--stag=ab,ba
3636
--wtag=ab,ba
3737
--dtag=ab,ba
@@ -54,7 +54,7 @@
5454

5555
# Arg scales check
5656
--reset
57-
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32
57+
--dt=f64,f32,bf16,f16,f8_e5m2,f8_e4m3,u8:s8:u8,s8:s8:f32,s8:s8:f16,s8:u8:f16
5858
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:2,wei:per_oc
5959
--batch=shapes_2d_ci
6060

@@ -68,7 +68,7 @@
6868

6969
# Zero-points check
7070
--reset
71-
--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16
71+
--dt=s8:s8:s8,u8:s8:f32,u8:s8:bf16,s8:s8:f16,s8:u8:f16
7272
--attr-zero-points=src:common:1+wei:common:-1+dst:common:2
7373
--batch=shapes_2d_ci
7474

tests/benchdnn/inputs/matmul/test_matmul_int8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# int8
22
--reset
33

4-
--dt=u8:s8:s8,s8:s8:f32
4+
--dt=u8:s8:s8,s8:s8:f32,s8:s8:f16,s8:u8:f16
55
--stag=ab --wtag=ab,ba --dtag=ab
66
--runtime_dims_masks=0,2:1,1:0,3:1
77
--bia_dt=undef,f32 --bia_mask=2

0 commit comments

Comments
 (0)