Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 4b56d54

Browse files
committed
save
1 parent 90662a8 commit 4b56d54

File tree

2 files changed

+79
-98
lines changed

2 files changed

+79
-98
lines changed

include/subgroup/tile/impl/load_xe.hpp

+70-89
Original file line numberDiff line numberDiff line change
@@ -108,81 +108,86 @@ tile_load(tile_t& tile, payload_t& payload) {
108108

109109
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
110110

111-
// static constexpr uint32_t max_load_width_in_elem = trans
112-
// ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
113-
// : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
111+
// static constexpr uint32_t max_load_width_in_elem = trans
112+
// ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
113+
// : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
114114
// static constexpr uint32_t max_load_height_in_elem = trans
115115
// ? load_store_attr::max_trans_load_height_in_elem
116116
// : load_store_attr::max_load_height_in_elem;
117-
static constexpr uint32_t max_trans_load_width_in_elem =
118-
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119-
static constexpr uint32_t max_load_width_in_elem =
120-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
117+
// static constexpr uint32_t max_trans_load_width_in_elem =
118+
// load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119+
// static constexpr uint32_t max_load_width_in_elem =
120+
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121121

122122
// static constexpr uint32_t max_trans_load_height_in_elem =
123123
// load_store_attr::max_trans_load_height_in_elem;
124-
static constexpr uint32_t max_load_height_in_elem =
125-
load_store_attr::max_load_height_in_elem;
124+
125+
// static constexpr uint32_t max_load_height_in_elem =
126+
// load_store_attr::max_load_height_in_elem;
126127

127128
static constexpr uint32_t elems_per_CL =
128129
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
129130

130131
static constexpr uint32_t elems_per_reg =
131132
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
132133

133-
static constexpr uint32_t ld_blk_size_y_limit =
134-
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135-
static constexpr uint32_t ld_blk_size_y = reg_transpose
136-
? block_size_y
137-
: std::min(ld_blk_size_y_limit, block_size_y);
134+
static constexpr uint32_t max_load_width_in_elem = trans
135+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
136+
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
137+
138+
static constexpr uint32_t max_load_blk_height_in_elem = trans
139+
? load_store_attr::max_trans_load_height_in_elem
140+
: load_store_attr::max_load_height_in_elem;
141+
142+
static constexpr uint32_t ld_blk_width = std::min(
143+
(mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem);
144+
145+
static constexpr uint32_t ld_blk_height = std::min(
146+
(mem_transpose ? block_size_x : block_size_y),
147+
max_load_blk_height_in_elem);
148+
149+
static constexpr uint32_t ld_blk_size_y =
150+
mem_transpose ? ld_blk_width : ld_blk_height;
151+
152+
static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
153+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
154+
: load_store_attr::max_load_height_in_elem;
138155

139156
// array len is used to make sure memory load is cache line aligned
140157
// disabled while register or memory transpose
141158
static constexpr uint8_t arr_len_candidate =
142-
(reg_transpose ||
143-
mem_transpose
159+
((reg_transpose || mem_transpose)
144160
// block elements should be integer
145161
// times of register bytes
146-
|| ((block_size_y * block_size_x) % elems_per_reg != 0)
162+
|| ((block_elems) % elems_per_reg != 0)
147163
// tail blocks also need to meet above condition
148-
||
149-
(((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) ||
150-
(block_size_y > ld_blk_size_y_limit)
164+
|| (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0))
165+
// || (block_size_y > load_store_attr::max_load_height_in_elem)
151166
? 1
152167
: (((tile_size_x % elems_per_CL) == 0)
153168
? (((elems_per_CL % block_size_x) == 0)
154169
? elems_per_CL / block_size_x
155170
: 1)
156171
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
157172
: 1));
158-
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
159-
(arr_len_candidate == 2) || (arr_len_candidate == 4);
160-
161-
static constexpr uint8_t arr_len =
162-
is_valid_arr_len_candidate ? arr_len_candidate : 1;
163-
164-
static_assert(
165-
reg_transpose || mem_transpose ||
166-
(!mem_transpose &&
167-
(block_size_x * arr_len) <= max_load_width_in_elem),
168-
"When reg_transpose was disabled, check 2d block width "
169-
"restriction");
170-
static_assert(
171-
!reg_transpose ||
172-
(!mem_transpose &&
173-
(block_size_x * arr_len) <= max_trans_load_width_in_elem) ||
174-
(mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem),
175-
"When reg_transpose was enabled, check 2d block width "
176-
"restriction");
177-
static_assert(
178-
!reg_transpose ||
179-
(!mem_transpose && (block_size_y <= max_load_height_in_elem)) ||
180-
(mem_transpose && (block_size_x) <= max_load_height_in_elem),
181-
"When reg_transpose was enabled, check 2d block height "
182-
"restriction");
183-
static_assert(
184-
tile_size_x % (block_size_x * arr_len) == 0,
185-
"tile_size_x should be a multiple of (block_size_x * arr_len)");
173+
// NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for
174+
// qwords.
175+
static constexpr bool arr_len =
176+
((arr_len_candidate == 1) ||
177+
(arr_len_candidate == 2 && sizeof(dtype) <= 4) ||
178+
(arr_len_candidate == 4 && sizeof(dtype) <= 2))
179+
? arr_len_candidate
180+
: 1;
181+
182+
if constexpr (!trans && !mem_transform) {
183+
static_assert(
184+
(ld_blk_width * arr_len) <= max_load_width_in_elem,
185+
"When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords");
186+
} else if constexpr (mem_transform) {
187+
static_assert(
188+
(ld_blk_width * arr_len) <= max_load_width_in_elem,
189+
"When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words.");
190+
}
186191
static_assert(
187192
(reg_transpose &&
188193
((block_size_x * sizeof(dtype)) % sizeof(load_dtype) == 0)) ||
@@ -198,10 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) {
198203
constexpr uint32_t load_block_elems = block_elems * arr_len;
199204
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
200205
(i * num_block_x + j) * block_elems);
201-
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
202-
? detail::getNextPowerOf2<ld_blk_size_y>()
203-
: ld_blk_size_y;
204-
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
206+
constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205207
xetla_vector<dtype, tmp_size> reg_tmp;
206208
#pragma unroll
207209
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +215,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213215
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214216
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
215217
native_type_t<load_dtype>,
216-
(trans ? ld_blk_size_y : block_size_x) / scale_factor,
217-
(trans ? block_size_x : ld_blk_size_y),
218-
// block_size_x / scale_factor,
219-
// ld_blk_size_y,
218+
ld_blk_width / scale_factor,
219+
ld_blk_height,
220220
arr_len,
221221
trans,
222222
mem_transform,
@@ -261,11 +261,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261261
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262262
constexpr uint8_t block_height =
263263
mem_transpose ? block_size_x : remained_blk_size_y;
264-
// constexpr uint32_t block_widthx_widthy_arrlen =
265-
// (block_width - 1) | ((block_height - 1) << 8);
266-
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267-
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268-
269264
reg_blk.xetla_select<load_elems, 1>(remained_start)
270265
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
271266
native_type_t<load_dtype>,
@@ -283,15 +278,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283278
payload.surface_pitch,
284279
payload.offset_x + offset_x / scale_factor,
285280
payload.offset_y + offset_y + remained_start_y);
286-
287-
// xetla_tload_global<
288-
// load_dtype,
289-
// (load_elems / scale_factor),
290-
// L1,
291-
// L2,
292-
// trans,
293-
// mem_transform,
294-
// arch_tag>(tdesc);
295281
}
296282
}
297283
}
@@ -304,24 +290,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304290
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305291
? ld_blk_size_y_limit
306292
: remained_size_y;
307-
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308-
// num_block_y * num_block_x, 0);
309-
// detail::reset_tile_desc_core<
310-
// num_block_x,
311-
// block_size_x,
312-
// remained_ld_blk_size_y,
313-
// scale_factor,
314-
// arr_len,
315-
// mem_transpose>(payload_row);
293+
316294
#pragma unroll
317295
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
318296
int32_t offset_x = j * block_size_x;
319297
// xetla_tdescriptor tdesc = payload_row.row(j);
320298
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
321299
processed_elems + j * remained_block_elems);
322-
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323-
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324-
: remained_ld_blk_size_y;
300+
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301+
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302+
// : remained_ld_blk_size_y;
325303
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326304
xetla_vector<dtype, tmp_size> reg_tmp;
327305
#pragma unroll
@@ -490,7 +468,8 @@ tile_load(tile_t& tile, payload_t& payload) {
490468

491469
/// @brief This function loads data from unaligned-2D memory surface.
492470
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
493-
/// registers. Each block will be loaded serially by its corresponding payload.
471+
/// registers. Each block will be loaded serially by its corresponding
472+
/// payload.
494473
/// @tparam tile_t Is the tile_t struct contains registers.
495474
/// These registers will be the destination of load operation.
496475
/// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -614,7 +593,8 @@ tile_load(tile_t& tile, payload_t& payload) {
614593

615594
/// @brief This function loads data from unaligned-2D memory surface.
616595
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
617-
/// registers. Each block will be loaded serially by its corresponding payload.
596+
/// registers. Each block will be loaded serially by its corresponding
597+
/// payload.
618598
/// @tparam tile_t Is the tile_t struct contains registers.
619599
/// These registers will be the destination of load operation.
620600
/// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -679,7 +659,8 @@ tile_load(tile_t& tile, payload_t& payload) {
679659

680660
/// @brief This function loads data from unaligned-2D memory surface.
681661
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
682-
/// registers. Each block will be loaded serially by its corresponding payload.
662+
/// registers. Each block will be loaded serially by its corresponding
663+
/// payload.
683664
/// @tparam tile_t Is the tile_t struct contains registers.
684665
/// These registers will be the destination of load operation.
685666
/// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -819,8 +800,8 @@ tile_load(
819800
}
820801

