@@ -40,7 +40,8 @@ class test_col_major_1 {
40
40
static constexpr size_t sg_k = 512 / sg_m;
41
41
static constexpr size_t dequant_s = 128 ;
42
42
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43
- static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43
+ // static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44
+ static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO;
44
45
45
46
static constexpr size_t local_kslicing = 1 ;
46
47
static constexpr size_t global_kslicing = 1 ;
@@ -131,13 +132,19 @@ std::vector<fp16> convert_int4(
131
132
data_type_zero_pt zero_pt) {
132
133
std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
133
134
134
- int8_t zero_pt_i8 = zero_pt & 0xf ;
135
+ int8_t zero_pt_i8;
136
+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
137
+ zero_pt_i8 = zero_pt & 0xf ;
135
138
for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
136
139
int8_t dequant_8bit = data_b & 0xf ;
137
140
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
138
141
dequant_fp16[i] = scale * (dequant_8bit - 8 );
139
- } else {
142
+ } else if constexpr (quant_mode == quant_mode::S4_ASYM) {
140
143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144
+ } else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
145
+ dequant_fp16[i] = scale * (dequant_8bit - 8 ) + zero_pt;
146
+ } else {
147
+ assert (0 );
141
148
}
142
149
data_b = data_b >> 4 ;
143
150
}
@@ -169,15 +176,17 @@ std::vector<data_type_acc_in> dequantize_weight(
169
176
for (uint32_t j = 0 ; j < width; j += step) {
170
177
int start_b_in = i * width + j;
171
178
int start_scale_in = start_b_in / step;
172
- int start_zero_pt_in =
173
- (j / step) * (matrix_n / pack_radio) + i / pack_radio;
179
+ int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
180
+ ? (j / step) * matrix_n + i
181
+ : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
174
182
int start_out =
175
183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184
+ data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185
+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
186
+ zp_value = zp_value >> (4 * (i % pack_radio));
176
187
for (uint32_t jj = 0 ; jj < step; jj++) {
177
188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
178
- b[start_b_in + jj],
179
- scale[start_scale_in],
180
- zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
189
+ b[start_b_in + jj], scale[start_scale_in], zp_value);
181
190
for (uint32_t jjj = 0 ; jjj < dequant_fp16.size (); jjj++) {
182
191
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
183
192
}
@@ -215,7 +224,10 @@ void dequantize_gemv_run(int iter) {
215
224
using data_type_a = typename Test::data_type_a;
216
225
using data_type_b = typename Test::data_type_b;
217
226
using data_type_c = typename Test::data_type_c;
218
- using data_type_zero_pt = data_type_b;
227
+ using data_type_zero_pt = std::conditional_t <
228
+ Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO,
229
+ data_type_c,
230
+ data_type_b>;
219
231
using data_type_scale = fp16;
220
232
using data_type_acc_in = fp16;
221
233
using data_type_acc = float ;
@@ -225,7 +237,7 @@ void dequantize_gemv_run(int iter) {
225
237
constexpr mem_layout layout_b = Test::layout_b;
226
238
227
239
constexpr size_t size_a = matrix_m * matrix_k;
228
- constexpr size_t size_b = matrix_k * matrix_n / ( 2 * sizeof (data_type_b)) ;
240
+ constexpr size_t size_b = matrix_k * matrix_n / 2 ;
229
241
230
242
constexpr size_t size_scale_k = matrix_k / dequant_s;
231
243
constexpr size_t size_scale_n = matrix_n;
@@ -234,7 +246,9 @@ void dequantize_gemv_run(int iter) {
234
246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
235
247
constexpr size_t size_zero_pt_n = matrix_n;
236
248
constexpr size_t size_zero_pt =
237
- size_zero_pt_k * size_zero_pt_n / (2 * sizeof (data_type_b));
249
+ Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
250
+ ? size_zero_pt_k * size_zero_pt_n / 2
251
+ : size_zero_pt_k * size_zero_pt_n;
238
252
239
253
constexpr size_t size_c = matrix_m * matrix_n;
240
254
constexpr size_t size_bias = matrix_n;
@@ -405,16 +419,18 @@ void dequantize_gemv_run(int iter) {
405
419
scale_h[i] = INFINITY;
406
420
}
407
421
for (unsigned i = 0 ; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
408
- if constexpr (std::is_same_v<int4x2, data_type_b >) {
422
+ if constexpr (std::is_same_v<int4x2, data_type_zero_pt >) {
409
423
zero_pt_h[i] = random_uint8 ();
410
424
#ifdef UT_DEBUG
411
425
zero_pt_h[i] = 0x12 << i;
412
426
#endif
413
- } else if constexpr (std::is_same_v<int4x8, data_type_b >) {
427
+ } else if constexpr (std::is_same_v<int4x8, data_type_zero_pt >) {
414
428
zero_pt_h[i] = random_uint32 ();
415
429
#ifdef UT_DEBUG
416
430
zero_pt_h[i] = 0x33333333 ;
417
431
#endif
432
+ } else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
433
+ zero_pt_h[i] = random_float ();
418
434
}
419
435
}
420
436
@@ -491,7 +507,9 @@ void dequantize_gemv_run(int iter) {
491
507
Acc_d,
492
508
Cnt_d,
493
509
epilogue_args);
494
- } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
510
+ } else if constexpr (
511
+ compute_policy::quant_mode == quant_mode::S4_ASYM ||
512
+ compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
495
513
gemm_arg =
496
514
typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
497
515
matrix_m,
0 commit comments