@@ -27,7 +27,7 @@ constexpr int ITER = 200;
27
27
#endif
28
28
constexpr size_t UNDEFINED_DATA_SIZE = 1024 ;
29
29
30
- template <typename scalar_t >
30
+ template <typename scalar_t , quant_mode quant_mode_ >
31
31
class test_col_major_1 {
32
32
public:
33
33
// Extract the parameters required by different test cases
@@ -41,7 +41,7 @@ class test_col_major_1 {
41
41
static constexpr size_t sg_k = 512 / sg_m;
42
42
static constexpr size_t dequant_s = 128 ;
43
43
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
44
- static constexpr quant_mode quant_mode = quant_mode::I4_SYM ;
44
+ static constexpr quant_mode quant_mode = quant_mode_ ;
45
45
46
46
static constexpr size_t local_kslicing = 1 ;
47
47
static constexpr size_t global_kslicing = 1 ;
@@ -132,13 +132,19 @@ std::vector<fp16> convert_int4(
132
132
data_type_zero_pt zero_pt) {
133
133
std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
134
134
135
- int8_t zero_pt_i8 = zero_pt & 0xf ;
135
+ int8_t zero_pt_i8;
136
+ if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
137
+ zero_pt_i8 = zero_pt & 0xf ;
136
138
for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
137
139
int8_t dequant_8bit = data_b & 0xf ;
138
140
if constexpr (quant_mode == quant_mode::I4_SYM) {
139
141
dequant_fp16[i] = scale * (dequant_8bit - 8 );
140
- } else {
142
+ } else if constexpr (quant_mode == quant_mode::I4_ASYM) {
141
143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144
+ } else if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
145
+ dequant_fp16[i] = scale * (dequant_8bit - 8 ) + zero_pt;
146
+ } else {
147
+ assert (0 );
142
148
}
143
149
data_b = data_b >> 4 ;
144
150
}
@@ -170,12 +176,14 @@ std::vector<data_type_acc_in> dequantize_weight(
170
176
for (uint32_t j = 0 ; j < width; j += step) {
171
177
int start_b_in = i * width + j;
172
178
int start_scale_in = start_b_in / step;
173
- int start_zero_pt_in =
174
- (j / step) * (matrix_n / pack_radio) + i / pack_radio;
179
+ int start_zero_pt_in = quant_mode == quant_mode::I4_ASYM_FP_ZERO
180
+ ? (j / step) * matrix_n + i
181
+ : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
175
182
int start_out =
176
183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
177
184
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
178
- zp_value = zp_value >> (4 * (i % pack_radio));
185
+ if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
186
+ zp_value = zp_value >> (4 * (i % pack_radio));
179
187
for (uint32_t jj = 0 ; jj < step; jj++) {
180
188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
181
189
b[start_b_in + jj], scale[start_scale_in], zp_value);
@@ -216,7 +224,10 @@ void dequantize_gemv_run(int iter) {
216
224
using data_type_a = typename Test::data_type_a;
217
225
using data_type_b = typename Test::data_type_b;
218
226
using data_type_c = typename Test::data_type_c;
219
- using data_type_zero_pt = data_type_b;
227
+ using data_type_zero_pt = std::conditional_t <
228
+ Test::quant_mode == quant_mode::I4_ASYM_FP_ZERO,
229
+ data_type_c,
230
+ data_type_b>;
220
231
using data_type_scale = fp16;
221
232
using data_type_acc_in = fp16;
222
233
using data_type_acc = float ;
@@ -226,7 +237,7 @@ void dequantize_gemv_run(int iter) {
226
237
constexpr mem_layout layout_b = Test::layout_b;
227
238
228
239
constexpr size_t size_a = matrix_m * matrix_k;
229
- constexpr size_t size_b = matrix_k * matrix_n / ( 2 * sizeof (data_type_b)) ;
240
+ constexpr size_t size_b = matrix_k * matrix_n / 2 ;
230
241
231
242
constexpr size_t size_scale_k = matrix_k / dequant_s;
232
243
constexpr size_t size_scale_n = matrix_n;
@@ -235,7 +246,9 @@ void dequantize_gemv_run(int iter) {
235
246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
236
247
constexpr size_t size_zero_pt_n = matrix_n;
237
248
constexpr size_t size_zero_pt =
238
- size_zero_pt_k * size_zero_pt_n / (2 * sizeof (data_type_b));
249
+ Test::quant_mode != quant_mode::I4_ASYM_FP_ZERO
250
+ ? size_zero_pt_k * size_zero_pt_n / 2
251
+ : size_zero_pt_k * size_zero_pt_n;
239
252
240
253
constexpr size_t size_c = matrix_m * matrix_n;
241
254
constexpr size_t size_bias = matrix_n;
@@ -406,16 +419,18 @@ void dequantize_gemv_run(int iter) {
406
419
scale_h[i] = INFINITY;
407
420
}
408
421
for (unsigned i = 0 ; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
409
- if constexpr (std::is_same_v<int4x2, data_type_b >) {
422
+ if constexpr (std::is_same_v<int4x2, data_type_zero_pt >) {
410
423
zero_pt_h[i] = random_uint8 ();
411
424
#ifdef UT_DEBUG
412
425
zero_pt_h[i] = 0x12 << i;
413
426
#endif
414
- } else if constexpr (std::is_same_v<int4x8, data_type_b >) {
427
+ } else if constexpr (std::is_same_v<int4x8, data_type_zero_pt >) {
415
428
zero_pt_h[i] = random_uint32 ();
416
429
#ifdef UT_DEBUG
417
430
zero_pt_h[i] = 0x33333333 ;
418
431
#endif
432
+ } else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
433
+ zero_pt_h[i] = random_float ();
419
434
}
420
435
}
421
436
@@ -492,7 +507,9 @@ void dequantize_gemv_run(int iter) {
492
507
Acc_d,
493
508
Cnt_d,
494
509
epilogue_args);
495
- } else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
510
+ } else if constexpr (
511
+ compute_policy::quant_mode == quant_mode::I4_ASYM ||
512
+ compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
496
513
gemm_arg =
497
514
typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
498
515
matrix_m,
@@ -604,8 +621,10 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {
604
621
605
622
REGISTER_TYPED_TEST_SUITE_P (dequantize_gemv_test, esimd);
606
623
using tests = ::testing::Types< //
607
- test_col_major_1<fp16>,
608
- test_col_major_1<bf16>,
624
+ test_col_major_1<fp16, quant_mode::I4_SYM>,
625
+ test_col_major_1<bf16, quant_mode::I4_SYM>,
626
+ test_col_major_1<fp16, quant_mode::I4_ASYM_FP_ZERO>,
627
+ test_col_major_1<bf16, quant_mode::I4_ASYM_FP_ZERO>,
609
628
// test_col_major_2,
610
629
void >;
611
630
0 commit comments