Skip to content

Commit 28e5f46

Browse files
committed
x64: brgemm: initial fp8 enabling
1 parent b5a0f3a commit 28e5f46

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

src/cpu/x64/brgemm/brgemm.cpp

+27-5
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
249249

250250
if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
251251

252-
if (utils::everyone_is(
253-
false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16))
252+
if (utils::everyone_is(false, brg->is_int8, brg->is_bf16, brg->is_f32,
253+
brg->is_f16, brg->is_fp8))
254254
return status::unimplemented;
255255

256256
// Only amx_int8 kernel supports u8 weights.
@@ -319,6 +319,10 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
319319
is_superset(brg->isa_impl, avx512_core_fp16)
320320
|| is_superset(brg->isa_impl, avx2_vnni_2)))
321321
return status::unimplemented;
322+
if (!IMPLICATION(one_of(data_type::f8_e5m2, dt_bias, dt_d)
323+
|| one_of(data_type::f8_e4m3, dt_bias, dt_d),
324+
mayiuse(avx512_core_amx_fp16)))
325+
return status::unimplemented;
322326
// check that combination of data types is allowed
323327
if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
324328
&& (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
@@ -340,6 +344,17 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
340344
&& one_of(dt_bias, data_type::undef, data_type::f32,
341345
data_type::f16)))
342346
return status::unimplemented;
347+
const auto bias_f8_e5m2_compatible
348+
= one_of(dt_d, data_type::f32, data_type::f8_e5m2)
349+
&& one_of(dt_bias, data_type::undef, data_type::f32,
350+
data_type::f8_e5m2);
351+
const auto bias_f8_e4m3_compatible
352+
= one_of(dt_d, data_type::f32, data_type::f8_e4m3)
353+
&& one_of(dt_bias, data_type::undef, data_type::f32,
354+
data_type::f8_e4m3);
355+
if (!IMPLICATION(brg->is_fp8,
356+
bias_f8_e5m2_compatible || bias_f8_e4m3_compatible))
357+
return status::unimplemented;
343358

344359
brg->dt_d = dt_d;
345360
brg->typesize_D = types::data_type_size(brg->dt_d);
@@ -541,6 +556,10 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
541556
&& brg->prfC.dist2 < 0)
542557
brg->prfC.dist2 = 0;
543558

559+
// TODO: update conditions once other brgemm implementations are enabled
560+
// Currently, fp8 via AMX f16 convert only supported in non-unrolled kernel
561+
if (brg->is_fp8 && brg->brgattr.use_uker) return status::unimplemented;
562+
544563
return status::success;
545564
}
546565

@@ -597,16 +616,19 @@ status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
597616

598617
//TODO: Add support of tail processing by reduction dimension
599618
auto rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block;
600-
if (brg.is_bf32) rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
619+
if (brg.is_input_convert())
620+
rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
601621

602622
palette_config_t *buff = (palette_config_t *)(palette);
603623

604624
char *_tc = (char *)(buff);
605625
for (int i = 0; i < max_palette_size_in_bytes; i++)
606626
_tc[i] = 0;
607627

608-
const int typesize_A = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_A;
609-
const int typesize_B = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_B;
628+
const int typesize_A
629+
= brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_A;
630+
const int typesize_B
631+
= brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_B;
610632

611633
const int rd_step = 4 / typesize_A;
612634

