Skip to content

Commit 7fee383

Browse files
author
dmitrygo
committed
[FORK][FEATURE] Enable prepack algorithm for 4bit weights decompression
1 parent 2ead5d4 commit 7fee383

11 files changed

+287
-109
lines changed

src/cpu/reorder/cpu_reorder_regular_nf4.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ const impl_list_map_t &regular_nf4_impl_list_map() {
3131
REG_SR(nf4, any, nf4, OI8i24o2i, fmt_order_keep)
3232
REG_SR(nf4, any, nf4, OI8i32o2i, fmt_order_keep)
3333
REG_SR(nf4, any, nf4, OI8i64o2i, fmt_order_keep)
34+
REG_SR(nf4, any, nf4, OI16i16o2i, fmt_order_keep)
35+
REG_SR(nf4, any, nf4, OI16i32o2i, fmt_order_keep)
36+
REG_SR(nf4, any, nf4, OI16i48o2i, fmt_order_keep)
37+
REG_SR(nf4, any, nf4, OI16i64o2i, fmt_order_keep)
3438
nullptr,
3539
}},
3640
});

src/cpu/reorder/cpu_reorder_regular_s4.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ const impl_list_map_t &regular_s4_impl_list_map() {
3131
REG_SR(s4, any, s4, OI8i24o2i, fmt_order_keep)
3232
REG_SR(s4, any, s4, OI8i32o2i, fmt_order_keep)
3333
REG_SR(s4, any, s4, OI8i64o2i, fmt_order_keep)
34+
REG_SR(s4, any, s4, OI16i16o2i, fmt_order_keep)
35+
REG_SR(s4, any, s4, OI16i32o2i, fmt_order_keep)
36+
REG_SR(s4, any, s4, OI16i48o2i, fmt_order_keep)
37+
REG_SR(s4, any, s4, OI16i64o2i, fmt_order_keep)
3438
nullptr,
3539
}},
3640
});

src/cpu/reorder/cpu_reorder_regular_u4.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ const impl_list_map_t &regular_u4_impl_list_map() {
3131
REG_SR(u4, any, u4, OI8i24o2i, fmt_order_keep)
3232
REG_SR(u4, any, u4, OI8i32o2i, fmt_order_keep)
3333
REG_SR(u4, any, u4, OI8i64o2i, fmt_order_keep)
34+
REG_SR(u4, any, u4, OI16i16o2i, fmt_order_keep)
35+
REG_SR(u4, any, u4, OI16i32o2i, fmt_order_keep)
36+
REG_SR(u4, any, u4, OI16i48o2i, fmt_order_keep)
37+
REG_SR(u4, any, u4, OI16i64o2i, fmt_order_keep)
3438
nullptr,
3539
}},
3640
});

src/cpu/x64/brgemm/brgemm_utils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ int calculate_max_bcast_block(brgemm_t *brg, const int adj_ld_block2) {
206206
if (brg->is_int8 && !brg->has_int8_vnni) max_bcast_block -= 2;
207207

208208
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5;
209-
if (one_of(brg->dt_b, data_type::u4, data_type::nf4)) max_bcast_block -= 1;
209+
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl != avx2) max_bcast_block -= 1;
210210
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
211211

212212
max_bcast_block /= adj_ld_block2;

