Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 2017cb0

Browse files
committed
rename the new quant type
1 parent cb58590 commit 2017cb0

File tree

5 files changed

+14
-14
lines changed

5 files changed

+14
-14
lines changed

include/common/core/common_types.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
3030
enum class quant_mode : uint8_t {
3131
S4_ASYM = 0,
3232
S4_FULLRANGE_NO_ZP = 1,
33-
INT4_ASYM_ZERO_NO_DEGRAD = 2
33+
INT4_ASYM_FP_ZERO = 2
3434
};
3535

3636
struct quant_info {

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class gemm_t<
285285

286286
// compress int4 along N dimensions
287287
using zero_pt_tile_desc_t = std::conditional_t<
288-
quant_info_.quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD,
288+
quant_info_.quant_mode != quant_mode::INT4_ASYM_FP_ZERO,
289289
subgroup::tile_desc_t<
290290
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
291291
tile_size_y_zero_pt,

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ class gemm_universal_t<
618618
int start_y_scale = start_k / dequant_s;
619619

620620
int start_x_zero_pt = gemm_t::compute_policy::quant_mode ==
621-
quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
621+
quant_mode::INT4_ASYM_FP_ZERO
622622
? start_n
623623
: start_n / pack_ratio;
624624
int start_y_zero_pt = start_k / dequant_s;
@@ -691,7 +691,7 @@ class gemm_universal_t<
691691
mem_desc_zero_pt);
692692
} else if constexpr (
693693
gemm_t::compute_policy::quant_mode ==
694-
quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
694+
quant_mode::INT4_ASYM_FP_ZERO) {
695695
mem_desc_zero_pt_t mem_desc_zero_pt(
696696
args.zero_pt_base,
697697
{args.matrix_n,

include/subgroup/tile/impl/tile_op_functor.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,15 @@ struct dequant_int4_weight_t {
151151
zero_pt_i8;
152152
} else if constexpr (
153153
quant_mode == quant_mode::S4_FULLRANGE_NO_ZP ||
154-
quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
154+
quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
155155
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
156156
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
157157
int8_t(8);
158158
}
159159
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
160160
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
161161
scale.reg[scale_idx];
162-
if constexpr (quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
162+
if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
163163
uint32_t zero_pt_idx =
164164
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
165165
offset_x_in_tile;

tests/integration/gemv/int4/main.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class test_col_major_1 {
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343
// static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44-
static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_ZERO_NO_DEGRAD;
44+
static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO;
4545

4646
static constexpr size_t local_kslicing = 1;
4747
static constexpr size_t global_kslicing = 1;
@@ -133,15 +133,15 @@ std::vector<fp16> convert_int4(
133133
std::vector<fp16> dequant_fp16(sizeof(data_type_b) * 2);
134134

135135
int8_t zero_pt_i8;
136-
if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
136+
if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
137137
zero_pt_i8 = zero_pt & 0xf;
138138
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
139139
int8_t dequant_8bit = data_b & 0xf;
140140
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
141141
dequant_fp16[i] = scale * (dequant_8bit - 8);
142142
} else if constexpr (quant_mode == quant_mode::S4_ASYM) {
143143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144-
} else if constexpr (quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
144+
} else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
145145
dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt;
146146
} else {
147147
assert(0);
@@ -176,13 +176,13 @@ std::vector<data_type_acc_in> dequantize_weight(
176176
for (uint32_t j = 0; j < width; j += step) {
177177
int start_b_in = i * width + j;
178178
int start_scale_in = start_b_in / step;
179-
int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
179+
int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
180180
? (j / step) * matrix_n + i
181181
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
182182
int start_out =
183183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184184
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185-
if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
185+
if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
186186
zp_value = zp_value >> (4 * (i % pack_radio));
187187
for (uint32_t jj = 0; jj < step; jj++) {
188188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
@@ -225,7 +225,7 @@ void dequantize_gemv_run(int iter) {
225225
using data_type_b = typename Test::data_type_b;
226226
using data_type_c = typename Test::data_type_c;
227227
using data_type_zero_pt = std::conditional_t<
228-
Test::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD,
228+
Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO,
229229
data_type_c,
230230
data_type_b>;
231231
using data_type_scale = fp16;
@@ -246,7 +246,7 @@ void dequantize_gemv_run(int iter) {
246246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
247247
constexpr size_t size_zero_pt_n = matrix_n;
248248
constexpr size_t size_zero_pt =
249-
Test::quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
249+
Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
250250
? size_zero_pt_k * size_zero_pt_n / 2
251251
: size_zero_pt_k * size_zero_pt_n;
252252

@@ -509,7 +509,7 @@ void dequantize_gemv_run(int iter) {
509509
epilogue_args);
510510
} else if constexpr (
511511
compute_policy::quant_mode == quant_mode::S4_ASYM ||
512-
compute_policy::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
512+
compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
513513
gemm_arg =
514514
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
515515
matrix_m,

0 commit comments

Comments
 (0)