@@ -108,81 +108,86 @@ tile_load(tile_t& tile, payload_t& payload) {
108
108
109
109
using load_store_attr = load_store_attr_t <msg_type::block_2d, arch_tag>;
110
110
111
- // static constexpr uint32_t max_load_width_in_elem = trans
112
- // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
113
- // : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
111
+ // static constexpr uint32_t max_load_width_in_elem = trans
112
+ // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
113
+ // : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
114
114
// static constexpr uint32_t max_load_height_in_elem = trans
115
115
// ? load_store_attr::max_trans_load_height_in_elem
116
116
// : load_store_attr::max_load_height_in_elem;
117
- static constexpr uint32_t max_trans_load_width_in_elem =
118
- load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype);
119
- static constexpr uint32_t max_load_width_in_elem =
120
- load_store_attr::max_load_width_in_bytes / sizeof (dtype);
117
+ // static constexpr uint32_t max_trans_load_width_in_elem =
118
+ // load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119
+ // static constexpr uint32_t max_load_width_in_elem =
120
+ // load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121
121
122
122
// static constexpr uint32_t max_trans_load_height_in_elem =
123
123
// load_store_attr::max_trans_load_height_in_elem;
124
- static constexpr uint32_t max_load_height_in_elem =
125
- load_store_attr::max_load_height_in_elem;
124
+
125
+ // static constexpr uint32_t max_load_height_in_elem =
126
+ // load_store_attr::max_load_height_in_elem;
126
127
127
128
static constexpr uint32_t elems_per_CL =
128
129
load_store_attr::cache_line_size_in_bytes / sizeof (dtype);
129
130
130
131
static constexpr uint32_t elems_per_reg =
131
132
register_bytes_t <arch_tag>::reg_in_bytes / sizeof (dtype);
132
133
133
- static constexpr uint32_t ld_blk_size_y_limit =
134
- mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135
- static constexpr uint32_t ld_blk_size_y = reg_transpose
136
- ? block_size_y
137
- : std::min (ld_blk_size_y_limit, block_size_y);
134
+ static constexpr uint32_t max_load_width_in_elem = trans
135
+ ? load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype)
136
+ : load_store_attr::max_load_width_in_bytes / sizeof (dtype);
137
+
138
+ static constexpr uint32_t max_load_blk_height_in_elem = trans
139
+ ? load_store_attr::max_trans_load_height_in_elem
140
+ : load_store_attr::max_load_height_in_elem;
141
+
142
+ static constexpr uint32_t ld_blk_width = std::min (
143
+ (mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem);
144
+
145
+ static constexpr uint32_t ld_blk_height = std::min (
146
+ (mem_transpose ? block_size_x : block_size_y),
147
+ max_load_blk_height_in_elem);
148
+
149
+ static constexpr uint32_t ld_blk_size_y =
150
+ mem_transpose ? ld_blk_width : ld_blk_height;
151
+
152
+ static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
153
+ ? load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype)
154
+ : load_store_attr::max_load_height_in_elem;
138
155
139
156
// array len is used to make sure memory load is cache line aligned
140
157
// disabled while register or memory transpose
141
158
static constexpr uint8_t arr_len_candidate =
142
- (reg_transpose ||
143
- mem_transpose
159
+ ((reg_transpose || mem_transpose)
144
160
// block elements should be integer
145
161
// times of register bytes
146
- || ((block_size_y * block_size_x ) % elems_per_reg != 0 )
162
+ || ((block_elems ) % elems_per_reg != 0 )
147
163
// tail blocks also need to meet above condition
148
- ||
149
- (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0 )) ||
150
- (block_size_y > ld_blk_size_y_limit)
164
+ || (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0 ))
165
+ // || (block_size_y > load_store_attr::max_load_height_in_elem)
151
166
? 1
152
167
: (((tile_size_x % elems_per_CL) == 0 )
153
168
? (((elems_per_CL % block_size_x) == 0 )
154
169
? elems_per_CL / block_size_x
155
170
: 1 )
156
171
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
157
172
: 1 ));
158
- static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1 ) ||
159
- (arr_len_candidate == 2 ) || (arr_len_candidate == 4 );
160
-
161
- static constexpr uint8_t arr_len =
162
- is_valid_arr_len_candidate ? arr_len_candidate : 1 ;
163
-
164
- static_assert (
165
- reg_transpose || mem_transpose ||
166
- (!mem_transpose &&
167
- (block_size_x * arr_len) <= max_load_width_in_elem),
168
- " When reg_transpose was disabled, check 2d block width "
169
- " restriction" );
170
- static_assert (
171
- !reg_transpose ||
172
- (!mem_transpose &&
173
- (block_size_x * arr_len) <= max_trans_load_width_in_elem) ||
174
- (mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem),
175
- " When reg_transpose was enabled, check 2d block width "
176
- " restriction" );
177
- static_assert (
178
- !reg_transpose ||
179
- (!mem_transpose && (block_size_y <= max_load_height_in_elem)) ||
180
- (mem_transpose && (block_size_x) <= max_load_height_in_elem),
181
- " When reg_transpose was enabled, check 2d block height "
182
- " restriction" );
183
- static_assert (
184
- tile_size_x % (block_size_x * arr_len) == 0 ,
185
- " tile_size_x should be a multiple of (block_size_x * arr_len)" );
173
+ // NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for
174
+ // qwords.
175
+ static constexpr bool arr_len =
176
+ ((arr_len_candidate == 1 ) ||
177
+ (arr_len_candidate == 2 && sizeof (dtype) <= 4 ) ||
178
+ (arr_len_candidate == 4 && sizeof (dtype) <= 2 ))
179
+ ? arr_len_candidate
180
+ : 1 ;
181
+
182
+ if constexpr (!trans && !mem_transform) {
183
+ static_assert (
184
+ (ld_blk_width * arr_len) <= max_load_width_in_elem,
185
+ " When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords" );
186
+ } else if constexpr (mem_transform) {
187
+ static_assert (
188
+ (ld_blk_width * arr_len) <= max_load_width_in_elem,
189
+ " When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words." );
190
+ }
186
191
static_assert (
187
192
(reg_transpose &&
188
193
((block_size_x * sizeof (dtype)) % sizeof (load_dtype) == 0 )) ||
@@ -198,10 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) {
198
203
constexpr uint32_t load_block_elems = block_elems * arr_len;
199
204
auto reg_blk = tile.reg .xetla_select <load_block_elems, 1 >(
200
205
(i * num_block_x + j) * block_elems);
201
- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
202
- ? detail::getNextPowerOf2<ld_blk_size_y>()
203
- : ld_blk_size_y;
204
- constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
206
+ constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205
207
xetla_vector<dtype, tmp_size> reg_tmp;
206
208
#pragma unroll
207
209
for (uint32_t ii = 0 ; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +215,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213
215
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214
216
reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
215
217
native_type_t <load_dtype>,
216
- (trans ? ld_blk_size_y : block_size_x) / scale_factor,
217
- (trans ? block_size_x : ld_blk_size_y),
218
- // block_size_x / scale_factor,
219
- // ld_blk_size_y,
218
+ ld_blk_width / scale_factor,
219
+ ld_blk_height,
220
220
arr_len,
221
221
trans,
222
222
mem_transform,
@@ -261,11 +261,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261
261
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262
262
constexpr uint8_t block_height =
263
263
mem_transpose ? block_size_x : remained_blk_size_y;
264
- // constexpr uint32_t block_widthx_widthy_arrlen =
265
- // (block_width - 1) | ((block_height - 1) << 8);
266
- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267
- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268
-
269
264
reg_blk.xetla_select <load_elems, 1 >(remained_start)
270
265
.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
271
266
native_type_t <load_dtype>,
@@ -283,15 +278,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283
278
payload.surface_pitch ,
284
279
payload.offset_x + offset_x / scale_factor,
285
280
payload.offset_y + offset_y + remained_start_y);
286
-
287
- // xetla_tload_global<
288
- // load_dtype,
289
- // (load_elems / scale_factor),
290
- // L1,
291
- // L2,
292
- // trans,
293
- // mem_transform,
294
- // arch_tag>(tdesc);
295
281
}
296
282
}
297
283
}
@@ -304,24 +290,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304
290
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305
291
? ld_blk_size_y_limit
306
292
: remained_size_y;
307
- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308
- // num_block_y * num_block_x, 0);
309
- // detail::reset_tile_desc_core<
310
- // num_block_x,
311
- // block_size_x,
312
- // remained_ld_blk_size_y,
313
- // scale_factor,
314
- // arr_len,
315
- // mem_transpose>(payload_row);
293
+
316
294
#pragma unroll
317
295
for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
318
296
int32_t offset_x = j * block_size_x;
319
297
// xetla_tdescriptor tdesc = payload_row.row(j);
320
298
auto reg_blk = tile.reg .xetla_select <remained_block_elems * arr_len, 1 >(
321
299
processed_elems + j * remained_block_elems);
322
- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323
- ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324
- : remained_ld_blk_size_y;
300
+ // constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301
+ // ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302
+ // : remained_ld_blk_size_y;
325
303
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326
304
xetla_vector<dtype, tmp_size> reg_tmp;
327
305
#pragma unroll
@@ -490,7 +468,8 @@ tile_load(tile_t& tile, payload_t& payload) {
490
468
491
469
// / @brief This function loads data from unaligned-2D memory surface.
492
470
// / Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
493
- // / registers. Each block will be loaded serially by its corresponding payload.
471
+ // / registers. Each block will be loaded serially by its corresponding
472
+ // / payload.
494
473
// / @tparam tile_t Is the tile_t struct contains registers.
495
474
// / These registers will be the destination of load operation.
496
475
// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -614,7 +593,8 @@ tile_load(tile_t& tile, payload_t& payload) {
614
593
615
594
// / @brief This function loads data from unaligned-2D memory surface.
616
595
// / Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
617
- // / registers. Each block will be loaded serially by its corresponding payload.
596
+ // / registers. Each block will be loaded serially by its corresponding
597
+ // / payload.
618
598
// / @tparam tile_t Is the tile_t struct contains registers.
619
599
// / These registers will be the destination of load operation.
620
600
// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -679,7 +659,8 @@ tile_load(tile_t& tile, payload_t& payload) {
679
659
680
660
// / @brief This function loads data from unaligned-2D memory surface.
681
661
// / Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
682
- // / registers. Each block will be loaded serially by its corresponding payload.
662
+ // / registers. Each block will be loaded serially by its corresponding
663
+ // / payload.
683
664
// / @tparam tile_t Is the tile_t struct contains registers.
684
665
// / These registers will be the destination of load operation.
685
666
// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -819,8 +800,8 @@ tile_load(
819
800
}
820
801
821
802
// / @brief Is the data load func from local shared memory to register file,
822
- // / which supports the memory surface is 1d or 2d scenario. And we always assume
823
- // / data in SLM is row major.
803
+ // / which supports the memory surface is 1d or 2d scenario. And we always
804
+ // / assume data in SLM is row major.
824
805
// / @tparam tile_t Is the tile_t struct contains registers
825
806
// / These registers will be the destination of load operation.
826
807
// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -902,8 +883,8 @@ tile_load(tile_t& tile, payload_t& payload) {
902
883
}
903
884
904
885
// / @brief Is the data load func from shared local memory to register file,
905
- // / which supports the memory surface is 1d scenario. And the src memory layout
906
- // / is always row major.
886
+ // / which supports the memory surface is 1d scenario. And the src memory
887
+ // / layout is always row major.
907
888
// / @tparam tile_t Is the tile_t struct contains registers.
908
889
// / These registers will be the destination of load operation.
909
890
// / @tparam payload_t Is the mem_payload_t struct describing the memory
0 commit comments