src/cpu/x64/brgemm/jit_brgemm_kernel.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
4747
, max_effective_vregs(
4848
max_vregs - (brg.is_int8 && !brg.has_int8_vnni ? 2 : 0)
4949
- (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2 ? 5 : 0)
50-
- (one_of(brg.dt_b, data_type::nf4) ? 1 : 0)
50+
- (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl != avx2 ? 1 : 0)
5151
- (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 ? 1 : 0)) {
5252

5353
// The implementation uses is_superset(), is_subset() utilities.
@@ -1930,7 +1930,10 @@ void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel(int bd_block2,
19301930
uni_vmovups(vmm_mask8, ptr[reg_ptr]);
19311931
mov(reg_ptr, (size_t)mask7);
19321932
uni_vmovups(vmm_mask7, ptr[reg_ptr]);
1933-
vmm_zero_points = Vmm(max_vregs - 5);
1933+
if (brg.wei_decomp_zero_points_stride == 0)
1934+
vmm_zero_points = Vmm(max_vregs - 6);
1935+
else
1936+
vmm_zero_points = Vmm(max_vregs - 5);
19341937
} else {
19351938
mov(reg_ptr, (size_t)lookup);
19361939
uni_vmovups(vmm_lookup, ptr[reg_ptr]);

src/cpu/x64/jit_brgemm_inner_product.cpp

+44-43
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,6 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
8787
const auto wei_zero_points_d = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
8888
int wei_scales_oc_stride = wei_scales_d.dims()[0] > 1 ? 1 : 0;
8989
int wei_zero_points_oc_stride = wei_zero_points_d.dims()[0] > 1 ? 1 : 0;
90-
int wei_scales_ic_group_size, wei_zero_points_ic_group_size;
91-
if (jbgp.with_grouped_weights_decompression) {
92-
int wei_scales_ic_group_num = wei_scales_d.dims()[1];
93-
int wei_zero_points_ic_group_num = wei_zero_points_d.dims()[1];
94-
wei_scales_ic_group_size = wei_scales_ic_group_num ? div_up(jbgp.ic, wei_scales_ic_group_num) : jbgp.ic;
95-
wei_zero_points_ic_group_size = wei_zero_points_ic_group_num ? div_up(jbgp.ic, wei_zero_points_ic_group_num) : jbgp.ic;
96-
} else {
97-
wei_scales_ic_group_size = wei_zero_points_ic_group_size = jbgp.ic;
98-
}
9990

10091
const float *oscales = nullptr;
10192
if (jbgp.weights_decompression) {
@@ -170,8 +161,6 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
170161
const auto wei_ic_stride
171162
= types::data_type_size(jbgp.wei_dt) * weights_d.off_v(ic_dims);
172163

173-
int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
174-
175164
const auto ker = [&](int ithr_oc_mb, int nthr_oc_mb, int ithr_ic, int osb,
176165
int osb_s, int ocb, int ocb_s, int icc, int icc_s,
177166
bool copy_buffer_a, int &prev_ker_idx) {
@@ -269,7 +258,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
269258
int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index(
270259
is_bs_tail, kernel_init, is_os_tail, is_oc_tail, false);
271260
auto brg_kernel = brg_kernels_[brg_ker_idx].get();
272-
const int ic_blocks_per_batch = jbgp.K / jbgp.ic_block;
261+
const int ic_blocks_per_batch = div_up(jbgp.K, jbgp.ic_block);
273262
const dim_t wei_cur_ocb
274263
= get_blk_off(weights_d, jbgp.wei_dt, cur_ocb, 0);
275264

@@ -290,7 +279,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
290279
ic + b * jbgp.K));
291280
addr_batch[b].ptr.A = A_ptr;
292281
const dim_t wei_offset = (wei_cur_ocb
293-
+ wei_ic_stride * (icb + b * ic_blocks_per_batch)) / typesize_scale;
282+
+ wei_ic_stride * (icb + b * ic_blocks_per_batch));
294283
if (jbgp.weights_compressed) {
295284
using comp_tile_len_type = int;
296285
const comp_tile_len_type *compressed_tile_lengths_ptr
@@ -311,30 +300,35 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
311300
(*brg_decomp_kernel_)(&dcomp_params);
312301
addr_batch[b].ptr.B = decomp_buf;
313302
} else if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) {
314-
auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt);
303+
int typesize_scale = one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
304+
auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt) / typesize_scale;
315305
auto weights_ptr = reinterpret_cast<const uint8_t *>(&weights[w_off]);
316306

317307
const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size(jbgp.wei_dt);
318308
auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr + wei_ic_stride * b * ic_blocks_per_batch;
319309

320-
const int ic_internal_block = is_amx ? 2 : 1;
321-
auto wei_zero_points_ptr = wei_zero_points + oc;
322-
auto wei_scales_ptr = wei_scales + oc;
310+
const int ic_internal_block = is_amx || one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
311+
auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
312+
auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
323313

324314
if (jbgp.with_grouped_weights_decompression) {
325315
weights_decompression_runtime_params_t rt_params = {};
326316
auto ic_size = jbgp.ic_block * ic_blocks_per_batch / ic_internal_block;
327-
auto wei_scales_ic_group_size_local = wei_scales_ic_group_size / ic_internal_block;
328-
auto wei_zero_points_ic_group_size_local = wei_zero_points_ic_group_size / ic_internal_block;
317+
auto wei_scales_ic_group_size_local = jbgp.wei_scales_ic_group_size / ic_internal_block;
318+
auto wei_zero_points_ic_group_size_local = jbgp.wei_zero_points_ic_group_size / ic_internal_block;
329319
auto group_size = nstl::min(wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local);
330320
auto group_ic_blocks = div_up(ic_size, group_size);
321+
auto start_group_scales = ic / jbgp.wei_scales_ic_group_size;
322+
auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size;
331323
for (int icb_idx = 0; icb_idx < group_ic_blocks; icb_idx++) {
332324
auto ic_idx = icb_idx * group_size;
325+
auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales;
326+
auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points;
333327

334-
rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt);
328+
rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt) / typesize_scale;
335329
rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block *jbgp.oc_block * types::data_type_size(jbgp.wei_dt);
336-
rt_params.scales_ptr = wei_scales_ptr + (ic_idx * wei_scales_d.dims()[0]) / wei_scales_ic_group_size_local;
337-
rt_params.zero_points_ptr = wei_zero_points_ptr + (ic_idx * wei_zero_points_d.dims()[0]) / wei_zero_points_ic_group_size_local;
330+
rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0];
331+
rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims()[0];
338332
rt_params.ic_size = nstl::min(group_size, ic_size - icb_idx * group_size);
339333
(*brg_weights_decomp_kernel_)(&rt_params);
340334
}
@@ -350,15 +344,16 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
350344

351345
addr_batch[b].ptr.B = decomp_buf;
352346
} else {
353-
addr_batch[b].ptr.B = weights + wei_offset;
347+
int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
348+
addr_batch[b].ptr.B = weights + wei_offset / typesize_scale;
354349
}
355350
}
356351

