Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 243c9e6

Browse files
committed
fix compile
1 parent 27bbe56 commit 243c9e6

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

-7
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,6 @@ class gemm_t<
101101
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x2>>::value ||
102102
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
103103
"this is for 4bit matB ");
104-
static_assert(
105-
quant_info_.quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD &&
106-
(std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
107-
value ||
108-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
109-
value),
110-
"this is for 4bit zero_pt ");
111104

112105
/******** set memory attribute **********/
113106
static constexpr mem_space mem_space_a = mem_desc_a_t::space;

tests/integration/gemv/int4/main.cpp

+16-9
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ class test_col_major_1 {
4040
static constexpr size_t sg_k = 1024 / 1;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43-
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43+
// static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44+
static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_ZERO_NO_DEGRAD;
4445

4546
static constexpr size_t local_kslicing = 1;
4647
static constexpr size_t global_kslicing = 1;
4748
static constexpr mem_layout layout_a = mem_layout::row_major;
4849
static constexpr mem_layout layout_b = mem_layout::col_major;
4950
static constexpr mma_engine mma_eng = mma_engine::fpu;
50-
static constexpr gpu_arch arch = gpu_arch::XeHpc;
51+
static constexpr gpu_arch arch = gpu_arch::XeHpg;
5152
using data_type_a = fp16;
5253
using data_type_b = int4x8;
5354
using data_type_c = fp16;
@@ -131,7 +132,9 @@ std::vector<fp16> convert_int4(
131132
data_type_zero_pt zero_pt) {
132133
std::vector<fp16> dequant_fp16(sizeof(data_type_b) * 2);
133134

134-
int8_t zero_pt_i8 = zero_pt & 0xf;
135+
int8_t zero_pt_i8;
136+
if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
137+
zero_pt_i8 = zero_pt & 0xf;
135138
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
136139
int8_t dequant_8bit = data_b & 0xf;
137140
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
@@ -173,15 +176,17 @@ std::vector<data_type_acc_in> dequantize_weight(
173176
for (uint32_t j = 0; j < width; j += step) {
174177
int start_b_in = i * width + j;
175178
int start_scale_in = start_b_in / step;
176-
int start_zero_pt_in =
177-
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
179+
int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
180+
? (j / step) * matrix_n + i
181+
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
178182
int start_out =
179183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184+
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185+
if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
186+
zp_value = zp_value >> (4 * (i % pack_radio));
180187
for (uint32_t jj = 0; jj < step; jj++) {
181188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
182-
b[start_b_in + jj],
183-
scale[start_scale_in],
184-
zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
189+
b[start_b_in + jj], scale[start_scale_in], zp_value);
185190
for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) {
186191
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
187192
}
@@ -502,7 +507,9 @@ void dequantize_gemv_run(int iter) {
502507
Acc_d,
503508
Cnt_d,
504509
epilogue_args);
505-
} else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
510+
} else if constexpr (
511+
compute_policy::quant_mode == quant_mode::S4_ASYM ||
512+
compute_policy::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
506513
gemm_arg =
507514
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
508515
matrix_m,

0 commit comments

Comments
 (0)