821802
/// @brief Is the data load func from local shared memory to register file,
822-
/// which supports the memory surface is 1d or 2d scenario. And we always assume
823-
/// data in SLM is row major.
803+
/// which supports the memory surface is 1d or 2d scenario. And we always
804+
/// assume data in SLM is row major.
824805
/// @tparam tile_t Is the tile_t struct contains registers
825806
/// These registers will be the destination of load operation.
826807
/// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -902,8 +883,8 @@ tile_load(tile_t& tile, payload_t& payload) {
902883
}
903884

904885
/// @brief Is the data load func from shared local memory to register file,
905-
/// which supports the memory surface is 1d scenario. And the src memory layout
906-
/// is always row major.
886+
/// which supports the memory surface is 1d scenario. And the src memory
887+
/// layout is always row major.
907888
/// @tparam tile_t Is the tile_t struct contains registers.
908889
/// These registers will be the destination of load operation.
909890
/// @tparam payload_t Is the mem_payload_t struct describing the memory

tests/integration/gemm/fp32/main.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) {
3434

3535
REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd);
3636
using tests = ::testing::Types<
37-
// Test1,
38-
// Test2,
39-
// Test3,
40-
// Test4,
41-
// Test5,
42-
// Test6,
43-
// Test7,
44-
// Test8,
45-
// Test9,
37+
Test1,
38+
Test2,
39+
Test3,
40+
Test4,
41+
Test5,
42+
Test6,
43+
Test7,
44+
Test8,
45+
Test9,
4646
Test10,
4747
Test11>;
4848
INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);

0 commit comments

Comments
 (0)