Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 074e76d

Browse files
zhewang1-intcDDEle
authored andcommitted
Squashed commit of the following:
commit 96d2966 Author: Ding, Yi1 <yi1.ding@intel.com> Date: Wed Jul 10 08:26:45 2024 +0000 s4=>i4 commit 696820f Author: Ding, Yi1 <yi1.ding@intel.com> Date: Wed Jul 10 05:50:03 2024 +0000 add back dtype_zero_pt checkk commit 20da116 Author: Wang,Zhe <zhe1.wang@intel.com> Date: Mon Jul 1 10:51:38 2024 +0800 support bf16 activation commit ea42a9a Author: Zhe, Wang <zhe1.wang@intel.com> Date: Tue Jun 18 09:23:19 2024 +0800 support fp_zp quant
1 parent 22c2123 commit 074e76d

File tree

5 files changed

+105
-34
lines changed

5 files changed

+105
-34
lines changed

include/common/core/common_types.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2727

2828
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
2929

30-
enum class quant_mode : uint8_t { I4_ASYM = 0, I4_SYM = 1 };
30+
enum class quant_mode : uint8_t {
31+
I4_ASYM = 0,
32+
I4_SYM = 1,
33+
I4_ASYM_FP_ZERO = 2
34+
};
3135

3236
struct quant_info {
3337
quant_mode quant_mode;

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

+30-13
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,16 @@ class gemm_t<
102102
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
103103
"this is for 4bit matB ");
104104
static_assert(
105-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
106-
value ||
107-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
108-
value,
105+
quant_info_.quant_mode == quant_mode::I4_ASYM_FP_ZERO
106+
? std::is_same_v<
107+
remove_const_t<dtype_zero_pt>,
108+
remove_const_t<dtype_a>>
109+
: (std::is_same_v<
110+
remove_const_t<dtype_zero_pt>,
111+
remove_const_t<int4x2>> ||
112+
std::is_same_v<
113+
remove_const_t<dtype_zero_pt>,
114+
remove_const_t<int4x8>>),
109115
"this is for 4bit zero_pt ");
110116

111117
/******** set memory attribute **********/
@@ -284,12 +290,20 @@ class gemm_t<
284290
arch_tag>;
285291

286292
// compress int4 along N dimensions
287-
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
288-
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
289-
tile_size_y_zero_pt,
290-
(block_size_x_b + pack_ratio - 1) / pack_ratio,
291-
block_size_y_zero_pt,
292-
reg_layout::tiled>;
293+
using zero_pt_tile_desc_t = std::conditional_t<
294+
quant_info_.quant_mode != quant_mode::I4_ASYM_FP_ZERO,
295+
subgroup::tile_desc_t<
296+
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
297+
tile_size_y_zero_pt,
298+
(block_size_x_b + pack_ratio - 1) / pack_ratio,
299+
block_size_y_zero_pt,
300+
reg_layout::tiled>,
301+
subgroup::tile_desc_t<
302+
tile_size_x_b,
303+
tile_size_y_zero_pt,
304+
block_size_x_b,
305+
block_size_y_zero_pt,
306+
reg_layout::tiled>>;
293307

