@@ -98,125 +98,85 @@ tile_store(tile_t& tile, payload_t& payload) {
98
98
99
99
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
100
100
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
101
- // static constexpr uint32_t num_block = tile_desc::num_block;
102
101
103
- using load_store_attr = typename arch_attr_t <
104
- payload_t ::arch_tag>::template load_store_attr<msg_type::block_2d>;
105
-
106
- static constexpr uint32_t max_block_width =
107
- load_store_attr::max_store_width_in_bytes / sizeof (dtype);
108
- static constexpr uint32_t max_block_height =
109
- load_store_attr::max_store_height_in_elem;
110
102
static_assert (
111
- (max_block_width % block_size_x) == 0 ,
112
- " max_block_width should be a multiply of block_size_x." );
113
- static constexpr uint32_t elems_per_CL =
114
- load_store_attr::cache_line_size_in_bytes / sizeof (dtype);
115
- static constexpr uint32_t st_block_size_y =
116
- std::min (block_size_y, max_block_height);
103
+ (payload_t ::max_store_width_in_elem % block_size_x) == 0 ,
104
+ " max_store_width_in_elem should be a multiply of block_size_x." );
105
+
106
+ static constexpr uint32_t st_blk_size_y =
107
+ std::min (block_size_y, payload_t ::max_store_height_in_elem);
117
108
118
109
// to make sure full CL store
119
- static constexpr uint32_t st_block_size_x =
120
- ((tile_size_x % elems_per_CL) == 0 )
121
- ? elems_per_CL
122
- : (((elems_per_CL % tile_size_x) == 0 ) ? tile_size_x : block_size_x);
110
+ static constexpr uint32_t st_blk_size_x =
111
+ ((tile_size_x % payload_t ::elems_per_CL) == 0 )
112
+ ? payload_t ::elems_per_CL
113
+ : (((payload_t ::elems_per_CL % tile_size_x) == 0 ) ? tile_size_x
114
+ : block_size_x);
123
115
124
- static constexpr uint8_t arr_len_candidate = st_block_size_x / block_size_x;
116
+ static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x;
125
117
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1 ) ||
126
118
(arr_len_candidate == 2 ) || (arr_len_candidate == 4 );
127
119
128
120
static constexpr uint8_t arr_len =
129
121
is_valid_arr_len_candidate ? arr_len_candidate : 1 ;
130
122
131
- // auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
123
+ constexpr uint32_t store_block_elems = block_elems * arr_len;
124
+ constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
132
125
#pragma unroll
133
126
for (uint32_t i = 0 ; i < num_block_y; ++i) {
134
127
int32_t offset_y = i * block_size_y;
135
- constexpr uint32_t store_block_elems = block_elems * arr_len;
136
- // auto payload_row =
137
- // payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
138
- // detail::reset_tile_desc_core<
139
- // num_block_x,
140
- // block_size_x * arr_len,
141
- // st_block_size_y,
142
- // 1,
143
- // 1,
144
- // false>(payload_row);
145
128
#pragma unroll
146
129
for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
147
130
int32_t offset_x = j * block_size_x;
148
- // xetla_tdescriptor tdesc = payload_row.row(j);
149
131
auto reg_blk = tile.reg .xetla_select <store_block_elems, 1 >(
150
132
(i * num_block_x + j) * block_elems);
151
133
xetla_vector<dtype, store_block_elems> combine_blk;
152
134
auto combine_blk_2d = combine_blk.xetla_format <
153
135
native_type_t <dtype>,
154
136
block_size_y,
155
137
block_size_x * arr_len>();
156
- #pragma unroll
157
- for (uint32_t combine_i = 0 ; combine_i < arr_len; ++combine_i) {
138
+ /* combine_blk_2d
139
+ ____________ ____________
140
+ | || |
141
+ | block || block |
142
+ | || |
143
+ |____________||____________|
144
+ */
145
+ #pragma unroll
146
+ for (uint32_t block_id = 0 ; block_id < arr_len; ++block_id) {
158
147
combine_blk_2d.xetla_select <block_size_y, 1 , block_size_x, 1 >(
159
- 0 , combine_i * block_size_x) =
160
- reg_blk.xetla_select <block_elems, 1 >(combine_i * block_elems);
148
+ 0 , block_id * block_size_x) =
149
+ reg_blk.xetla_select <block_elems, 1 >(block_id * block_elems);
161
150
}
162
151
#pragma unroll
163
- for (uint32_t ii = 0 ; ii < block_size_y / st_block_size_y; ++ii) {
164
- constexpr uint32_t store_elems =
165
- st_block_size_y * block_size_x * arr_len;
152
+ for (uint32_t ii = 0 ; ii < block_size_y; ii += st_blk_size_y) {
166
153
auto st_blk =
167
- combine_blk.xetla_select <store_elems, 1 >(ii * store_elems);
168
- // xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
169
- // tdesc, st_blk);
170
- xetla_store_global<
171
- dtype,
172
- block_size_x * arr_len,
173
- st_block_size_y,
174
- L1,
175
- L2>(
154
+ combine_blk.xetla_select <store_elems, 1 >(ii * st_blk_size_x);
155
+ xetla_store_global<dtype, st_blk_size_x, st_blk_size_y, L1, L2>(
176
156
reinterpret_cast <dtype*>(payload.base_ptr ),
177
157
payload.surface_width ,
178
158
payload.surface_height ,
179
159
payload.surface_pitch ,
180
160
payload.offset_x + offset_x,
181
- payload.offset_y + offset_y + ii * st_block_size_y,
182
- // ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
183
- // ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc),
161
+ payload.offset_y + offset_y + ii,
184
162
st_blk);
185
- // xetla_update_tdesc_offsety(
186
- // tdesc.xetla_format<uint32_t>(), st_block_size_y);
187
163
}
188
164
// exceed hardware limitation
189
- if constexpr ((block_size_y % st_block_size_y) != 0 ) {
190
- constexpr uint32_t blk_remained_start = block_size_y / st_block_size_y *
191
- st_block_size_y * block_size_x * arr_len;
192
- constexpr uint8_t blk_remained_y = block_size_y % st_block_size_y;
193
- constexpr uint8_t blk_remained_elems =
194
- blk_remained_y * block_size_x * arr_len;
165
+ if constexpr ((block_size_y % st_blk_size_y) != 0 ) {
166
+ constexpr uint32_t blk_remained_start =
167
+ block_size_y / st_blk_size_y * st_blk_size_y * st_blk_size_x;
168
+ constexpr uint8_t blk_remained_y = block_size_y % st_blk_size_y;
169
+ constexpr uint8_t blk_remained_elems = blk_remained_y * st_blk_size_x;
195
170
auto st_blk =
196
171
combine_blk.xetla_select <blk_remained_elems, 1 >(blk_remained_start);
197
- // constexpr uint32_t block_widthx_widthy_arrlen =
198
- // (block_size_x * arr_len - 1) | ((blk_remained_y - 1) << 8);
199
- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
200
- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
201
- // xetla_tstore_global<
202
- // dtype,
203
- // blk_remained_elems,
204
- // L1,
205
- // L2,
206
- // payload_t::arch_tag>(tdesc, st_blk);
207
- xetla_store_global<
208
- dtype,
209
- block_size_x * arr_len,
210
- blk_remained_y,
211
- L1,
212
- L2>(
172
+ xetla_store_global<dtype, st_blk_size_x, blk_remained_y, L1, L2>(
213
173
reinterpret_cast <dtype*>(payload.base_ptr ),
214
174
payload.surface_width ,
215
175
payload.surface_height ,
216
176
payload.surface_pitch ,
217
177
payload.offset_x + offset_x,
218
178
payload.offset_y + offset_y +
219
- block_size_y / st_block_size_y * st_block_size_y ,
179
+ block_size_y / st_blk_size_y * st_blk_size_y ,
220
180
st_blk);
221
181
}
222
182
}
@@ -227,47 +187,34 @@ tile_store(tile_t& tile, payload_t& payload) {
227
187
constexpr uint32_t processed_elems =
228
188
num_block_y * num_block_x * block_elems;
229
189
constexpr uint32_t remained_st_blk_size_y =
230
- st_block_size_y > remained_size_y ? remained_size_y : st_block_size_y;
231
- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
232
- // num_block_y * num_block_x, 0);
233
- // detail::reset_tile_desc_core<
234
- // num_block_x,
235
- // block_size_x * arr_len,
236
- // remained_st_blk_size_y,
237
- // 1,
238
- // 1,
239
- // false>(payload_row);
190
+ std::min (st_blk_size_y, remained_size_y);
240
191
#pragma unroll
241
192
for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
242
193
int offset_x = j * block_size_x;
243
- // xetla_tdescriptor tdesc = payload_row.row(j);
244
194
auto reg_blk = tile.reg .xetla_select <remained_block_elems * arr_len, 1 >(
245
195
processed_elems + j * remained_block_elems);
246
196
// Do combination
247
197
xetla_vector<dtype, remained_block_elems * arr_len> combine_blk;
248
198
auto combine_blk_2d = combine_blk.xetla_format <
249
199
native_type_t <dtype>,
250
200
remained_size_y,
251
- block_size_x * arr_len >();
201
+ st_blk_size_x >();
252
202
#pragma unroll
253
- for (uint32_t combine_i = 0 ; combine_i < arr_len; ++combine_i ) {
203
+ for (uint32_t block_id = 0 ; block_id < arr_len; ++block_id ) {
254
204
combine_blk_2d.xetla_select <remained_size_y, 1 , block_size_x, 1 >(
255
- 0 , combine_i * block_size_x) =
205
+ 0 , block_id * block_size_x) =
256
206
reg_blk.xetla_select <remained_block_elems, 1 >(
257
- combine_i * remained_block_elems);
207
+ block_id * remained_block_elems);
258
208
}
259
209
#pragma unroll
260
- for (uint32_t ii = 0 ; ii < remained_size_y / remained_st_blk_size_y;
261
- ++ii) {
262
- constexpr uint32_t store_elems =
263
- remained_st_blk_size_y * block_size_x * arr_len;
210
+ for (uint32_t ii = 0 ; ii < remained_size_y;
211
+ ii += remained_st_blk_size_y) {
212
+ constexpr uint32_t store_elems = remained_st_blk_size_y * st_blk_size_x;
264
213
auto st_blk =
265
- combine_blk.xetla_select <store_elems, 1 >(ii * store_elems);
266
- // xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
267
- // tdesc, st_blk);
214
+ combine_blk.xetla_select <store_elems, 1 >(ii * st_blk_size_x);
268
215
xetla_store_global<
269
216
dtype,
270
- block_size_x * arr_len ,
217
+ st_blk_size_x ,
271
218
remained_st_blk_size_y,
272
219
L1,
273
220
L2>(
@@ -276,38 +223,19 @@ tile_store(tile_t& tile, payload_t& payload) {
276
223
payload.surface_height ,
277
224
payload.surface_pitch ,
278
225
payload.offset_x + offset_x,
279
- payload.offset_y + num_block_y * block_size_y +
280
- ii * remained_st_blk_size_y,
226
+ payload.offset_y + num_block_y * block_size_y + ii,
281
227
st_blk);
282
- // xetla_update_tdesc_offsety(
283
- // tdesc.xetla_format<uint32_t>(), remained_st_blk_size_y);
284
228
}
285
229
constexpr uint32_t final_st_blk_size_y =
286
230
remained_size_y % remained_st_blk_size_y;
287
231
if constexpr (final_st_blk_size_y != 0 ) {
288
232
constexpr uint32_t final_start = remained_size_y /
289
- remained_st_blk_size_y * remained_st_blk_size_y * block_size_x *
290
- arr_len;
233
+ remained_st_blk_size_y * remained_st_blk_size_y * st_blk_size_x;
291
234
constexpr uint32_t final_store_elems =
292
- final_st_blk_size_y * block_size_x * arr_len ;
235
+ final_st_blk_size_y * st_blk_size_x ;
293
236
auto st_blk =
294
237
combine_blk.xetla_select <final_store_elems, 1 >(final_start);
295
- // constexpr uint32_t block_widthx_widthy_arrlen =
296
- // (block_size_x * arr_len - 1) | ((final_st_blk_size_y - 1) << 8);
297
- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
298
- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
299
- // xetla_tstore_global<
300
- // dtype,
301
- // final_store_elems,
302
- // L1,
303
- // L2,
304
- // payload_t::arch_tag>(tdesc, st_blk);
305
- xetla_store_global<
306
- dtype,
307
- block_size_x * arr_len,
308
- final_st_blk_size_y,
309
- L1,
310
- L2>(
238
+ xetla_store_global<dtype, st_blk_size_x, final_st_blk_size_y, L1, L2>(
311
239
reinterpret_cast <dtype*>(payload.base_ptr ),
312
240
payload.surface_width ,
313
241
payload.surface_height ,
0 commit comments