@@ -41,7 +41,7 @@ class test_col_major_1 {
41
41
static constexpr size_t dequant_s = 128 ;
42
42
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43
43
// static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44
- static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ;
44
+ static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO ;
45
45
46
46
static constexpr size_t local_kslicing = 1 ;
47
47
static constexpr size_t global_kslicing = 1 ;
@@ -133,15 +133,15 @@ std::vector<fp16> convert_int4(
133
133
std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
134
134
135
135
int8_t zero_pt_i8;
136
- if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD )
136
+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO )
137
137
zero_pt_i8 = zero_pt & 0xf ;
138
138
for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
139
139
int8_t dequant_8bit = data_b & 0xf ;
140
140
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
141
141
dequant_fp16[i] = scale * (dequant_8bit - 8 );
142
142
} else if constexpr (quant_mode == quant_mode::S4_ASYM) {
143
143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144
- } else if constexpr (quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ) {
144
+ } else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO ) {
145
145
dequant_fp16[i] = scale * (dequant_8bit - 8 ) + zero_pt;
146
146
} else {
147
147
assert (0 );
@@ -176,13 +176,13 @@ std::vector<data_type_acc_in> dequantize_weight(
176
176
for (uint32_t j = 0 ; j < width; j += step) {
177
177
int start_b_in = i * width + j;
178
178
int start_scale_in = start_b_in / step;
179
- int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
179
+ int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
180
180
? (j / step) * matrix_n + i
181
181
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
182
182
int start_out =
183
183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184
184
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185
- if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD )
185
+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO )
186
186
zp_value = zp_value >> (4 * (i % pack_radio));
187
187
for (uint32_t jj = 0 ; jj < step; jj++) {
188
188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
@@ -225,7 +225,7 @@ void dequantize_gemv_run(int iter) {
225
225
using data_type_b = typename Test::data_type_b;
226
226
using data_type_c = typename Test::data_type_c;
227
227
using data_type_zero_pt = std::conditional_t <
228
- Test::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ,
228
+ Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO ,
229
229
data_type_c,
230
230
data_type_b>;
231
231
using data_type_scale = fp16;
@@ -246,7 +246,7 @@ void dequantize_gemv_run(int iter) {
246
246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
247
247
constexpr size_t size_zero_pt_n = matrix_n;
248
248
constexpr size_t size_zero_pt =
249
- Test::quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
249
+ Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
250
250
? size_zero_pt_k * size_zero_pt_n / 2
251
251
: size_zero_pt_k * size_zero_pt_n;
252
252
@@ -509,7 +509,7 @@ void dequantize_gemv_run(int iter) {
509
509
epilogue_args);
510
510
} else if constexpr (
511
511
compute_policy::quant_mode == quant_mode::S4_ASYM ||
512
- compute_policy::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ) {
512
+ compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO ) {
513
513
gemm_arg =
514
514
typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
515
515
matrix_m,
0 commit comments