294308
using zero_pt_t = subgroup::tile_t<dtype_zero_pt, zero_pt_tile_desc_t>;
295309
using zero_pt_payload_t = subgroup::mem_payload_t<
@@ -576,7 +590,8 @@ class gemm_t<
576590
// TODO 1D prefetch need pack to U32/U64
577591
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
578592
scale_prefetch_payload);
579-
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
593+
if constexpr (
594+
compute_policy::quant_mode != quant_mode::I4_SYM) {
580595
// TODO 1D prefetch need pack to U32/U64
581596
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
582597
zero_pt_prefetch_payload);
@@ -589,7 +604,8 @@ class gemm_t<
589604
if (tile_k_idx % scale_addr_update_freq == 0) {
590605
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
591606
}
592-
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
607+
if constexpr (
608+
compute_policy::quant_mode != quant_mode::I4_SYM) {
593609
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
594610
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
595611
zero_pt_t::tile_size_y);
@@ -603,7 +619,8 @@ class gemm_t<
603619
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
604620
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
605621
scale_t::tile_size_y);
606-
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
622+
if constexpr (
623+
compute_policy::quant_mode != quant_mode::I4_SYM) {
607624
zero_pt_prefetch_payload
608625
.template update_tdesc<tdesc_update_dir::y_dir>(
609626
zero_pt_t::tile_size_y);

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ class gemm_universal_t<
570570
// check for int4x2
571571
implementable &=
572572
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
573-
if constexpr (gemm_t::compute_policy::quant_mode != quant_mode::I4_SYM) {
573+
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
574574
implementable &= (args.zero_pt_ld % pack_ratio == 0);
575575
}
576576

@@ -621,7 +621,10 @@ class gemm_universal_t<
621621
int start_x_scale = start_n;
622622
int start_y_scale = start_k / dequant_s;
623623

624-
int start_x_zero_pt = start_n / pack_ratio;
624+
int start_x_zero_pt =
625+
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO
626+
? start_n
627+
: start_n / pack_ratio;
625628
int start_y_zero_pt = start_k / dequant_s;
626629

627630
// set up arguments
@@ -674,7 +677,8 @@ class gemm_universal_t<
674677
inner_loop_start,
675678
inner_loop_count,
676679
mem_desc_scale);
677-
} else {
680+
} else if constexpr (
681+
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
678682
mem_desc_zero_pt_t mem_desc_zero_pt(
679683
args.zero_pt_base,
680684
{(args.matrix_n + pack_ratio - 1) / pack_ratio,
@@ -688,6 +692,23 @@ class gemm_universal_t<
688692
inner_loop_count,
689693
mem_desc_scale,
690694
mem_desc_zero_pt);
695+
} else if constexpr (
696+
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
697+
mem_desc_zero_pt_t mem_desc_zero_pt(
698+
args.zero_pt_base,
699+
{args.matrix_n,
700+
((args.matrix_k + dequant_s - 1) / dequant_s),
701+
args.zero_pt_ld},
702+
{start_x_zero_pt, start_y_zero_pt});
703+
gemm_args = gemm_args_t(
704+
mem_desc_a,
705+
mem_desc_b,
706+
inner_loop_start,
707+
inner_loop_count,
708+
mem_desc_scale,
709+
mem_desc_zero_pt);
710+
} else {
711+
assert(0);
691712
}
692713
matAcc_t matAcc;
693714
matAcc.init(0);

include/subgroup/tile/impl/tile_op_functor.hpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,25 @@ struct dequant_int4_weight_t {
149149
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
150150
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
151151
zero_pt_i8;
152-
} else if constexpr (quant_mode == quant_mode::I4_SYM) {
152+
} else if constexpr (
153+
quant_mode == quant_mode::I4_SYM ||
154+
quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
153155
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
154156
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
155157
int8_t(8);
156158
}
157159
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
158160
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
159161
scale.reg[scale_idx];
160-
162+
if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
163+
uint32_t zero_pt_idx =
164+
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
165+
offset_x_in_tile;
166+
xetla_vector<fp16, 1> zero_pt_pack = zero_pt.reg[zero_pt_idx];
167+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
168+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) +
169+
zero_pt_pack[0];
170+
}
161171
// sycl::ext::oneapi::experimental::printf(
162172
// "scale[%d] %f \n",
163173
// scale_idx,

tests/integration/gemv/int4/main.cpp

+34-15
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ constexpr int ITER = 200;
2727
#endif
2828
constexpr size_t UNDEFINED_DATA_SIZE = 1024;
2929

