Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 1b234c5

Browse files
committed
support fp_zp quant
1 parent 631b2a3 commit 1b234c5

File tree

6 files changed

+95
-45
lines changed

6 files changed

+95
-45
lines changed

include/common/core/common_types.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2727

2828
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
2929

30-
enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
30+
enum class quant_mode : uint8_t {
31+
S4_ASYM = 0,
32+
S4_FULLRANGE_NO_ZP = 1,
33+
INT4_ASYM_FP_ZERO = 2
34+
};
3135

3236
struct quant_info {
3337
quant_mode quant_mode;

include/common/core/memory.hpp

+6-12
Original file line numberDiff line numberDiff line change
@@ -355,13 +355,9 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
355355
__ESIMD_NS::cache_hint_L1<gpu::xetla::detail::get_cache_hint(L1H)>,
356356
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
357357
__ESIMD_NS::alignment<alignment>};
358-
if constexpr (sizeof(T) * N < sizeof(uint32_t) || N == 1) {
359-
xetla_vector<T, N> ret;
360-
#pragma unroll
361-
for (uint32_t i = 0; i < N; i++) {
362-
ret[i] = ptr[i + byte_offset / sizeof(T)];
363-
}
364-
return ret;
358+
if constexpr (sizeof(T) * N < sizeof(uint32_t)) {
359+
xetla_vector<uint32_t, N> offsets(byte_offset, sizeof(T));
360+
return __ESIMD_NS::gather<T, N, uint32_t>(ptr, offsets);
365361
} else {
366362
return __ESIMD_NS::block_load<T, N>(ptr, byte_offset, props);
367363
}
@@ -505,11 +501,9 @@ __XETLA_API void xetla_store_global(
505501
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
506502
__ESIMD_NS::alignment<alignment>};
507503

508-
if constexpr (sizeof(T) * N < sizeof(uint32_t) || N == 1) {
509-
#pragma unroll
510-
for (uint32_t i = 0; i < N; i++) {
511-
ptr[i + byte_offset / sizeof(T)] = vals[i];
512-
}
504+
if constexpr (sizeof(T) * N < sizeof(uint32_t)) {
505+
xetla_vector<uint32_t, N> offsets(byte_offset, sizeof(T));
506+
return __ESIMD_NS::scatter<T, N, uint32_t>(ptr, offsets, vals);
513507
} else {
514508
__ESIMD_NS::block_store<T, N>(ptr, byte_offset, vals, props);
515509
}

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

+14-12
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,6 @@ class gemm_t<
101101
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x2>>::value ||
102102
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
103103
"this is for 4bit matB ");
104-
static_assert(
105-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
106-
value ||
107-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
108-
value,
109-
"this is for 4bit zero_pt ");
110104

111105
/******** set memory attribute **********/
112106
static constexpr mem_space mem_space_a = mem_desc_a_t::space;
@@ -284,12 +278,20 @@ class gemm_t<
284278
arch_tag>;
285279

286280
// compress int4 along N dimensions
287-
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
288-
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
289-
tile_size_y_zero_pt,
290-
(block_size_x_b + pack_ratio - 1) / pack_ratio,
291-
block_size_y_zero_pt,
292-
reg_layout::tiled>;
281+
using zero_pt_tile_desc_t = std::conditional_t<
282+
quant_info_.quant_mode != quant_mode::INT4_ASYM_FP_ZERO,
283+
subgroup::tile_desc_t<
284+
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
285+
tile_size_y_zero_pt,
286+
(block_size_x_b + pack_ratio - 1) / pack_ratio,
287+
block_size_y_zero_pt,
288+
reg_layout::tiled>,
289+
subgroup::tile_desc_t<
290+
tile_size_x_b,
291+
tile_size_y_zero_pt,
292+
block_size_x_b,
293+
block_size_y_zero_pt,
294+
reg_layout::tiled>>;
293295

294296
using zero_pt_t = subgroup::tile_t<dtype_zero_pt, zero_pt_tile_desc_t>;
295297
using zero_pt_payload_t = subgroup::mem_payload_t<

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

+25-4
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,7 @@ class gemm_universal_t<
566566
// check for int4x2
567567
implementable &=
568568
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
569-
if constexpr (
570-
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
569+
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) {
571570
implementable &= (args.zero_pt_ld % pack_ratio == 0);
572571
}
573572

@@ -618,7 +617,10 @@ class gemm_universal_t<
618617
int start_x_scale = start_n;
619618
int start_y_scale = start_k / dequant_s;
620619

621-
int start_x_zero_pt = start_n / pack_ratio;
620+
int start_x_zero_pt = gemm_t::compute_policy::quant_mode ==
621+
quant_mode::INT4_ASYM_FP_ZERO
622+
? start_n
623+
: start_n / pack_ratio;
622624
int start_y_zero_pt = start_k / dequant_s;
623625

624626
// set up arguments
@@ -672,7 +674,8 @@ class gemm_universal_t<
672674
inner_loop_start,
673675
inner_loop_count,
674676
mem_desc_scale);
675-
} else {
677+
} else if constexpr (
678+
gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) {
676679
mem_desc_zero_pt_t mem_desc_zero_pt(
677680
args.zero_pt_base,
678681
{(args.matrix_n + pack_ratio - 1) / pack_ratio,
@@ -686,6 +689,24 @@ class gemm_universal_t<
686689
inner_loop_count,
687690
mem_desc_scale,
688691
mem_desc_zero_pt);
692+
} else if constexpr (
693+
gemm_t::compute_policy::quant_mode ==
694+
quant_mode::INT4_ASYM_FP_ZERO) {
695+
mem_desc_zero_pt_t mem_desc_zero_pt(
696+
args.zero_pt_base,
697+
{args.matrix_n,
698+
((args.matrix_k + dequant_s - 1) / dequant_s),
699+
args.zero_pt_ld},
700+
{start_x_zero_pt, start_y_zero_pt});
701+
gemm_args = gemm_args_t(
702+
mem_desc_a,
703+
mem_desc_b,
704+
inner_loop_start,
705+
inner_loop_count,
706+
mem_desc_scale,
707+
mem_desc_zero_pt);
708+
} else {
709+
assert(0);
689710
}
690711
matAcc_t matAcc;
691712
matAcc.init(0);

include/subgroup/tile/impl/tile_op_functor.hpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,26 @@ struct dequant_int4_weight_t {
149149
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
150150
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
151151
zero_pt_i8;
152-
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
152+
} else if constexpr (
153+
quant_mode == quant_mode::S4_FULLRANGE_NO_ZP ||
154+
quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
153155
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
154156
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
155157
int8_t(8);
156158
}
157159
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
158160
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
159161
scale.reg[scale_idx];
160-
162+
if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
163+
uint32_t zero_pt_idx =
164+
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
165+
offset_x_in_tile;
166+
native_type_t<typename zero_pt_t::dtype> zero_pt_pack =
167+
zero_pt.reg[zero_pt_idx];
168+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
169+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) +
170+
zero_pt_pack;
171+
}
161172
// sycl::ext::oneapi::experimental::printf(
162173
// "scale[%d] %f \n",
163174
// scale_idx,