src/cpu/x64/brgemm/brgemm_types.hpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ struct brgemm_t {
265265
bool is_tmm = false;
266266
bool is_int8 = false, is_int8_tmm = false;
267267
bool is_bf16 = false, is_bf16_tmm = false, is_bf16_emu = false;
268+
bool is_fp8 = false, is_fp8_tmm = false;
268269
bool is_f16 = false, is_f16_tmm = false;
269270
bool is_f32 = false;
270271
bool is_bf32 = false;
@@ -295,6 +296,13 @@ struct brgemm_t {
295296
const primitive_attr_t *attr() const { return attr_; };
296297
const memory_desc_t *dst_md() const { return dst_md_; };
297298

299+
// return 'true' when FP8 MAC is not natively supported by the CPU ISA
300+
bool is_fp8_via_convert() const {
301+
return is_fp8 && utils::one_of(isa_impl, avx10_1_512_amx_fp16);
302+
}
303+
304+
bool is_input_convert() const { return is_bf32 || is_fp8_via_convert(); }
305+
298306
bool is_row_major() const {
299307
assert(layout != brgemm_layout_undef);
300308
return layout == brgemm_row_major;
@@ -355,7 +363,7 @@ struct brgemm_t {
355363
if (is_tmm) {
356364
constexpr int tilesize = 1024;
357365
sz = get_num_C_tiles() * tilesize; // postops buffer
358-
if (is_bf32) {
366+
if (is_input_convert()) {
359367
const int n_bdb = bd_block2;
360368
const int n_rdb = rdb + (rdb_tail != 0);
361369
const int n_ldb = ldb + (ldb_tail != 0);

src/cpu/x64/brgemm/brgemm_utils.cpp

+22-10
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ enum {
4242
impl::data_type_t get_accum_datatype(brgemm_t *brg) {
4343
// this assert should check if 'init_kernel_datatype()' was previously
4444
// called.
45-
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
45+
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16
46+
|| brg->is_fp8);
4647
return brg->is_int8 ? data_type::s32 : data_type::f32;
4748
}
4849

@@ -54,7 +55,10 @@ void init_kernel_datatype(
5455
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
5556
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
5657
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
57-
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
58+
brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3)
59+
&& one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3);
60+
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16
61+
|| brg->is_fp8);
5862
}
5963

6064
void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
@@ -145,12 +149,15 @@ void set_isa_impl(brgemm_t *brg) {
145149
avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni,
146150
is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2),
147151
avx2_vnni_2, is_isa_ok(avx2_vnni), avx2_vnni);
152+
} else if (brg->is_fp8) {
153+
brg->isa_impl = utils::map(true, isa_undef,
154+
is_isa_ok(avx10_1_512_amx_fp16), avx10_1_512_amx_fp16);
148155
}
149156
}
150157

151158
void set_brg_vmm(brgemm_t *brg) {
152159
brg->is_tmm = brg->is_int8_tmm || brg->is_bf16_tmm || brg->is_f16_tmm
153-
|| brg->is_bf32;
160+
|| brg->is_bf32 || brg->is_fp8_tmm;
154161
brg->is_zmm = !brg->is_tmm && mayiuse(avx512_core)
155162
&& is_superset(brg->isa_impl, avx512_core);
156163
brg->is_ymm
@@ -672,11 +679,10 @@ status_t brgemm_blocking(brgemm_t *brg) {
672679
brg->load_nt_B
673680
= (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true);
674681

675-
const auto max_rd_block
676-
= (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 32
677-
: 64;
678-
const auto rd_block_step
679-
= (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 2 : 4;
682+
const bool reduce_by_words = brg->is_bf16_tmm || brg->is_f16_tmm
683+
|| brg->is_input_convert();
684+
const auto max_rd_block = reduce_by_words ? 32 : 64;
685+
const auto rd_block_step = reduce_by_words ? 2 : 4;
680686
// TODO: if rd_block calculated is very small then maybe it makes
681687
// sense to use 1x2 or 2x1 blocking with supporting rd_block
682688
// and rdb_tail
@@ -692,14 +698,18 @@ status_t brgemm_blocking(brgemm_t *brg) {
692698

693699
// Remove these guards in the future (add tail processing by reduction
694700
// dimension)
695-
if (!IMPLICATION(brg->rdb > 0 && brg->rdb_tail, brg->is_bf32))
701+
// TODO: these checks do not work for fp8-f16 and f16-fp8 cfgs
702+
if (!IMPLICATION(
703+
brg->rdb > 0 && brg->rdb_tail, brg->is_input_convert())) {
696704
return status::unimplemented;
705+
}
697706
if (!IMPLICATION(
698707
(brg->rdb_tail
699708
% ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4))
700709
!= 0,
701-
brg->is_bf32))
710+
brg->is_input_convert())) {
702711
return status::unimplemented;
712+
}
703713

704714
//TODO: check this condition
705715
brg->interleave_tilestores_ = brg->beta == 0
@@ -822,6 +832,8 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
822832
brg->is_bf32 = is_bf32
823833
&& utils::one_of(brg->isa_user, isa_undef, avx512_core_amx)
824834
&& mayiuse(avx512_core_amx);
835+
brg->is_fp8_tmm
836+
= brg->is_fp8 && one_of(brg->isa_impl, avx512_core_amx_fp16);
825837

826838
brg->has_int8_vnni = isa_has_int8_vnni(brg->isa_impl);
827839

0 commit comments

Comments
 (0)