Skip to content

Commit cc986b3

Browse files
author
dmitrygo
committed
[FORK][FEATURE] InnerProduct primitive: 4bit weights decompression support on SPR
1 parent 7fee383 commit cc986b3

8 files changed

+206
-152
lines changed

src/cpu/cpu_inner_product_list.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,42 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
9090
}},
9191
{{forward, bf16, u8, f32}, {
9292
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
93+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
9394
nullptr,
9495
}},
9596
{{forward, bf16, u8, bf16}, {
9697
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
98+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
99+
nullptr,
100+
}},
101+
{{forward, bf16, nf4, f32}, {
102+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
103+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
104+
nullptr,
105+
}},
106+
{{forward, bf16, nf4, bf16}, {
107+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
108+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
109+
nullptr,
110+
}},
111+
{{forward, bf16, s4, f32}, {
112+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
113+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
114+
nullptr,
115+
}},
116+
{{forward, bf16, s4, bf16}, {
117+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
118+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
119+
nullptr,
120+
}},
121+
{{forward, bf16, u4, f32}, {
122+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
123+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
124+
nullptr,
125+
}},
126+
{{forward, bf16, u4, bf16}, {
127+
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
128+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
97129
nullptr,
98130
}},
99131
{{forward, f16, f16, f32}, {

src/cpu/x64/brgemm/brgemm_utils.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ void init_kernel_datatype(
5151
assert(dt_a != data_type::undef && dt_b != data_type::undef);
5252
brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8)
5353
&& utils::one_of(dt_b, data_type::u8, data_type::s8);
54-
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
54+
brg->is_bf16 = one_of(dt_a, data_type::bf16) &&
55+
one_of(dt_b, data_type::bf16, data_type::u8, data_type::nf4, data_type::s4, data_type::u4);
5556
brg->is_f32 = one_of(dt_a, data_type::f32) &&
5657
one_of(dt_b, data_type::f32, data_type::u8, data_type::nf4, data_type::s4, data_type::u4);
5758
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
@@ -833,15 +834,15 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
833834
brg->bdb2_tail = 0;
834835

835836
const bool is_b_in_vnni_format = !(brg->dt_b == data_type::f16 && brg->isa_impl == avx512_core_fp16) &&
836-
!(brg->dt_a == data_type::f32 && one_of(brg->dt_b, data_type::u8));
837+
!(one_of(brg->dt_a, data_type::f32, data_type::bf16) && one_of(brg->dt_b, data_type::u8));
837838
brg->ld_step
838839
= is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1;
839840

840841
const bool has_no_vnni_compute_instruction
841842
= (brg->is_f16
842843
&& one_of(brg->isa_impl, avx2_vnni_2, avx512_core_fp16))
843844
|| (brg->is_bf16 && brg->isa_impl == avx2_vnni_2)
844-
|| (one_of(brg->dt_a, data_type::f32) &&
845+
|| (one_of(brg->dt_a, data_type::f32, data_type::bf16) &&
845846
one_of(brg->dt_b, data_type::u8, data_type::nf4, data_type::s4, data_type::u4));
846847
brg->rd_step = has_no_vnni_compute_instruction
847848
? 1

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -2027,8 +2027,14 @@ void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel(int bd_block2,
20272027
const auto bd_by_load_bytes
20282028
= (bd >= bd_e - rows_by_load_bytes
20292029
|| brg.brgattr.wary_tail_read);
2030-
broadcast(bcst(), A_offset(bd, rd),
2031-
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
2030+
if (brg.dt_a == data_type::bf16) {
2031+
vpbroadcastw(bcst(), ptr[reg_aux_A + A_offset(bd, rd)]);
2032+
uni_vpmovzxwd(bcst(), bcst());
2033+
uni_vpslld(bcst(), bcst(), 16);
2034+
} else {
2035+
broadcast(bcst(), A_offset(bd, rd),
2036+
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
2037+
}
20322038
}
20332039
if (prefetch_count_B < ld_block2) {
20342040
prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd)
@@ -2044,10 +2050,9 @@ void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel(int bd_block2,
20442050
uni_vmulps(vmm, load(ld), bcst());
20452051
} else {
20462052
if (is_emdbd)
2047-
uni_vfmadd231ps(vmm, load(ld),
2048-
ptr_b[reg_aux_A + A_offset(bd, rd)]);
2053+
uni_vfmadd231ps(vmm, load(ld), ptr_b[reg_aux_A + A_offset(bd, rd)]);
20492054
else
2050-
dot_product(vmm, load(ld), bcst());
2055+
uni_vfmadd231ps(vmm, load(ld), bcst());
20512056
}
20522057
}
20532058
}

src/cpu/x64/jit_brgemm_inner_product.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
307307
const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size(jbgp.wei_dt);
308308
auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr + wei_ic_stride * b * ic_blocks_per_batch;
309309

310-
const int ic_internal_block = is_amx || one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
310+
const int ic_internal_block = pd()->jbgp_.wei_dt == data_type::bf16 ||
311+
one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
311312
auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
312313
auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
313314

@@ -408,7 +409,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
408409
const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size(jbgp.wei_dt);
409410
auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr;
410411

411-
const int ic_internal_block = is_amx || one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
412+
const int ic_internal_block = pd()->jbgp_.wei_dt == data_type::bf16 ||
413+
one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
412414
auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
413415
auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
414416

src/cpu/x64/jit_brgemm_inner_product.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
212212
if (pd()->jbgp_.weights_decompression && pd()->jbgp_.wei_decomp_algo == weights_decomp_kind_t::prepack) {
213213
weights_decompression_compile_params_t jcp = {};
214214
jcp.oc_size = pd()->jbgp_.oc_block;
215-
jcp.ic_internal_size = pd()->jbgp_.is_amx || utils::one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
215+
jcp.ic_internal_size = pd()->jbgp_.wei_dt == data_type::bf16 ||
216+
utils::one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
216217
jcp.with_scales = !pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values();
217218
jcp.broadcast_scales = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).dims_[0] == 1;
218219
jcp.with_zero_points = !pd()->attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS);

src/cpu/x64/jit_brgemm_inner_product_utils.cpp

+12-7
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ jit_brgemm_ip_conf_t::get_desired_weights_tag() const {
151151
const bool is_xf16 = utils::one_of(jbgp.wei_dt, bf16, f16);
152152
const bool is_not_vnni_tag = (jbgp.wei_dt == f32
153153
|| (jbgp.wei_dt == f16 && jbgp.isa == avx512_core_fp16)) && !jbgp.weights_decompression;
154-
if (is_not_vnni_tag || (jbgp.weights_decompression && jbgp.orig_wei_dt == u8 && jbgp.wei_dt != bf16)) {
154+
if (is_not_vnni_tag || (jbgp.weights_decompression && jbgp.orig_wei_dt == u8)) {
155155
if (is_superset(jbgp.isa, avx512_core))
156156
return {{64,
157157
pick(n_sp_dims, OI16i64o, OIw16i64o, OIhw16i64o,
@@ -176,7 +176,7 @@ jit_brgemm_ip_conf_t::get_desired_weights_tag() const {
176176
pick(n_sp_dims, OI8i16o, OIw8i16o, OIhw8i16o,
177177
OIdhw8i16o)},
178178
{8, pick(n_sp_dims, OI8i8o, OIw8i8o, OIhw8i8o, OIdhw8i8o)}};
179-
} else if (is_xf16 || (jbgp.weights_decompression && jbgp.orig_wei_dt == u8 && jbgp.wei_dt == bf16)) {
179+
} else if (is_xf16) {
180180
if (jbgp.is_amx) {
181181
return {{64,
182182
pick(n_sp_dims, OI16i64o2i, OIw16i64o2i,
@@ -374,10 +374,11 @@ int jit_brgemm_ip_conf_t::get_adjusted_oc_block() const {
374374
// time for weights reorder are key optimization points there.
375375
const size_t wei_size = static_cast<size_t>(jbgp.ic * jbgp.oc) * types::data_type_size(jbgp.wei_dt);
376376
// Use oc block to be 32 if weight size >= 8MB on amx bf16 to optimized memory consumption.
377-
if (jbgp.is_amx && jbgp.wei_dt == bf16 && !jbgp.is_bf32 && wei_size >= 8 * (1 << 20))
377+
if (jbgp.is_amx && jbgp.orig_wei_dt == bf16 && !jbgp.is_bf32 && wei_size >= 8 * (1 << 20))
378378
return 32;
379379
// Use oc block to be 64 if weight size >= 16MB on avx512 f32 to optimized memory consumption.
380-
if (is_f32_compute_avx512 && wei_size >= 16 * (1 << 20))
380+
if ((is_f32_compute_avx512 || (jbgp.is_amx && jbgp.orig_wei_dt != bf16 && !jbgp.is_bf32))
381+
&& wei_size >= 16 * (1 << 20))
381382
return 64;
382383
// Use oc block to be 24 if weight size >= 16MB on avx2 f32 to optimized memory consumption.
383384
if (is_f32_compute_avx2 && wei_size >= 16 * (1 << 20))
@@ -1339,8 +1340,7 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
13391340

13401341
jbgp.weights_decompression = one_of(jbgp.src_dt, f32, bf16) &&
13411342
one_of(jbgp.wei_dt, u8, nf4, s4, u4);
1342-
jbgp.wei_decomp_algo = jbgp.is_amx ? weights_decomp_kind_t::prepack
1343-
: weights_decomp_kind_t::immediate;
1343+
jbgp.wei_decomp_algo = weights_decomp_kind_t::immediate;
13441344
jbgp.orig_wei_dt = jbgp.wei_dt;
13451345
jbgp.with_grouped_weights_decompression = false;
13461346
if (jbgp.weights_decompression) {
@@ -1366,6 +1366,10 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
13661366
}
13671367
}
13681368

1369+
// Current AMX implementation cannot provide perfromance benefit for immediate algorithm over avx512 version
1370+
if (jbgp.is_amx && jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::immediate)
1371+
return status::unimplemented;
1372+
13691373
jbgp.bia_dt = jbgp.with_bias
13701374
? pick_by_prop_kind(jbgp.prop_kind, ipd.bias_desc.data_type,
13711375
data_type::undef, ipd.diff_bias_desc.data_type)
@@ -1382,7 +1386,8 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
13821386
everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt)
13831387
&& jbgp.src_dt == f32,
13841388
everyone_is(bf16, jbgp.src_dt, jbgp.dst_dt)
1385-
&& jbgp.wei_dt == f32);
1389+
&& jbgp.wei_dt == f32)
1390+
|| (jbgp.weights_decompression && jbgp.src_dt == bf16 && one_of(jbgp.dst_dt, f32, bf16));
13861391
const bool is_f16 = everyone_is(f16, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt)
13871392
|| pick_by_prop_kind(jbgp.prop_kind,
13881393
everyone_is(f16, jbgp.src_dt, jbgp.wei_dt)

0 commit comments

Comments
 (0)