@@ -40,14 +40,15 @@ class test_col_major_1 {
40
40
static constexpr size_t sg_k = 1024 / 1 ;
41
41
static constexpr size_t dequant_s = 128 ;
42
42
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43
- static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43
+ // static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44
+ static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_ZERO_NO_DEGRAD;
44
45
45
46
static constexpr size_t local_kslicing = 1 ;
46
47
static constexpr size_t global_kslicing = 1 ;
47
48
static constexpr mem_layout layout_a = mem_layout::row_major;
48
49
static constexpr mem_layout layout_b = mem_layout::col_major;
49
50
static constexpr mma_engine mma_eng = mma_engine::fpu;
50
- static constexpr gpu_arch arch = gpu_arch::XeHpc ;
51
+ static constexpr gpu_arch arch = gpu_arch::XeHpg ;
51
52
using data_type_a = fp16;
52
53
using data_type_b = int4x8;
53
54
using data_type_c = fp16;
@@ -131,7 +132,9 @@ std::vector<fp16> convert_int4(
131
132
data_type_zero_pt zero_pt) {
132
133
std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
133
134
134
- int8_t zero_pt_i8 = zero_pt & 0xf ;
135
+ int8_t zero_pt_i8;
136
+ if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
137
+ zero_pt_i8 = zero_pt & 0xf ;
135
138
for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
136
139
int8_t dequant_8bit = data_b & 0xf ;
137
140
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
@@ -173,15 +176,17 @@ std::vector<data_type_acc_in> dequantize_weight(
173
176
for (uint32_t j = 0 ; j < width; j += step) {
174
177
int start_b_in = i * width + j;
175
178
int start_scale_in = start_b_in / step;
176
- int start_zero_pt_in =
177
- (j / step) * (matrix_n / pack_radio) + i / pack_radio;
179
+ int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
180
+ ? (j / step) * matrix_n + i
181
+ : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
178
182
int start_out =
179
183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184
+ data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185
+ if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
186
+ zp_value = zp_value >> (4 * (i % pack_radio));
180
187
for (uint32_t jj = 0 ; jj < step; jj++) {
181
188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
182
- b[start_b_in + jj],
183
- scale[start_scale_in],
184
- zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
189
+ b[start_b_in + jj], scale[start_scale_in], zp_value);
185
190
for (uint32_t jjj = 0 ; jjj < dequant_fp16.size (); jjj++) {
186
191
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
187
192
}
@@ -502,7 +507,9 @@ void dequantize_gemv_run(int iter) {
502
507
Acc_d,
503
508
Cnt_d,
504
509
epilogue_args);
505
- } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
510
+ } else if constexpr (
511
+ compute_policy::quant_mode == quant_mode::S4_ASYM ||
512
+ compute_policy::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
506
513
gemm_arg =
507
514
typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
508
515
matrix_m,
0 commit comments