tests/integration/gemv/int4/main.cpp

+32-14
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class test_col_major_1 {
4040
static constexpr size_t sg_k = 512 / sg_m;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43-
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43+
// static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44+
static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO;
4445

4546
static constexpr size_t local_kslicing = 1;
4647
static constexpr size_t global_kslicing = 1;
@@ -131,13 +132,19 @@ std::vector<fp16> convert_int4(
131132
data_type_zero_pt zero_pt) {
132133
std::vector<fp16> dequant_fp16(sizeof(data_type_b) * 2);
133134

134-
int8_t zero_pt_i8 = zero_pt & 0xf;
135+
int8_t zero_pt_i8;
136+
if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
137+
zero_pt_i8 = zero_pt & 0xf;
135138
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
136139
int8_t dequant_8bit = data_b & 0xf;
137140
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
138141
dequant_fp16[i] = scale * (dequant_8bit - 8);
139-
} else {
142+
} else if constexpr (quant_mode == quant_mode::S4_ASYM) {
140143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144+
} else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
145+
dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt;
146+
} else {
147+
assert(0);
141148
}
142149
data_b = data_b >> 4;
143150
}
@@ -169,15 +176,17 @@ std::vector<data_type_acc_in> dequantize_weight(
169176
for (uint32_t j = 0; j < width; j += step) {
170177
int start_b_in = i * width + j;
171178
int start_scale_in = start_b_in / step;
172-
int start_zero_pt_in =
173-
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
179+
int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
180+
? (j / step) * matrix_n + i
181+
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
174182
int start_out =
175183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184+
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185+
if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
186+
zp_value = zp_value >> (4 * (i % pack_radio));
176187
for (uint32_t jj = 0; jj < step; jj++) {
177188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
178-
b[start_b_in + jj],
179-
scale[start_scale_in],
180-
zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
189+
b[start_b_in + jj], scale[start_scale_in], zp_value);
181190
for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) {
182191
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
183192
}
@@ -215,7 +224,10 @@ void dequantize_gemv_run(int iter) {
215224
using data_type_a = typename Test::data_type_a;
216225
using data_type_b = typename Test::data_type_b;
217226
using data_type_c = typename Test::data_type_c;
218-
using data_type_zero_pt = data_type_b;
227+
using data_type_zero_pt = std::conditional_t<
228+
Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO,
229+
data_type_c,
230+
data_type_b>;
219231
using data_type_scale = fp16;
220232
using data_type_acc_in = fp16;
221233
using data_type_acc = float;
@@ -225,7 +237,7 @@ void dequantize_gemv_run(int iter) {
225237
constexpr mem_layout layout_b = Test::layout_b;
226238

227239
constexpr size_t size_a = matrix_m * matrix_k;
228-
constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b));
240+
constexpr size_t size_b = matrix_k * matrix_n / 2;
229241

230242
constexpr size_t size_scale_k = matrix_k / dequant_s;
231243
constexpr size_t size_scale_n = matrix_n;
@@ -234,7 +246,9 @@ void dequantize_gemv_run(int iter) {
234246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
235247
constexpr size_t size_zero_pt_n = matrix_n;
236248
constexpr size_t size_zero_pt =
237-
size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b));
249+
Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
250+
? size_zero_pt_k * size_zero_pt_n / 2
251+
: size_zero_pt_k * size_zero_pt_n;
238252

