Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit b1e3a00

Browse files
committed
opt store_xe.hpp
1 parent bba4180 commit b1e3a00

File tree

3 files changed

+70
-121
lines changed

3 files changed

+70
-121
lines changed

include/common/core/arch_config.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
4747
// BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords.
4848
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
4949

50+
// BlockHeight must be 8 for qwords and be in range [1..32] for dwords.
51+
static constexpr uint32_t max_trans_load_height_in_elem = 32;
52+
5053
// If Transformed is true
5154
// BlockWidth must be in range [4..16] for bytes and [2..16] for word.
5255
static constexpr uint32_t max_vnni_load_width_in_elems = 16;

include/subgroup/tile/impl/payload_xe.hpp

+18
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ struct mem_payload_t<
8484
using mem_dtype = typename std::
8585
conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
8686
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
87+
88+
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
89+
90+
static constexpr uint32_t max_load_width_in_elem = trans
91+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
92+
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
93+
static constexpr uint32_t max_load_height_in_elem = trans
94+
? load_store_attr::max_trans_load_height_in_elem
95+
: load_store_attr::max_load_height_in_elem;
96+
97+
static constexpr uint32_t max_store_width_in_elem =
98+
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
99+
static constexpr uint32_t max_store_height_in_elem =
100+
load_store_attr::max_store_height_in_elem;
101+
102+
static constexpr uint32_t elems_per_CL =
103+
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
104+
87105
mem_dtype* base_ptr;
88106
uint32_t surface_width;
89107
uint32_t surface_height;

include/subgroup/tile/impl/store_xe.hpp

+49-121
Original file line numberDiff line numberDiff line change
@@ -98,125 +98,85 @@ tile_store(tile_t& tile, payload_t& payload) {
9898

9999
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
100100
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
101-
// static constexpr uint32_t num_block = tile_desc::num_block;
102101

103-
using load_store_attr = typename arch_attr_t<
104-
payload_t::arch_tag>::template load_store_attr<msg_type::block_2d>;
105-
106-
static constexpr uint32_t max_block_width =
107-
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
108-
static constexpr uint32_t max_block_height =
109-
load_store_attr::max_store_height_in_elem;
110102
static_assert(
111-
(max_block_width % block_size_x) == 0,
112-
"max_block_width should be a multiply of block_size_x.");
113-
static constexpr uint32_t elems_per_CL =
114-
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
115-
static constexpr uint32_t st_block_size_y =
116-
std::min(block_size_y, max_block_height);
103+
(payload_t::max_store_width_in_elem % block_size_x) == 0,
104+
"max_store_width_in_elem should be a multiply of block_size_x.");
105+
106+
static constexpr uint32_t st_blk_size_y =
107+
std::min(block_size_y, payload_t::max_store_height_in_elem);
117108

118109
// to make sure full CL store
119-
static constexpr uint32_t st_block_size_x =
120-
((tile_size_x % elems_per_CL) == 0)
121-
? elems_per_CL
122-
: (((elems_per_CL % tile_size_x) == 0) ? tile_size_x : block_size_x);
110+
static constexpr uint32_t st_blk_size_x =
111+
((tile_size_x % payload_t::elems_per_CL) == 0)
112+
? payload_t::elems_per_CL
113+
: (((payload_t::elems_per_CL % tile_size_x) == 0) ? tile_size_x
114+
: block_size_x);
123115

124-
static constexpr uint8_t arr_len_candidate = st_block_size_x / block_size_x;
116+
static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x;
125117
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
126118
(arr_len_candidate == 2) || (arr_len_candidate == 4);
127119

128120
static constexpr uint8_t arr_len =
129121
is_valid_arr_len_candidate ? arr_len_candidate : 1;
130122