30-
template <typename scalar_t>
30+
template <typename scalar_t, quant_mode quant_mode_>
3131
class test_col_major_1 {
3232
public:
3333
// Extract the parameters required by different test cases
@@ -41,7 +41,7 @@ class test_col_major_1 {
4141
static constexpr size_t sg_k = 512 / sg_m;
4242
static constexpr size_t dequant_s = 128;
4343
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
44-
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;
44+
static constexpr quant_mode quant_mode = quant_mode_;
4545

4646
static constexpr size_t local_kslicing = 1;
4747
static constexpr size_t global_kslicing = 1;
@@ -132,13 +132,19 @@ std::vector<fp16> convert_int4(
132132
data_type_zero_pt zero_pt) {
133133
std::vector<fp16> dequant_fp16(sizeof(data_type_b) * 2);
134134

135-
int8_t zero_pt_i8 = zero_pt & 0xf;
135+
int8_t zero_pt_i8;
136+
if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
137+
zero_pt_i8 = zero_pt & 0xf;
136138
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
137139
int8_t dequant_8bit = data_b & 0xf;
138140
if constexpr (quant_mode == quant_mode::I4_SYM) {
139141
dequant_fp16[i] = scale * (dequant_8bit - 8);
140-
} else {
142+
} else if constexpr (quant_mode == quant_mode::I4_ASYM) {
141143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144+
} else if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
145+
dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt;
146+
} else {
147+
assert(0);
142148
}
143149
data_b = data_b >> 4;
144150
}
@@ -170,12 +176,14 @@ std::vector<data_type_acc_in> dequantize_weight(
170176
for (uint32_t j = 0; j < width; j += step) {
171177
int start_b_in = i * width + j;
172178
int start_scale_in = start_b_in / step;
173-
int start_zero_pt_in =
174-
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
179+
int start_zero_pt_in = quant_mode == quant_mode::I4_ASYM_FP_ZERO
180+
? (j / step) * matrix_n + i
181+
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
175182
int start_out =
176183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
177184
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
178-
zp_value = zp_value >> (4 * (i % pack_radio));
185+
if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
186+
zp_value = zp_value >> (4 * (i % pack_radio));
179187
for (uint32_t jj = 0; jj < step; jj++) {
180188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
181189
b[start_b_in + jj], scale[start_scale_in], zp_value);
@@ -216,7 +224,10 @@ void dequantize_gemv_run(int iter) {
216224
using data_type_a = typename Test::data_type_a;
217225
using data_type_b = typename Test::data_type_b;
218226
using data_type_c = typename Test::data_type_c;
219-
using data_type_zero_pt = data_type_b;
227+
using data_type_zero_pt = std::conditional_t<
228+
Test::quant_mode == quant_mode::I4_ASYM_FP_ZERO,
229+
data_type_c,
230+
data_type_b>;
220231
using data_type_scale = fp16;
221232
using data_type_acc_in = fp16;
222233
using data_type_acc = float;
@@ -226,7 +237,7 @@ void dequantize_gemv_run(int iter) {
226237
constexpr mem_layout layout_b = Test::layout_b;
227238

228239
constexpr size_t size_a = matrix_m * matrix_k;
229-
constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b));
240+
constexpr size_t size_b = matrix_k * matrix_n / 2;
230241

231242
constexpr size_t size_scale_k = matrix_k / dequant_s;
232243
constexpr size_t size_scale_n = matrix_n;
@@ -235,7 +246,9 @@ void dequantize_gemv_run(int iter) {
235246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
236247
constexpr size_t size_zero_pt_n = matrix_n;
237248
constexpr size_t size_zero_pt =
238-
size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b));
249+
Test::quant_mode != quant_mode::I4_ASYM_FP_ZERO
250+
? size_zero_pt_k * size_zero_pt_n / 2
251+
: size_zero_pt_k * size_zero_pt_n;
239252

240253
constexpr size_t size_c = matrix_m * matrix_n;
241254
constexpr size_t size_bias = matrix_n;
@@ -406,16 +419,18 @@ void dequantize_gemv_run(int iter) {
406419
scale_h[i] = INFINITY;
407420
}
408421
for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
409-
if constexpr (std::is_same_v<int4x2, data_type_b>) {
422+
if constexpr (std::is_same_v<int4x2, data_type_zero_pt>) {
410423
zero_pt_h[i] = random_uint8();
411424
#ifdef UT_DEBUG
412425
zero_pt_h[i] = 0x12 << i;
413426
#endif
414-
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
427+
} else if constexpr (std::is_same_v<int4x8, data_type_zero_pt>) {
415428
zero_pt_h[i] = random_uint32();
416429
#ifdef UT_DEBUG
417430
zero_pt_h[i] = 0x33333333;
418431
#endif
432+
} else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
433+
zero_pt_h[i] = random_float();
419434
}
420435
}
421436

@@ -492,7 +507,9 @@ void dequantize_gemv_run(int iter) {
492507
Acc_d,
493508
Cnt_d,
494509
epilogue_args);
495-
} else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
510+
} else if constexpr (
511+
compute_policy::quant_mode == quant_mode::I4_ASYM ||
512+
compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
496513
gemm_arg =
497514
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
498515
matrix_m,
@@ -604,8 +621,10 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {
604621

605622
REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd);
606623
using tests = ::testing::Types< //
607-
test_col_major_1<fp16>,
608-
test_col_major_1<bf16>,
624+
test_col_major_1<fp16, quant_mode::I4_SYM>,
625+
test_col_major_1<bf16, quant_mode::I4_SYM>,
626+
test_col_major_1<fp16, quant_mode::I4_ASYM_FP_ZERO>,
627+
test_col_major_1<bf16, quant_mode::I4_ASYM_FP_ZERO>,
609628
// test_col_major_2,
610629
void>;
611630

0 commit comments

Comments
 (0)