239253
constexpr size_t size_c = matrix_m * matrix_n;
240254
constexpr size_t size_bias = matrix_n;
@@ -405,16 +419,18 @@ void dequantize_gemv_run(int iter) {
405419
scale_h[i] = INFINITY;
406420
}
407421
for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
408-
if constexpr (std::is_same_v<int4x2, data_type_b>) {
422+
if constexpr (std::is_same_v<int4x2, data_type_zero_pt>) {
409423
zero_pt_h[i] = random_uint8();
410424
#ifdef UT_DEBUG
411425
zero_pt_h[i] = 0x12 << i;
412426
#endif
413-
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
427+
} else if constexpr (std::is_same_v<int4x8, data_type_zero_pt>) {
414428
zero_pt_h[i] = random_uint32();
415429
#ifdef UT_DEBUG
416430
zero_pt_h[i] = 0x33333333;
417431
#endif
432+
} else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
433+
zero_pt_h[i] = random_float();
418434
}
419435
}
420436

@@ -491,7 +507,9 @@ void dequantize_gemv_run(int iter) {
491507
Acc_d,
492508
Cnt_d,
493509
epilogue_args);
494-
} else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
510+
} else if constexpr (
511+
compute_policy::quant_mode == quant_mode::S4_ASYM ||
512+
compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
495513
gemm_arg =
496514
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
497515
matrix_m,

0 commit comments

Comments
 (0)