Skip to content

Commit 4aed5ec

Browse files
dmitry-gorokhovazhai219
authored andcommitted
[FORK][FEATURE] Support (f32,bf16,f32) inner-product
3.5 squash list: [FORK][FIX] Corrected brgemm rd_step for bf16 compressed weights
1 parent bd4f691 commit 4aed5ec

6 files changed

+30
-16
lines changed

src/cpu/cpu_inner_product_list.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
8888
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
8989
nullptr,
9090
}},
91+
{{forward, f32, bf16, f32}, {
92+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
93+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
94+
nullptr,
95+
}},
9196
{{forward, bf16, bf16, f32}, {
9297
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
9398
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)

src/cpu/x64/brgemm/brgemm_utils.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void init_kernel_datatype(
5555
brg->is_bf16 = one_of(dt_a, data_type::bf16) &&
5656
one_of(dt_b, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
5757
brg->is_f32 = one_of(dt_a, data_type::f32) &&
58-
one_of(dt_b, data_type::f32, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
58+
one_of(dt_b, data_type::f32, data_type::f16, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
5959
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
6060
brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3)
6161
&& one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3);
@@ -905,10 +905,9 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
905905
brg->bdb2 = 0;
906906
brg->bdb2_tail = 0;
907907

908-
const bool is_vcvtph2ps_kernel = (brg->dt_b == data_type::f16 && brg->dt_a == data_type::f32);
909908
const bool is_b_in_vnni_format = !(brg->dt_b == data_type::f16 && brg->isa_impl == avx512_core_fp16) &&
910909
!(one_of(brg->dt_a, data_type::f32, data_type::bf16) && one_of(brg->dt_b, data_type::u8, data_type::s8)) &&
911-
!is_vcvtph2ps_kernel;
910+
!(one_of(brg->dt_a, data_type::f32) && one_of(brg->dt_b, data_type::bf16, data_type::f16));
912911
brg->ld_step
913912
= is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1;
914913
// const data_type_t ld_step_compute_dt
@@ -924,7 +923,8 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
924923
&& one_of(brg->isa_impl, avx2_vnni_2, avx512_core_fp16))
925924
|| (brg->is_bf16 && brg->isa_impl == avx2_vnni_2)
926925
|| (one_of(brg->dt_a, data_type::f32, data_type::bf16) &&
927-
one_of(brg->dt_b, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f16));
926+
one_of(brg->dt_b, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4))
927+
|| (one_of(brg->dt_a, data_type::f32) && one_of(brg->dt_b, data_type::bf16, data_type::f16));
928928
brg->rd_step = has_no_vnni_compute_instruction
929929
? 1
930930
: data_type_vnni_granularity(brg->dt_b);

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -2556,12 +2556,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
25562556
} else
25572557
vcvtph2ps(vmm_load, addr);
25582558
}
2559-
} else if (brg.dt_b == data_type::bf16
2560-
&& brg.isa_impl == avx2_vnni_2) {
2559+
} else if (brg.dt_b == data_type::bf16 && brg.isa_impl == avx2_vnni_2) {
25612560
if (rd % 2 == 0)
25622561
vcvtneebf162ps(vmm_load, addr);
25632562
else
25642563
vcvtneobf162ps(vmm_load, addr);
2564+
} else if (brg.dt_b == data_type::bf16 && brg.is_f32) {
2565+
vpmovzxwd(vmm_load, addr);
2566+
uni_vpslld(vmm_load, vmm_load, 16);
25652567
} else if (is_ld_tail) {
25662568
if (is_superset(brg.isa_impl, avx512_core)) {
25672569
uni_vmovups(vmm_load, addr);
@@ -2852,12 +2854,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
28522854
} else
28532855
vcvtph2ps(vmm_load, addr);
28542856
}
2855-
} else if (brg.dt_b == data_type::bf16
2856-
&& brg.isa_impl == avx2_vnni_2) {
2857+
} else if (brg.dt_b == data_type::bf16 && brg.isa_impl == avx2_vnni_2) {
28572858
if (rd % 2 == 0)
28582859
vcvtneebf162ps(vmm_load, addr);
28592860
else
28602861
vcvtneobf162ps(vmm_load, addr);
2862+
} else if (brg.dt_b == data_type::bf16 && brg.is_f32) {
2863+
vpmovzxwd(vmm_load, addr);
2864+
uni_vpslld(vmm_load, vmm_load, 16);
28612865
} else if (is_ld_tail) {
28622866
if (is_superset(brg.isa_impl, avx512_core)) {
28632867
uni_vmovups(vmm_load, addr);

src/cpu/x64/jit_brgemm_inner_product.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
6363
auto dst_dt = invariant_dst_md()->data_type;
6464
auto wei_dt = invariant_wei_md()->data_type;
6565
const bool is_int8 = one_of(src_dt, u8, s8);
66-
const bool is_wei_decomp = one_of(src_dt, f32, bf16) &&
67-
one_of(wei_dt, u8, s8, nf4, s4, u4, f16);
66+
const bool is_wei_decomp = (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4)) ||
67+
(one_of(src_dt, f32) && one_of(wei_dt, f16, bf16));
6868

6969
using skip_mask_t = primitive_attr_t::skip_mask_t;
7070
auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt
@@ -120,8 +120,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
120120
const float beta = 1.0;
121121
const float beta_init = 0.0;
122122

123-
// f16 weight decompression doesn't need scales/zero-points which handles by normal brgemm kernel
124-
bool brgemm_with_wei_decomp = is_wei_decomp && jbgp_.wei_decomp_algo == weights_decomp_kind_t::immediate && wei_dt != f16;
123+
// f16/bf16 weights decompression doesn't need scales/zero-points which is handled by normal brgemm kernel
124+
bool brgemm_with_wei_decomp = is_wei_decomp && jbgp_.wei_decomp_algo == weights_decomp_kind_t::immediate && !one_of(wei_dt, f16, bf16);
125125

126126
for_(int i_bs = 0; i_bs < 2; i_bs++)
127127
for_(int i_init = 0; i_init < 2; i_init++)

src/cpu/x64/jit_brgemm_inner_product_utils.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ jit_brgemm_ip_conf_t::get_desired_weights_tag() const {
172172
const bool is_xf16 = utils::one_of(jbgp.wei_dt, bf16, f16);
173173
const bool is_not_vnni_tag = (jbgp.wei_dt == f32
174174
|| (jbgp.wei_dt == f16 && jbgp.isa == avx512_core_fp16)) && !jbgp.weights_decompression;
175-
const bool is_vcvtph2ps_kernel = (jbgp.orig_wei_dt == f16 && jbgp.src_dt == f32);
176-
if (is_not_vnni_tag || (jbgp.weights_decompression && (jbgp.orig_wei_dt == u8 || jbgp.orig_wei_dt == s8 || is_vcvtph2ps_kernel) && !jbgp.with_src_dynamic_quant)) {
175+
const bool is_half_prc_weights = (one_of(jbgp.orig_wei_dt, f16, bf16) && jbgp.src_dt == f32);
176+
if (is_not_vnni_tag || (jbgp.weights_decompression && (one_of(jbgp.orig_wei_dt, u8, s8) || is_half_prc_weights) && !jbgp.with_src_dynamic_quant)) {
177177
if (is_superset(jbgp.isa, avx512_core))
178178
return {{64,
179179
pick(n_sp_dims, OI16i64o, OwI16i64o, OhwI16i64o,
@@ -1419,8 +1419,8 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
14191419
jbgp.dst_dt = dst_d.data_type();
14201420
jbgp.wei_dt = weights_d.data_type();
14211421

1422-
jbgp.weights_decompression = one_of(jbgp.src_dt, f32, bf16) &&
1423-
one_of(jbgp.wei_dt, u8, s8, nf4, s4, u4, f16);
1422+
jbgp.weights_decompression = (one_of(jbgp.src_dt, f32, bf16) && one_of(jbgp.wei_dt, u8, s8, nf4, s4, u4)) ||
1423+
(one_of(jbgp.src_dt, f32) && one_of(jbgp.wei_dt, f16, bf16));
14241424
jbgp.wei_decomp_algo = weights_decomp_kind_t::immediate;
14251425
jbgp.orig_wei_dt = jbgp.wei_dt;
14261426
jbgp.with_grouped_weights_decompression = false;

src/cpu/x64/jit_brgemm_weights_decompression_kernel.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ void jit_brgemm_weights_decompression_kernel_t<isa>::load_weights(Vmm vmm_load,
132132
vcvtph2ps(vmm_load, addr);
133133
break;
134134
}
135+
case data_type::bf16: {
136+
vpmovzxwd(vmm_load, addr);
137+
uni_vpslld(vmm_load, vmm_load, 16);
138+
break;
139+
}
135140
default: assert(!"unsupported data type");
136141
}
137142
}

0 commit comments

Comments
 (0)