Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Add zp no degrad dequant #297

Open
wants to merge 4 commits into
base: xetla
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };

enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
enum class quant_mode : uint8_t {
I4_ASYM = 0,
I4_SYM = 1,
I4_ASYM_FP_ZERO = 2
};

struct quant_info {
quant_mode quant_mode;
Expand Down
13 changes: 13 additions & 0 deletions include/common/core/explicit_conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ xetla_cvt(xetla_vector<T_src, N> src) {
return dst;
}

/// @brief xetla explicit data conversion, bf16->fp16.
/// @tparam T_dst is the float16 data type.
/// @tparam T_src is the bfloat16 data type.
/// @tparam N is the element number in xetla_vector.
template <typename T_dst, typename T_src, int N>
__XETLA_API typename std::enable_if_t<
std::is_same<T_dst, fp16>::value && std::is_same<T_src, bf16>::value,
xetla_vector<T_dst, N>>
xetla_cvt(xetla_vector<T_src, N> src) {
xetla_vector<T_dst, N> dst = src;
return dst;
}

/// @brief xetla explicit data conversion, bf16->fp32.
/// @tparam T_dst is the bfloat16 data type.
/// @tparam T_src is the float32 data type.
Expand Down
52 changes: 30 additions & 22 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ class gemm_t<
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
"this is for 4bit matB ");
static_assert(
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
value ||
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
value,
quant_info_.quant_mode == quant_mode::I4_ASYM_FP_ZERO
? std::is_same_v<
remove_const_t<dtype_zero_pt>,
remove_const_t<dtype_a>>
: (std::is_same_v<
remove_const_t<dtype_zero_pt>,
remove_const_t<int4x2>> ||
std::is_same_v<
remove_const_t<dtype_zero_pt>,
remove_const_t<int4x8>>),
"this is for 4bit zero_pt ");

/******** set memory attribute **********/
Expand Down Expand Up @@ -284,12 +290,20 @@ class gemm_t<
arch_tag>;

// compress int4 along N dimensions
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
tile_size_y_zero_pt,
(block_size_x_b + pack_ratio - 1) / pack_ratio,
block_size_y_zero_pt,
reg_layout::tiled>;
using zero_pt_tile_desc_t = std::conditional_t<
quant_info_.quant_mode != quant_mode::I4_ASYM_FP_ZERO,
subgroup::tile_desc_t<
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
tile_size_y_zero_pt,
(block_size_x_b + pack_ratio - 1) / pack_ratio,
block_size_y_zero_pt,
reg_layout::tiled>,
subgroup::tile_desc_t<
tile_size_x_b,
tile_size_y_zero_pt,
block_size_x_b,
block_size_y_zero_pt,
reg_layout::tiled>>;

