Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 5a2ad7f

Browse files
sunjiweiswiftDDEle
authored andcommitted
fix group_qkv
1 parent 1ba27e4 commit 5a2ad7f

File tree

7 files changed

+64
-47
lines changed

7 files changed

+64
-47
lines changed

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -598,12 +598,8 @@ class gemm_universal_t<
598598
int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n;
599599
int start_k = 0;
600600
uint32_t wg_tile_k = args.matrix_k;
601-
uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
602-
? args.matrix_n
603-
: (start_n + wg_tile_n);
604-
uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m
605-
? args.matrix_m
606-
: (start_m + wg_tile_m);
601+
uint32_t boundary_n = std::min(start_n + wg_tile_n, args.matrix_n);
602+
uint32_t boundary_m = std::min(start_m + wg_tile_m, args.matrix_m);
607603
uint32_t boundary_k = wg_tile_k;
608604
if constexpr (num_global_kslicing > 1) {
609605
wg_tile_k = (wg_tile_k + num_global_kslicing - 1) / num_global_kslicing;

include/group/epilogue/impl/default_xe.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class epilogue_t<
9595
using mat_tile_desc = typename matAcc_t::tile_desc;
9696
using matC_t = subgroup::tile_t<dtype_c, mat_tile_desc>;
9797

98+
// static constexpr msg_type msg_type_c = msg_type::unaligned_2d;
9899
static constexpr msg_type msg_type_c =
99100
subgroup::msg_type_v<mat_tile_desc, mem_desc_c_t>;
100101
using matC_payload_t = subgroup::
@@ -192,9 +193,9 @@ class epilogue_t<
192193
using mat_tile_desc = typename matAcc_t::tile_desc;
193194
using matC_t = subgroup::tile_t<dtype_c, mat_tile_desc>;
194195

195-
// static constexpr msg_type msg_type_c = msg_type::block_2d;
196-
static constexpr msg_type msg_type_c =
197-
subgroup::msg_type_v<mat_tile_desc, mem_desc_c_t>;
196+
static constexpr msg_type msg_type_c = msg_type::block_2d;
197+
// static constexpr msg_type msg_type_c =
198+
// subgroup::msg_type_v<mat_tile_desc, mem_desc_c_t>;
198199

199200
using matC_payload_t = subgroup::
200201
mem_payload_t<mem_desc_c_t, mat_tile_desc, msg_type_c, arch_tag>;

include/kernel/gemm/impl/kslicing_xe.hpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,8 @@ class gemm_universal_t<
426426
int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n;
427427
int start_k = 0;
428428
uint32_t wg_tile_k = args.matrix_k;
429-
uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
430-
? args.matrix_n
431-
: (start_n + wg_tile_n);
432-
uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m
433-
? args.matrix_m
434-
: (start_m + wg_tile_m);
429+
uint32_t boundary_n = std::min(start_n + wg_tile_n,args.matrix_n);
430+
uint32_t boundary_m = std::min(start_m + wg_tile_m, args.matrix_m);
435431
uint32_t boundary_k = wg_tile_k;
436432
if constexpr (num_global_kslicing > 1) {
437433
wg_tile_k = (wg_tile_k + num_global_kslicing - 1) / num_global_kslicing;

include/subgroup/tile/impl/payload_xe.hpp

+25-4
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,15 @@ struct mem_payload_t<
435435
uint64_t base_offset;
436436
dtype* base_ptr;
437437
uint32_t pitch_in_bytes;
438+
uint32_t height_in_elems;
439+
uint32_t width_in_elems;
440+
uint32_t payload_bytes;
438441

439442
inline mem_payload_t(mem_desc_t& mem_tdesc) {
440443
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
444+
width_in_elems = mem_tdesc.shape.x;
445+
height_in_elems = mem_tdesc.shape.y;
446+
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
441447
uint32_t offset_x = mem_tdesc.coord.x;
442448
uint32_t offset_y = mem_tdesc.coord.y;
443449
base_offset = mem_transpose
@@ -448,14 +454,17 @@ struct mem_payload_t<
448454

449455
inline mem_payload_t(
450456
dtype* p,
451-
[[maybe_unused]] int surface_width,
452-
[[maybe_unused]] int surface_height,
457+
int surface_width,
458+
int surface_height,
453459
int surface_pitch,
454460
int surface_offset_x,
455461
int surface_offset_y) {
456462
pitch_in_bytes = surface_pitch * sizeof(dtype);
457463
uint32_t offset_x = surface_offset_x;
458464
uint32_t offset_y = surface_offset_y;
465+
width_in_elems = surface_width;
466+
height_in_elems = surface_height;
467+
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
459468
base_offset = mem_transpose
460469
? offset_x * pitch_in_bytes + offset_y * sizeof(dtype)
461470
: offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
@@ -466,6 +475,9 @@ struct mem_payload_t<
466475
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
467476
uint32_t offset_x = mem_tdesc.coord.x;
468477
uint32_t offset_y = mem_tdesc.coord.y;
478+
width_in_elems = mem_tdesc.shape.x;
479+
height_in_elems = mem_tdesc.shape.y;
480+
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
469481
base_offset = mem_transpose
470482
? offset_x * pitch_in_bytes + offset_y * sizeof(dtype)
471483
: offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
@@ -474,14 +486,17 @@ struct mem_payload_t<
474486

475487
__XETLA_API void init(
476488
dtype* p,
477-
[[maybe_unused]] int surface_width,
478-
[[maybe_unused]] int surface_height,
489+
int surface_width,
490+
int surface_height,
479491
int surface_pitch,
480492
int surface_offset_x,
481493
int surface_offset_y) {
482494
pitch_in_bytes = surface_pitch * sizeof(dtype);
483495
uint32_t offset_x = surface_offset_x;
484496
uint32_t offset_y = surface_offset_y;
497+
width_in_elems = surface_width;
498+
height_in_elems = surface_height;
499+
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
485500
base_offset = mem_transpose
486501
? offset_x * pitch_in_bytes + offset_y * sizeof(dtype)
487502
: offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
@@ -492,13 +507,19 @@ struct mem_payload_t<
492507
this->base_offset = rhs.base_offset;
493508
this->base_ptr = rhs.base_ptr;
494509
this->pitch_in_bytes = rhs.pitch_in_bytes;
510+
this->width_in_elems = rhs.width_in_elems;
511+
this->height_in_elems = rhs.height_in_elems;
512+
this->payload_bytes = rhs.payload_bytes;
495513
}
496514

497515
inline mem_payload_t() = default;
498516
inline this_payload_t& operator=(const this_payload_t& rhs) {
499517
this->base_offset = rhs.base_offset;
500518
this->base_ptr = rhs.base_ptr;
501519
this->pitch_in_bytes = rhs.pitch_in_bytes;
520+
this->width_in_elems = rhs.width_in_elems;
521+
this->height_in_elems = rhs.height_in_elems;
522+
this->payload_bytes = rhs.payload_bytes;
502523
return *this;
503524
}
504525

include/subgroup/tile/impl/store_xe.hpp

+27-24
Original file line numberDiff line numberDiff line change
@@ -283,34 +283,38 @@ tile_store(tile_t& tile, payload_t& payload) {
283283
static constexpr uint32_t store_len = tile_t::tile_elems;
284284
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
285285

286-
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
287-
static constexpr uint32_t max_store_vec_len =
288-
load_store_attr::max_store_vec_len;
289-
static constexpr uint32_t max_store_vec_elems =
290-
max_store_vec_len / sizeof(dtype);
286+
if (payload.base_offset <= payload.payload_bytes) {
287+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
288+
static constexpr uint32_t max_store_vec_len =
289+
load_store_attr::max_store_vec_len;
290+
static constexpr uint32_t max_store_vec_elems =
291+
max_store_vec_len / sizeof(dtype);
292+
static constexpr uint32_t store_iter_steps =
293+
store_len / max_store_vec_elems;
291294

292-
static constexpr uint32_t store_iter_steps = store_len / max_store_vec_elems;
293-
if constexpr (store_len >= max_store_vec_elems) {
295+
if constexpr (store_len >= max_store_vec_elems) {
294296
#pragma unroll
295-
for (uint32_t i = 0; i < store_iter_steps; i++) {
296-
uint32_t offset = i * max_store_vec_elems;
297-
auto reg_sub = tile.reg.xetla_select<max_store_vec_elems, 1>(offset);
298-
uint32_t address_offset = offset * sizeof(dtype);
297+
for (uint32_t i = 0; i < store_iter_steps; i++) {
298+
uint32_t offset = i * max_store_vec_elems;
299+
auto reg_sub = tile.reg.xetla_select<max_store_vec_elems, 1>(offset);
300+
uint32_t address_offset = offset * sizeof(dtype);
299301

300-
xetla_store_global<dtype, max_store_vec_elems, L1, L2>(
301-
payload.base_ptr,
302-
payload.base_offset + address_offset,
303-
reg_sub.xetla_format<dtype>());
302+
xetla_store_global<dtype, max_store_vec_elems, L1, L2>(
303+
payload.base_ptr,
304+
payload.base_offset + address_offset,
305+
reg_sub.xetla_format<dtype>());
306+
}
304307
}
308+
constexpr uint32_t tail_len =
309+
store_len % max_store_vec_elems * sizeof(dtype);
310+
uint32_t tail_offset = store_iter_steps * max_store_vec_len;
311+
detail::process_1d_tail<
312+
tail_len,
313+
(max_store_vec_len >> 1),
314+
detail::process_flag::store,
315+
L1,
316+
L2>(tile, payload, tail_offset);
305317
}
306-
constexpr uint32_t tail_len = store_len % max_store_vec_elems * sizeof(dtype);
307-
uint32_t tail_offset = store_iter_steps * max_store_vec_len;
308-
detail::process_1d_tail<
309-
tail_len,
310-
(max_store_vec_len >> 1),
311-
detail::process_flag::store,
312-
L1,
313-
L2>(tile, payload, tail_offset);
314318
}
315319

316320
/// @brief Is the func storing data from register file to unaligned global
@@ -348,7 +352,6 @@ tile_store(
348352
constexpr uint32_t num_channel_y = payload_t::num_channel_y;
349353
constexpr uint32_t store_elems = num_channel_y * payload_t::num_channel_x;
350354
constexpr uint32_t scale_factor = payload_t::scale_factor;
351-
352355
#pragma unroll
353356
for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
354357
i++) {

tests/integration/gemm/fp16/common.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class TestBase {
5252
class Test0 : public TestBase {
5353
public:
5454
static constexpr size_t mat_m = 1;
55-
static constexpr size_t mat_n = 1280;
55+
static constexpr size_t mat_n = 64;
5656
static constexpr size_t mat_k = 8192;
5757
static constexpr size_t wg_m = 8;
5858
static constexpr size_t wg_n = 32;

tests/integration/gemv/int4/main.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ class test_col_major_1 {
3232
// Extract the parameters required by different test cases
3333
static constexpr size_t mat_m = 1;
3434
static constexpr size_t mat_n = 4096;
35-
static constexpr size_t mat_k = 11008;
35+
static constexpr size_t mat_k = 4096;
3636
static constexpr size_t wg_m = 1;
3737
static constexpr size_t wg_n = 1;
3838
static constexpr size_t sg_m = 1;
3939
static constexpr size_t sg_n = 1;
40-
static constexpr size_t sg_k = 256 / 1;
40+
static constexpr size_t sg_k = 1024 / 1;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
@@ -47,7 +47,7 @@ class test_col_major_1 {
4747
static constexpr mem_layout layout_a = mem_layout::row_major;
4848
static constexpr mem_layout layout_b = mem_layout::col_major;
4949
static constexpr mma_engine mma_eng = mma_engine::fpu;
50-
static constexpr gpu_arch arch = gpu_arch::XeHpc;
50+
static constexpr gpu_arch arch = gpu_arch::XeLpg;
5151
using data_type_a = fp16;
5252
using data_type_b = int4x8;
5353
using data_type_c = fp16;

0 commit comments

Comments
 (0)