@@ -249,8 +249,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
249
249
250
250
if (M <= 0 || N <= 0 || K <= 0 ) return status::invalid_arguments;
251
251
252
- if (utils::everyone_is (
253
- false , brg->is_int8 , brg->is_bf16 , brg-> is_f32 , brg-> is_f16 ))
252
+ if (utils::everyone_is (false , brg-> is_int8 , brg-> is_bf16 , brg-> is_f32 ,
253
+ brg->is_f16 , brg->is_fp8 ))
254
254
return status::unimplemented;
255
255
256
256
// Only amx_int8 kernel supports u8 weights.
@@ -319,6 +319,10 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
319
319
is_superset (brg->isa_impl , avx512_core_fp16)
320
320
|| is_superset (brg->isa_impl , avx2_vnni_2)))
321
321
return status::unimplemented;
322
+ if (!IMPLICATION (one_of (data_type::f8_e5m2, dt_bias, dt_d)
323
+ || one_of (data_type::f8_e4m3, dt_bias, dt_d),
324
+ mayiuse (avx512_core_amx_fp16)))
325
+ return status::unimplemented;
322
326
// check that combination of data types is allowed
323
327
if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
324
328
&& (!one_of (dt_d, data_type::u8, data_type::s8, data_type::s32,
@@ -340,6 +344,17 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
340
344
&& one_of (dt_bias, data_type::undef, data_type::f32,
341
345
data_type::f16)))
342
346
return status::unimplemented;
347
+ const auto bias_f8_e5m2_compatible
348
+ = one_of (dt_d, data_type::f32, data_type::f8_e5m2)
349
+ && one_of (dt_bias, data_type::undef, data_type::f32,
350
+ data_type::f8_e5m2);
351
+ const auto bias_f8_e4m3_compatible
352
+ = one_of (dt_d, data_type::f32, data_type::f8_e4m3)
353
+ && one_of (dt_bias, data_type::undef, data_type::f32,
354
+ data_type::f8_e4m3);
355
+ if (!IMPLICATION (brg->is_fp8 ,
356
+ bias_f8_e5m2_compatible || bias_f8_e4m3_compatible))
357
+ return status::unimplemented;
343
358
344
359
brg->dt_d = dt_d;
345
360
brg->typesize_D = types::data_type_size (brg->dt_d );
@@ -541,6 +556,10 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
541
556
&& brg->prfC .dist2 < 0 )
542
557
brg->prfC .dist2 = 0 ;
543
558
559
+ // TODO: update conditions once other brgemm implementations are enabled
560
+ // Currently, fp8 via AMX f16 convert only supported in non-unrolled kernel
561
+ if (brg->is_fp8 && brg->brgattr .use_uker ) return status::unimplemented;
562
+
544
563
return status::success;
545
564
}
546
565
@@ -597,16 +616,19 @@ status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
597
616
598
617
// TODO: Add support of tail processing by reduction dimension
599
618
auto rd_block = (!brg.rdb && brg.rdb_tail ) ? brg.rdb_tail : brg.rd_block ;
600
- if (brg.is_bf32 ) rd_block = utils::rnd_up (rd_block, 2 /* vnni_granularity*/ );
619
+ if (brg.is_input_convert ())
620
+ rd_block = utils::rnd_up (rd_block, 2 /* vnni_granularity*/ );
601
621
602
622
palette_config_t *buff = (palette_config_t *)(palette);
603
623
604
624
char *_tc = (char *)(buff);
605
625
for (int i = 0 ; i < max_palette_size_in_bytes; i++)
606
626
_tc[i] = 0 ;
607
627
608
- const int typesize_A = brg.is_bf32 ? sizeof (bfloat16_t ) : brg.typesize_A ;
609
- const int typesize_B = brg.is_bf32 ? sizeof (bfloat16_t ) : brg.typesize_B ;
628
+ const int typesize_A
629
+ = brg.is_input_convert () ? sizeof (int16_t ) : brg.typesize_A ;
630
+ const int typesize_B
631
+ = brg.is_input_convert () ? sizeof (int16_t ) : brg.typesize_B ;
610
632
611
633
const int rd_step = 4 / typesize_A;
612
634
0 commit comments