357352
int wei_scales_offset = 0;
358353
int wei_zero_points_offset = 0;
359354
if (jbgp.weights_decompression) {
360-
wei_scales_offset = (ic / wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc;
361-
wei_zero_points_offset = (ic / wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc;
355+
wei_scales_offset = (ic / jbgp.wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc;
356+
wei_zero_points_offset = (ic / jbgp.wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc;
362357
}
363358

364359
auto ptr_D = dst + dst_off;
@@ -382,10 +377,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
382377

383378
brgemm_kernel_execute_postops(brg_kernel, gemm_batch,
384379
addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
385-
scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic);
380+
scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0);
386381
} else {
387382
brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
388-
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic);
383+
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0);
389384
}
390385
}
391386

@@ -403,33 +398,38 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
403398
+ get_blk_off(src_d, jbgp.src_dt, n,
404399
ic + ic_block * jbgp.ic_block);
405400
const dim_t wei_offset
406-
= (wei_cur_ocb + wei_ic_stride * (icb + ic_block)) / typesize_scale;
401+
= (wei_cur_ocb + wei_ic_stride * (icb + ic_block));
407402

408403
if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) {
409-
auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt);
404+
int typesize_scale = one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
405+
auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt) / typesize_scale;
410406
auto weights_ptr = reinterpret_cast<const uint8_t *>(&weights[w_off]);
411407

412408
const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size(jbgp.wei_dt);
413409
auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr;
414410

415-
const int ic_internal_block = is_amx ? 2 : 1;
416-
auto wei_zero_points_ptr = wei_zero_points + oc;
417-
auto wei_scales_ptr = wei_scales + oc;
411+
const int ic_internal_block = is_amx || one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
412+
auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc;
413+
auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc;
418414

419415
if (jbgp.with_grouped_weights_decompression) {
416+
weights_decompression_runtime_params_t rt_params = {};
420417
auto ic_size = (jbgp.ic - (ic + ic_block * jbgp.ic_block)) / ic_internal_block;
421-
auto wei_scales_ic_group_size_local = wei_scales_ic_group_size / ic_internal_block;
422-
auto wei_zero_points_ic_group_size_local = wei_zero_points_ic_group_size / ic_internal_block;
418+
auto wei_scales_ic_group_size_local = jbgp.wei_scales_ic_group_size / ic_internal_block;
419+
auto wei_zero_points_ic_group_size_local = jbgp.wei_zero_points_ic_group_size / ic_internal_block;
423420
auto group_size = nstl::min(wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local);
424421
auto group_ic_blocks = div_up(ic_size, group_size);
425-
weights_decompression_runtime_params_t rt_params = {};
422+
auto start_group_scales = ic / jbgp.wei_scales_ic_group_size;
423+
auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size;
426424
for (int icb_idx = 0; icb_idx < group_ic_blocks; icb_idx++) {
427425
auto ic_idx = icb_idx * group_size;
426+
auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales;
427+
auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points;
428428

429-
rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt);
429+
rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt) / typesize_scale;
430430
rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.wei_dt);
431-
rt_params.scales_ptr = wei_scales_ptr + (ic_idx * wei_scales_d.dims()[0]) / wei_scales_ic_group_size_local;
432-
rt_params.zero_points_ptr = wei_zero_points_ptr + (ic_idx * wei_zero_points_d.dims()[0]) / wei_zero_points_ic_group_size_local;
431+
rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0];
432+
rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims()[0];
433433
rt_params.ic_size = nstl::min(group_size, ic_size - icb_idx * group_size);
434434
(*brg_weights_decomp_kernel_)(&rt_params);
435435
}
@@ -445,14 +445,15 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
445445

446446
addr_batch[0].ptr.B = decomp_buf;
447447
} else {
448-
addr_batch[0].ptr.B = weights + wei_offset;
448+
int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1;
449+
addr_batch[0].ptr.B = weights + wei_offset / typesize_scale;
449450
}
450451

451452
int wei_scales_offset = 0;
452453
int wei_zero_points_offset = 0;
453454
if (jbgp.weights_decompression) {
454-
wei_scales_offset = (ic / wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc;
455-
wei_zero_points_offset = (ic / wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc;
455+
wei_scales_offset = (ic / jbgp.wei_scales_ic_group_size) * wei_scales_d.dims()[0] + wei_scales_oc_stride * oc;
456+
wei_zero_points_offset = (ic / jbgp.wei_zero_points_ic_group_size) * wei_zero_points_d.dims()[0] + wei_zero_points_oc_stride * oc;
456457
}
457458

458459
auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get();
@@ -474,10 +475,10 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
474475
nullptr, false, 1, false, false, dst_scales};
475476

476477
brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch,
477-
(void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic);
478+
(void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0);
478479
} else {
479480
brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch,
480-
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], ic);
481+
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, &wei_scales[wei_scales_offset], &wei_zero_points[wei_zero_points_offset], 0);
481482
}
482483
}
483484
};

0 commit comments

Comments
 (0)