Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit a0b94f3

Browse files
committed
save
1 parent 90662a8 commit a0b94f3

File tree

2 files changed

+45
-54
lines changed

2 files changed

+45
-54
lines changed

include/subgroup/tile/impl/load_xe.hpp

+36-45
Original file line numberDiff line numberDiff line change
@@ -114,27 +114,45 @@ tile_load(tile_t& tile, payload_t& payload) {
114114
// static constexpr uint32_t max_load_height_in_elem = trans
115115
// ? load_store_attr::max_trans_load_height_in_elem
116116
// : load_store_attr::max_load_height_in_elem;
117-
static constexpr uint32_t max_trans_load_width_in_elem =
118-
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119-
static constexpr uint32_t max_load_width_in_elem =
120-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
117+
// static constexpr uint32_t max_trans_load_width_in_elem =
118+
// load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119+
// static constexpr uint32_t max_load_width_in_elem =
120+
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121121

122122
// static constexpr uint32_t max_trans_load_height_in_elem =
123123
// load_store_attr::max_trans_load_height_in_elem;
124-
static constexpr uint32_t max_load_height_in_elem =
125-
load_store_attr::max_load_height_in_elem;
124+
125+
// static constexpr uint32_t max_load_height_in_elem =
126+
// load_store_attr::max_load_height_in_elem;
126127

127128
static constexpr uint32_t elems_per_CL =
128129
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
129130

130131
static constexpr uint32_t elems_per_reg =
131132
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
132133

133-
static constexpr uint32_t ld_blk_size_y_limit =
134-
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135-
static constexpr uint32_t ld_blk_size_y = reg_transpose
136-
? block_size_y
137-
: std::min(ld_blk_size_y_limit, block_size_y);
134+
static constexpr uint32_t max_ld_blk_width_in_elem = trans
135+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
136+
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
137+
138+
static constexpr uint32_t max_ld_blk_height_in_elem = trans
139+
? load_store_attr::max_trans_load_height_in_elem
140+
: load_store_attr::max_load_height_in_elem;
141+
142+
static constexpr uint32_t ld_blk_width = std::min(
143+
(mem_transpose ? block_size_y : block_size_x), max_ld_blk_width_in_elem);
144+
static constexpr uint32_t ld_blk_height = std::min(
145+
(mem_transpose ? block_size_x : block_size_y), max_ld_blk_height_in_elem);
146+
147+
static constexpr uint32_t ld_blk_size_y =
148+
mem_transpose ? ld_blk_width : ld_blk_height;
149+
150+
static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
151+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
152+
: load_store_attr::max_load_height_in_elem;
153+
// static constexpr uint32_t ld_blk_size_y = reg_transpose
154+
// ? block_size_y
155+
// : std::min(ld_blk_size_y_limit, block_size_y);
138156

139157
// array len is used to make sure memory load is cache line aligned
140158
// disabled while register or memory transpose
@@ -198,10 +216,7 @@ tile_load(tile_t& tile, payload_t& payload) {
198216
constexpr uint32_t load_block_elems = block_elems * arr_len;
199217
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
200218
(i * num_block_x + j) * block_elems);
201-
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
202-
? detail::getNextPowerOf2<ld_blk_size_y>()
203-
: ld_blk_size_y;
204-
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
219+
constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205220
xetla_vector<dtype, tmp_size> reg_tmp;
206221
#pragma unroll
207222
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +228,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213228
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214229
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
215230
native_type_t<load_dtype>,
216-
(trans ? ld_blk_size_y : block_size_x) / scale_factor,
217-
(trans ? block_size_x : ld_blk_size_y),
218-
// block_size_x / scale_factor,
219-
// ld_blk_size_y,
231+
ld_blk_width / scale_factor,
232+
ld_blk_height,
220233
arr_len,
221234
trans,
222235
mem_transform,
@@ -261,11 +274,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261274
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262275
constexpr uint8_t block_height =
263276
mem_transpose ? block_size_x : remained_blk_size_y;
264-
// constexpr uint32_t block_widthx_widthy_arrlen =
265-
// (block_width - 1) | ((block_height - 1) << 8);
266-
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267-
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268-
269277
reg_blk.xetla_select<load_elems, 1>(remained_start)
270278
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
271279
native_type_t<load_dtype>,
@@ -283,15 +291,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283291
payload.surface_pitch,
284292
payload.offset_x + offset_x / scale_factor,
285293
payload.offset_y + offset_y + remained_start_y);
286-
287-
// xetla_tload_global<
288-
// load_dtype,
289-
// (load_elems / scale_factor),
290-
// L1,
291-
// L2,
292-
// trans,
293-
// mem_transform,
294-
// arch_tag>(tdesc);
295294
}
296295
}
297296
}
@@ -304,24 +303,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304303
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305304
? ld_blk_size_y_limit
306305
: remained_size_y;
307-
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308-
// num_block_y * num_block_x, 0);
309-
// detail::reset_tile_desc_core<
310-
// num_block_x,
311-
// block_size_x,
312-
// remained_ld_blk_size_y,
313-
// scale_factor,
314-
// arr_len,
315-
// mem_transpose>(payload_row);
306+
316307
#pragma unroll
317308
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
318309
int32_t offset_x = j * block_size_x;
319310
// xetla_tdescriptor tdesc = payload_row.row(j);
320311
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
321312
processed_elems + j * remained_block_elems);
322-
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323-
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324-
: remained_ld_blk_size_y;
313+
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
314+
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
315+
// : remained_ld_blk_size_y;
325316
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326317
xetla_vector<dtype, tmp_size> reg_tmp;
327318
#pragma unroll

tests/integration/gemm/fp32/main.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) {
3434

3535
REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd);
3636
using tests = ::testing::Types<
37-
// Test1,
38-
// Test2,
39-
// Test3,
40-
// Test4,
41-
// Test5,
42-
// Test6,
43-
// Test7,
44-
// Test8,
45-
// Test9,
37+
Test1,
38+
Test2,
39+
Test3,
40+
Test4,
41+
Test5,
42+
Test6,
43+
Test7,
44+
Test8,
45+
Test9,
4646
Test10,
4747
Test11>;
4848
INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);

0 commit comments

Comments
 (0)