@@ -55,7 +55,7 @@ void init_kernel_datatype(
55
55
brg->is_bf16 = one_of (dt_a, data_type::bf16) &&
56
56
one_of (dt_b, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
57
57
brg->is_f32 = one_of (dt_a, data_type::f32) &&
58
- one_of (dt_b, data_type::f32, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
58
+ one_of (dt_b, data_type::f32, data_type::f16, data_type::bf16, data_type:: u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
59
59
brg->is_f16 = utils::one_of (data_type::f16, dt_a, dt_b);
60
60
brg->is_fp8 = one_of (dt_a, data_type::f8_e5m2, data_type::f8_e4m3)
61
61
&& one_of (dt_b, data_type::f8_e5m2, data_type::f8_e4m3);
@@ -905,10 +905,9 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
905
905
brg->bdb2 = 0 ;
906
906
brg->bdb2_tail = 0 ;
907
907
908
- const bool is_vcvtph2ps_kernel = (brg->dt_b == data_type::f16 && brg->dt_a == data_type::f32);
909
908
const bool is_b_in_vnni_format = !(brg->dt_b == data_type::f16 && brg->isa_impl == avx512_core_fp16) &&
910
909
!(one_of (brg->dt_a , data_type::f32, data_type::bf16) && one_of (brg->dt_b , data_type::u8, data_type::s8)) &&
911
- !is_vcvtph2ps_kernel ;
910
+ !( one_of (brg-> dt_a , data_type::f32) && one_of (brg-> dt_b , data_type::bf16, data_type::f16)) ;
912
911
brg->ld_step
913
912
= is_b_in_vnni_format ? data_type_vnni_granularity (brg->dt_b ) : 1 ;
914
913
// const data_type_t ld_step_compute_dt
@@ -924,7 +923,8 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
924
923
&& one_of (brg->isa_impl , avx2_vnni_2, avx512_core_fp16))
925
924
|| (brg->is_bf16 && brg->isa_impl == avx2_vnni_2)
926
925
|| (one_of (brg->dt_a , data_type::f32, data_type::bf16) &&
927
- one_of (brg->dt_b , data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f16));
926
+ one_of (brg->dt_b , data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4))
927
+ || (one_of (brg->dt_a , data_type::f32) && one_of (brg->dt_b , data_type::bf16, data_type::f16));
928
928
brg->rd_step = has_no_vnni_compute_instruction
929
929
? 1
930
930
: data_type_vnni_granularity (brg->dt_b );
0 commit comments