@@ -114,27 +114,45 @@ tile_load(tile_t& tile, payload_t& payload) {
114
114
// static constexpr uint32_t max_load_height_in_elem = trans
115
115
// ? load_store_attr::max_trans_load_height_in_elem
116
116
// : load_store_attr::max_load_height_in_elem;
117
- static constexpr uint32_t max_trans_load_width_in_elem =
118
- load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype);
119
- static constexpr uint32_t max_load_width_in_elem =
120
- load_store_attr::max_load_width_in_bytes / sizeof (dtype);
117
+ // static constexpr uint32_t max_trans_load_width_in_elem =
118
+ // load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119
+ // static constexpr uint32_t max_load_width_in_elem =
120
+ // load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121
121
122
122
// static constexpr uint32_t max_trans_load_height_in_elem =
123
123
// load_store_attr::max_trans_load_height_in_elem;
124
- static constexpr uint32_t max_load_height_in_elem =
125
- load_store_attr::max_load_height_in_elem;
124
+
125
+ // static constexpr uint32_t max_load_height_in_elem =
126
+ // load_store_attr::max_load_height_in_elem;
126
127
127
128
static constexpr uint32_t elems_per_CL =
128
129
load_store_attr::cache_line_size_in_bytes / sizeof (dtype);
129
130
130
131
static constexpr uint32_t elems_per_reg =
131
132
register_bytes_t <arch_tag>::reg_in_bytes / sizeof (dtype);
132
133
133
- static constexpr uint32_t ld_blk_size_y_limit =
134
- mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135
- static constexpr uint32_t ld_blk_size_y = reg_transpose
136
- ? block_size_y
137
- : std::min (ld_blk_size_y_limit, block_size_y);
134
+ static constexpr uint32_t max_ld_blk_width_in_elem = trans
135
+ ? load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype)
136
+ : load_store_attr::max_load_width_in_bytes / sizeof (dtype);
137
+
138
+ static constexpr uint32_t max_ld_blk_height_in_elem = trans
139
+ ? load_store_attr::max_trans_load_height_in_elem
140
+ : load_store_attr::max_load_height_in_elem;
141
+
142
+ static constexpr uint32_t ld_blk_width = std::min (
143
+ (mem_transpose ? block_size_y : block_size_x), max_ld_blk_width_in_elem);
144
+ static constexpr uint32_t ld_blk_height = std::min (
145
+ (mem_transpose ? block_size_x : block_size_y), max_ld_blk_height_in_elem);
146
+
147
+ static constexpr uint32_t ld_blk_size_y =
148
+ mem_transpose ? ld_blk_width : ld_blk_height;
149
+
150
+ static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
151
+ ? load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype)
152
+ : load_store_attr::max_load_height_in_elem;
153
+ // static constexpr uint32_t ld_blk_size_y = reg_transpose
154
+ // ? block_size_y
155
+ // : std::min(ld_blk_size_y_limit, block_size_y);
138
156
139
157
// array len is used to make sure memory load is cache line aligned
140
158
// disabled while register or memory transpose
@@ -198,10 +216,7 @@ tile_load(tile_t& tile, payload_t& payload) {
198
216
constexpr uint32_t load_block_elems = block_elems * arr_len;
199
217
auto reg_blk = tile.reg .xetla_select <load_block_elems, 1 >(
200
218
(i * num_block_x + j) * block_elems);
201
- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
202
- ? detail::getNextPowerOf2<ld_blk_size_y>()
203
- : ld_blk_size_y;
204
- constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
219
+ constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205
220
xetla_vector<dtype, tmp_size> reg_tmp;
206
221
#pragma unroll
207
222
for (uint32_t ii = 0 ; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +228,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213
228
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214
229
reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
215
230
native_type_t <load_dtype>,
216
- (trans ? ld_blk_size_y : block_size_x) / scale_factor,
217
- (trans ? block_size_x : ld_blk_size_y),
218
- // block_size_x / scale_factor,
219
- // ld_blk_size_y,
231
+ ld_blk_width / scale_factor,
232
+ ld_blk_height,
220
233
arr_len,
221
234
trans,
222
235
mem_transform,
@@ -261,11 +274,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261
274
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262
275
constexpr uint8_t block_height =
263
276
mem_transpose ? block_size_x : remained_blk_size_y;
264
- // constexpr uint32_t block_widthx_widthy_arrlen =
265
- // (block_width - 1) | ((block_height - 1) << 8);
266
- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267
- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268
-
269
277
reg_blk.xetla_select <load_elems, 1 >(remained_start)
270
278
.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
271
279
native_type_t <load_dtype>,
@@ -283,15 +291,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283
291
payload.surface_pitch ,
284
292
payload.offset_x + offset_x / scale_factor,
285
293
payload.offset_y + offset_y + remained_start_y);
286
-
287
- // xetla_tload_global<
288
- // load_dtype,
289
- // (load_elems / scale_factor),
290
- // L1,
291
- // L2,
292
- // trans,
293
- // mem_transform,
294
- // arch_tag>(tdesc);
295
294
}
296
295
}
297
296
}
@@ -304,24 +303,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304
303
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305
304
? ld_blk_size_y_limit
306
305
: remained_size_y;
307
- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308
- // num_block_y * num_block_x, 0);
309
- // detail::reset_tile_desc_core<
310
- // num_block_x,
311
- // block_size_x,
312
- // remained_ld_blk_size_y,
313
- // scale_factor,
314
- // arr_len,
315
- // mem_transpose>(payload_row);
306
+
316
307
#pragma unroll
317
308
for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
318
309
int32_t offset_x = j * block_size_x;
319
310
// xetla_tdescriptor tdesc = payload_row.row(j);
320
311
auto reg_blk = tile.reg .xetla_select <remained_block_elems * arr_len, 1 >(
321
312
processed_elems + j * remained_block_elems);
322
- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323
- ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324
- : remained_ld_blk_size_y;
313
+ // constexpr uint32_t ld_blk_height = (reg_transpose && trans)
314
+ // ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
315
+ // : remained_ld_blk_size_y;
325
316
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326
317
xetla_vector<dtype, tmp_size> reg_tmp;
327
318
#pragma unroll
0 commit comments