using zero_pt_t = subgroup::tile_t<dtype_zero_pt, zero_pt_tile_desc_t>;
using zero_pt_payload_t = subgroup::mem_payload_t<
Expand Down Expand Up @@ -520,8 +534,7 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -534,8 +547,7 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<update_dir_b>(
scale_t::tile_size_y);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down Expand Up @@ -564,8 +576,7 @@ class gemm_t<
// matB, matB_payload);
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
scale, scale_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
zero_pt, zero_pt_payload);
}
Expand All @@ -579,8 +590,7 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -593,8 +603,7 @@ class gemm_t<
if (tile_k_idx % scale_addr_update_freq == 0) {
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
}
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand All @@ -608,8 +617,7 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
scale_t::tile_size_y);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class gemm_universal_t<
/// @brief GEMM arguments.
/// This is the interface for users to pass the application-related runtime
/// variables.
template <quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP>
template <quant_mode quant_mode = quant_mode::I4_SYM>
struct arguments_t {
/// @brief Is the size of the m dimension of the matrix multiplication (m x
/// k x n).
Expand Down Expand Up @@ -295,7 +295,7 @@ class gemm_universal_t<
}
};
template <>
struct arguments_t<quant_mode::S4_FULLRANGE_NO_ZP> {
struct arguments_t<quant_mode::I4_SYM> {
/// @brief Is the size of the m dimension of the matrix multiplication (m x
/// k x n).
uint32_t matrix_m;
Expand Down Expand Up @@ -526,6 +526,10 @@ class gemm_universal_t<
template <quant_mode quant_mode>
static bool can_implement(arguments_t<quant_mode>& args) {
bool implementable = true;
if (arch_tag == gpu_arch::XeLpg) {
implementable &= !std::is_same_v<dtype_a, bf16>; // XeLpg arch dosen't
// have bf16 related isa.
}
if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
if (gemm_t::msg_type_a == msg_type::block_2d) {
implementable &= kernel::block_2d<arch_tag, dtype_a>::check_tensor(
Expand Down Expand Up @@ -566,8 +570,7 @@ class gemm_universal_t<
// check for int4x2
implementable &=
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
if constexpr (
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
implementable &= (args.zero_pt_ld % pack_ratio == 0);
}

Expand Down Expand Up @@ -618,7 +621,10 @@ class gemm_universal_t<
int start_x_scale = start_n;
int start_y_scale = start_k / dequant_s;

int start_x_zero_pt = start_n / pack_ratio;
int start_x_zero_pt =
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO
? start_n
: start_n / pack_ratio;
int start_y_zero_pt = start_k / dequant_s;

// set up arguments
Expand Down Expand Up @@ -664,15 +670,15 @@ class gemm_universal_t<
uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride;
uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
gemm_args_t gemm_args;
if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) {
gemm_args = gemm_args_t(
mem_desc_a,
mem_desc_b,
inner_loop_start,
inner_loop_count,
mem_desc_scale);
} else {
} else if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
mem_desc_zero_pt_t mem_desc_zero_pt(
args.zero_pt_base,
{(args.matrix_n + pack_ratio - 1) / pack_ratio,
Expand All @@ -686,6 +692,23 @@ class gemm_universal_t<
inner_loop_count,
mem_desc_scale,
mem_desc_zero_pt);
} else if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
mem_desc_zero_pt_t mem_desc_zero_pt(
args.zero_pt_base,
{args.matrix_n,
((args.matrix_k + dequant_s - 1) / dequant_s),
args.zero_pt_ld},
{start_x_zero_pt, start_y_zero_pt});
gemm_args = gemm_args_t(
mem_desc_a,
mem_desc_b,
inner_loop_start,
inner_loop_count,
mem_desc_scale,
mem_desc_zero_pt);
} else {
assert(0);
}
matAcc_t matAcc;
matAcc.init(0);
Expand Down
16 changes: 13 additions & 3 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct dequant_int4_weight_t {
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
offset_x_in_tile;

if constexpr (quant_mode == quant_mode::S4_ASYM) {
if constexpr (quant_mode == quant_mode::I4_ASYM) {
uint32_t zero_pt_idx =
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
offset_x_in_tile / pack_ratio;
Expand All @@ -149,15 +149,25 @@ struct dequant_int4_weight_t {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
zero_pt_i8;
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
} else if constexpr (
quant_mode == quant_mode::I4_SYM ||
quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
}
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];

if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
uint32_t zero_pt_idx =
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
offset_x_in_tile;
xetla_vector<fp16, 1> zero_pt_pack = zero_pt.reg[zero_pt_idx];
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) +
zero_pt_pack[0];
}
// sycl::ext::oneapi::experimental::printf(
// "scale[%d] %f \n",
// scale_idx,
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/gemm/int4_dequantization/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) {
compute_attr_t<data_type_acc_in, data_type_acc_in, data_type_acc>;
using perf_tuning_knob = xetla::group::
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;

static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b};

static constexpr quant_info quant_info{
quant_mode::I4_ASYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) {
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;

static constexpr quant_info quant_info{
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
quant_mode::I4_SYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down Expand Up @@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd);
INSTANTIATE_TYPED_TEST_SUITE_P(
dequantize_gemm_act_shuf_test_suite,
dequantize_gemm_act_shuf_test,
tests);
tests);
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) {
using perf_tuning_knob = xetla::group::
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
static constexpr quant_info quant_info{
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
quant_mode::I4_SYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down
Loading
Loading