Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 29946d2

Browse files
committed
support bf16 activation
1 parent 24b1ec4 commit 29946d2

File tree

4 files changed

+29
-12
lines changed

4 files changed

+29
-12
lines changed

include/common/core/explicit_conv.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ xetla_cvt(xetla_vector<T_src, N> src) {
6262
return dst;
6363
}
6464

65+
/// @brief xetla explicit data conversion, bf16->fp16.
66+
/// @tparam T_dst is the float16 data type.
67+
/// @tparam T_src is the bfloat16 data type.
68+
/// @tparam N is the element number in xetla_vector.
69+
template <typename T_dst, typename T_src, int N>
70+
__XETLA_API typename std::enable_if_t<
71+
std::is_same<T_dst, fp16>::value && std::is_same<T_src, bf16>::value,
72+
xetla_vector<T_dst, N>>
73+
xetla_cvt(xetla_vector<T_src, N> src) {
74+
xetla_vector<T_dst, N> dst = src;
75+
return dst;
76+
}
77+
6578
/// @brief xetla explicit data conversion, bf16->fp32.
6679
/// @tparam T_dst is the bfloat16 data type.
6780
/// @tparam T_src is the float32 data type.

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ class gemm_universal_t<
526526
template <quant_mode quant_mode>
527527
static bool can_implement(arguments_t<quant_mode>& args) {
528528
bool implementable = true;
529+
if (arch_tag == gpu_arch::XeLpg) {
530+
implementable &= !std::is_same_v<dtype_a, bf16>; // XeLpg arch dosen't
531+
// have bf16 related isa.
532+
}
529533
if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
530534
if (gemm_t::msg_type_a == msg_type::block_2d) {
531535
implementable &= kernel::block_2d<arch_tag, dtype_a>::check_tensor(
@@ -617,8 +621,8 @@ class gemm_universal_t<
617621
int start_x_scale = start_n;
618622
int start_y_scale = start_k / dequant_s;
619623

620-
int start_x_zero_pt = gemm_t::compute_policy::quant_mode ==
621-
quant_mode::INT4_ASYM_FP_ZERO
624+
int start_x_zero_pt =
625+
gemm_t::compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO
622626
? start_n
623627
: start_n / pack_ratio;
624628
int start_y_zero_pt = start_k / dequant_s;
@@ -690,8 +694,7 @@ class gemm_universal_t<
690694
mem_desc_scale,
691695
mem_desc_zero_pt);
692696
} else if constexpr (
693-
gemm_t::compute_policy::quant_mode ==
694-
quant_mode::INT4_ASYM_FP_ZERO) {
697+
gemm_t::compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
695698
mem_desc_zero_pt_t mem_desc_zero_pt(
696699
args.zero_pt_base,
697700
{args.matrix_n,

include/subgroup/tile/impl/tile_op_functor.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,10 @@ struct dequant_int4_weight_t {
163163
uint32_t zero_pt_idx =
164164
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
165165
offset_x_in_tile;
166-
native_type_t<typename zero_pt_t::dtype> zero_pt_pack =
167-
zero_pt.reg[zero_pt_idx];
166+
xetla_vector<fp16, 1> zero_pt_pack = zero_pt.reg[zero_pt_idx];
168167
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
169168
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) +
170-
zero_pt_pack;
169+
zero_pt_pack[0];
171170
}
172171
// sycl::ext::oneapi::experimental::printf(
173172
// "scale[%d] %f \n",

tests/integration/gemv/int4/main.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ class test_col_major_1 {
4949
static constexpr mem_layout layout_b = mem_layout::col_major;
5050
static constexpr mma_engine mma_eng = mma_engine::fpu;
5151
static constexpr gpu_arch arch = gpu_arch::XeLpg;
52-
using data_type_a = fp16;
52+
using data_type_a = bf16;
5353
using data_type_b = int4x8;
54-
using data_type_c = fp16;
54+
using data_type_c = bf16;
5555
};
5656
class test_col_major_2 {
5757
public:
@@ -569,9 +569,11 @@ void dequantize_gemv_run(int iter) {
569569
// performance
570570
prof.print_profiling_result(profiling_selector::GPU);
571571
// check result
572-
std::vector<typename Test::data_type_a> dequantize_b =
573-
dequantize_weight<dequant_s, layout_b, compute_policy::quant_mode>(
574-
matrix_k, matrix_n, B_h, scale_h, zero_pt_h);
572+
std::vector<typename Test::data_type_a> dequantize_b = dequantize_weight<
573+
dequant_s,
574+
layout_b,
575+
compute_policy::quant_mode,
576+
data_type_c>(matrix_k, matrix_n, B_h, scale_h, zero_pt_h);
575577

576578
queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait();
577579
ASSERT_EQ(

0 commit comments

Comments
 (0)