131-
// auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
123+
constexpr uint32_t store_block_elems = block_elems * arr_len;
124+
constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
132125
#pragma unroll
133126
for (uint32_t i = 0; i < num_block_y; ++i) {
134127
int32_t offset_y = i * block_size_y;
135-
constexpr uint32_t store_block_elems = block_elems * arr_len;
136-
// auto payload_row =
137-
// payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
138-
// detail::reset_tile_desc_core<
139-
// num_block_x,
140-
// block_size_x * arr_len,
141-
// st_block_size_y,
142-
// 1,
143-
// 1,
144-
// false>(payload_row);
145128
#pragma unroll
146129
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
147130
int32_t offset_x = j * block_size_x;
148-
// xetla_tdescriptor tdesc = payload_row.row(j);
149131
auto reg_blk = tile.reg.xetla_select<store_block_elems, 1>(
150132
(i * num_block_x + j) * block_elems);
151133
xetla_vector<dtype, store_block_elems> combine_blk;
152134
auto combine_blk_2d = combine_blk.xetla_format<
153135
native_type_t<dtype>,
154136
block_size_y,
155137
block_size_x * arr_len>();
156-
#pragma unroll
157-
for (uint32_t combine_i = 0; combine_i < arr_len; ++combine_i) {
138+
/* combine_blk_2d
139+
____________ ____________
140+
| || |
141+
| block || block |
142+
| || |
143+
|____________||____________|
144+
*/
145+
#pragma unroll
146+
for (uint32_t block_id = 0; block_id < arr_len; ++block_id) {
158147
combine_blk_2d.xetla_select<block_size_y, 1, block_size_x, 1>(
159-
0, combine_i * block_size_x) =
160-
reg_blk.xetla_select<block_elems, 1>(combine_i * block_elems);
148+
0, block_id * block_size_x) =
149+
reg_blk.xetla_select<block_elems, 1>(block_id * block_elems);
161150
}
162151
#pragma unroll
163-
for (uint32_t ii = 0; ii < block_size_y / st_block_size_y; ++ii) {
164-
constexpr uint32_t store_elems =
165-
st_block_size_y * block_size_x * arr_len;
152+
for (uint32_t ii = 0; ii < block_size_y; ii += st_blk_size_y) {
166153
auto st_blk =
167-
combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
168-
// xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
169-
// tdesc, st_blk);
170-
xetla_store_global<
171-
dtype,
172-
block_size_x * arr_len,
173-
st_block_size_y,
174-
L1,
175-
L2>(
154+
combine_blk.xetla_select<store_elems, 1>(ii * st_blk_size_x);
155+
xetla_store_global<dtype, st_blk_size_x, st_blk_size_y, L1, L2>(
176156
reinterpret_cast<dtype*>(payload.base_ptr),
177157
payload.surface_width,
178158
payload.surface_height,
179159
payload.surface_pitch,
180160
payload.offset_x + offset_x,
181-
payload.offset_y + offset_y + ii * st_block_size_y,
182-
// ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
183-
// ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc),
161+
payload.offset_y + offset_y + ii,
184162
st_blk);
185-
// xetla_update_tdesc_offsety(
186-
// tdesc.xetla_format<uint32_t>(), st_block_size_y);
187163
}
188164
// exceed hardware limitation
189-
if constexpr ((block_size_y % st_block_size_y) != 0) {
190-
constexpr uint32_t blk_remained_start = block_size_y / st_block_size_y *
191-
st_block_size_y * block_size_x * arr_len;
192-
constexpr uint8_t blk_remained_y = block_size_y % st_block_size_y;
193-
constexpr uint8_t blk_remained_elems =
194-
blk_remained_y * block_size_x * arr_len;
165+
if constexpr ((block_size_y % st_blk_size_y) != 0) {
166+
constexpr uint32_t blk_remained_start =
167+
block_size_y / st_blk_size_y * st_blk_size_y * st_blk_size_x;
168+
constexpr uint8_t blk_remained_y = block_size_y % st_blk_size_y;
169+
constexpr uint8_t blk_remained_elems = blk_remained_y * st_blk_size_x;
195170
auto st_blk =
196171
combine_blk.xetla_select<blk_remained_elems, 1>(blk_remained_start);
197-
// constexpr uint32_t block_widthx_widthy_arrlen =
198-
// (block_size_x * arr_len - 1) | ((blk_remained_y - 1) << 8);
199-
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
200-
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
201-
// xetla_tstore_global<
202-
// dtype,
203-
// blk_remained_elems,
204-
// L1,
205-
// L2,
206-
// payload_t::arch_tag>(tdesc, st_blk);
207-
xetla_store_global<
208-
dtype,
209-
block_size_x * arr_len,
210-
blk_remained_y,
211-
L1,
212-
L2>(
172+
xetla_store_global<dtype, st_blk_size_x, blk_remained_y, L1, L2>(
213173
reinterpret_cast<dtype*>(payload.base_ptr),
214174
payload.surface_width,
215175
payload.surface_height,
216176
payload.surface_pitch,
217177
payload.offset_x + offset_x,
218178
payload.offset_y + offset_y +
219-
block_size_y / st_block_size_y * st_block_size_y,
179+
block_size_y / st_blk_size_y * st_blk_size_y,
220180
st_blk);
221181
}
222182
}
@@ -227,47 +187,34 @@ tile_store(tile_t& tile, payload_t& payload) {
227187
constexpr uint32_t processed_elems =
228188
num_block_y * num_block_x * block_elems;
229189
constexpr uint32_t remained_st_blk_size_y =
230-
st_block_size_y > remained_size_y ? remained_size_y : st_block_size_y;
231-
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
232-
// num_block_y * num_block_x, 0);
233-
// detail::reset_tile_desc_core<
234-
// num_block_x,
235-
// block_size_x * arr_len,
236-
// remained_st_blk_size_y,
237-
// 1,
238-
// 1,
239-
// false>(payload_row);
190+
std::min(st_blk_size_y, remained_size_y);
240191
#pragma unroll
241192
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
242193
int offset_x = j * block_size_x;
243-
// xetla_tdescriptor tdesc = payload_row.row(j);
244194
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
245195
processed_elems + j * remained_block_elems);
246196
// Do combination
247197
xetla_vector<dtype, remained_block_elems * arr_len> combine_blk;
248198
auto combine_blk_2d = combine_blk.xetla_format<
249199
native_type_t<dtype>,
250200
remained_size_y,
251-
block_size_x * arr_len>();
201+
st_blk_size_x>();
252202
#pragma unroll
253-
for (uint32_t combine_i = 0; combine_i < arr_len; ++combine_i) {
203+
for (uint32_t block_id = 0; block_id < arr_len; ++block_id) {
254204
combine_blk_2d.xetla_select<remained_size_y, 1, block_size_x, 1>(
255-
0, combine_i * block_size_x) =
205+
0, block_id * block_size_x) =
256206
reg_blk.xetla_select<remained_block_elems, 1>(
257-
combine_i * remained_block_elems);
207+
block_id * remained_block_elems);
258208
}
259209
#pragma unroll
260-
for (uint32_t ii = 0; ii < remained_size_y / remained_st_blk_size_y;
261-
++ii) {
262-
constexpr uint32_t store_elems =
263-
remained_st_blk_size_y * block_size_x * arr_len;
210+
for (uint32_t ii = 0; ii < remained_size_y;
211+
ii += remained_st_blk_size_y) {
212+
constexpr uint32_t store_elems = remained_st_blk_size_y * st_blk_size_x;
264213
auto st_blk =
265-
combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
266-
// xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
267-
// tdesc, st_blk);
214+
combine_blk.xetla_select<store_elems, 1>(ii * st_blk_size_x);
268215
xetla_store_global<
269216
dtype,
270-
block_size_x * arr_len,
217+
st_blk_size_x,
271218
remained_st_blk_size_y,
272219
L1,
273220
L2>(
@@ -276,38 +223,19 @@ tile_store(tile_t& tile, payload_t& payload) {
276223
payload.surface_height,
277224
payload.surface_pitch,
278225
payload.offset_x + offset_x,
279-
payload.offset_y + num_block_y * block_size_y +
280-
ii * remained_st_blk_size_y,
226+
payload.offset_y + num_block_y * block_size_y + ii,
281227
st_blk);
282-
// xetla_update_tdesc_offsety(
283-
// tdesc.xetla_format<uint32_t>(), remained_st_blk_size_y);
284228
}
285229
constexpr uint32_t final_st_blk_size_y =
286230
remained_size_y % remained_st_blk_size_y;
287231
if constexpr (final_st_blk_size_y != 0) {
288232
constexpr uint32_t final_start = remained_size_y /
289-
remained_st_blk_size_y * remained_st_blk_size_y * block_size_x *
290-
arr_len;
233+
remained_st_blk_size_y * remained_st_blk_size_y * st_blk_size_x;
291234
constexpr uint32_t final_store_elems =
292-
final_st_blk_size_y * block_size_x * arr_len;
235+
final_st_blk_size_y * st_blk_size_x;
293236
auto st_blk =
294237
combine_blk.xetla_select<final_store_elems, 1>(final_start);
295-
// constexpr uint32_t block_widthx_widthy_arrlen =
296-
// (block_size_x * arr_len - 1) | ((final_st_blk_size_y - 1) << 8);
297-
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
298-
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
299-
// xetla_tstore_global<
300-
// dtype,
301-
// final_store_elems,
302-
// L1,
303-
// L2,
304-
// payload_t::arch_tag>(tdesc, st_blk);
305-
xetla_store_global<
306-
dtype,
307-
block_size_x * arr_len,
308-
final_st_blk_size_y,
309-
L1,
310-
L2>(
238+
xetla_store_global<dtype, st_blk_size_x, final_st_blk_size_y, L1, L2>(
311239
reinterpret_cast<dtype*>(payload.base_ptr),
312240
payload.surface_width,
313241
payload.surface_height,

0 commit comments

Comments
 (0)