Skip to content

Commit 1a39fab

Browse files
committed
xe: concat: separate inner-padding and reusable conf
1 parent c3b89fe commit 1a39fab

File tree

3 files changed

+277
-133
lines changed

3 files changed

+277
-133
lines changed

src/gpu/intel/ocl/concat_utils.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ class normalization_t {
122122
padding_type_ = padding_t::external;
123123

124124
if (padding_type_ == padding_t::internal) {
125-
// may use different kernel, requires different partition for blocks
126125
chunk_size_ = math::gcd(chunk_size_, source_chunk);
127-
//chunk_size_ = 1;
128126
} else {
129127
chunk_size_ = math::gcd(chunk_size_, pdim);
130128
padded_chunk_size_ = math::gcd(padded_chunk_size_, source_chunk);
@@ -134,6 +132,7 @@ class normalization_t {
134132

135133
data_type_t data_type() const { return data_type_; }
136134
size_t data_type_size() const { return types::data_type_size(data_type_); }
135+
void set_pessimistic_chunk_size() { chunk_size_ = 1; }
137136

138137
dim_t max_write_size() const {
139138
dim_t write_size = 1;

src/gpu/intel/ocl/reusable_simple_concat.cl

+57-44
Original file line numberDiff line numberDiff line change
@@ -242,24 +242,28 @@ reusable_simple_concat(__global DATA_T *dst, const ulong dst_offset0,
242242
#define BBLOCK_WRITE BLOCK_WRITE
243243
#define DDATA_T DATA1_T
244244
#define AS_VEC as_ulong
245+
#define ITEMMASK(n) (0xFFFFFFFFFFFFFFFFUL)
245246
#elif DATA_TYPE_SIZE == 4
246247
#define NPERSG 2
247248
#define BBLOCK_READ BLOCK_READ2
248249
#define BBLOCK_WRITE BLOCK_WRITE2
249250
#define DDATA_T DATA2_T
250251
#define AS_VEC as_uint2
252+
#define ITEMMASK(n) (0xFFFFFFFFUL << ((n)*DATA_TYPE_SIZE * 8))
251253
#elif DATA_TYPE_SIZE == 2
252254
#define NPERSG 4
253255
#define BBLOCK_READ BLOCK_READ4
254256
#define BBLOCK_WRITE BLOCK_WRITE4
255257
#define DDATA_T DATA4_T
256258
#define AS_VEC as_ushort4
259+
#define ITEMMASK(n) (0xFFFFUL << ((n)*DATA_TYPE_SIZE * 8))
257260
#elif DATA_TYPE_SIZE == 1
258261
#define NPERSG 8
259262
#define BBLOCK_READ BLOCK_READ8
260263
#define BBLOCK_WRITE BLOCK_WRITE8
261264
#define DDATA_T DATA8_T
262265
#define AS_VEC as_uchar8
266+
#define ITEMMASK(n) (0xFFUL << ((n)*DATA_TYPE_SIZE * 8))
263267
#endif
264268

265269
/*
@@ -287,7 +291,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
287291
const idx_t src_concat_axis1, const idx_t src_padded_concat_axis1,
288292
const idx_t inner_dim) {
289293

290-
const size_t dtsize = DATA_TYPE_SIZE;
294+
const int dtsize = DATA_TYPE_SIZE;
291295
const int loads_per_sg = NPERSG;
292296
const int elems_per_sg = SIMD * loads_per_sg;
293297

@@ -297,10 +301,10 @@ internal_padding_block_concat2(__global DATA_T *dst,
297301
const unsigned B0 = 1;
298302
#endif
299303

300-
const size_t first_boundary_block
304+
const idx_t first_boundary_block
301305
= ((DIV_UP(src_concat_axis0, B0) - 1) * inner_dim);
302-
const size_t last_boundary_block = first_boundary_block + inner_dim - 1;
303-
const size_t tot_boundary_block
306+
const idx_t last_boundary_block = first_boundary_block + inner_dim - 1;
307+
const idx_t tot_boundary_block
304308
= (DIV_UP(dst_padded_concat_axis, B0) * inner_dim);
305309

306310
// TODO: host side check, if blocks_per_sg > #blocks in src0, use legacy method instead
@@ -315,33 +319,34 @@ internal_padding_block_concat2(__global DATA_T *dst,
315319
__global const DATA_T *src;
316320

317321
// within block idx, ex: 0-7, 0-15, 0-23, 0-31
318-
size_t blid = get_local_id(0) % B0;
322+
idx_t blid = get_local_id(0) % B0;
319323

320-
size_t id_start = get_group_id(0) * elems_per_sg + get_local_id(0);
321-
long bii = id_start / B0;
324+
idx_t id_start = get_group_id(0) * elems_per_sg + get_local_id(0);
325+
idx_t bii = id_start / B0;
322326

323-
size_t sg_first_bii = get_group_id(0) * elems_per_sg / B0;
327+
idx_t sg_first_bii = get_group_id(0) * elems_per_sg
328+
/ B0; //TODO: replace elems/B0 w/blocks_per_subgroup?
324329

325330
// index along concat dimension
326-
long ic = (bii / inner_dim) * B0 + blid;
327-
long ic_end = ((bii + blocks_per_sg) / inner_dim) * B0 + blid;
331+
idx_t ic = (bii / inner_dim) * B0 + blid;
332+
idx_t ic_end = ((bii + blocks_per_sg) / inner_dim) * B0 + blid;
328333

329-
size_t sg_last_bii = sg_first_bii + blocks_per_sg
334+
idx_t sg_last_bii = sg_first_bii + blocks_per_sg
330335
- 1; //TODO: verify w/B0 > sg_tot_elems
331336

332-
size_t ccsz
337+
idx_t ccsz
333338
= src_concat_axis0; // (31) padded concat dimension for calculating batched src
334-
size_t padded_ccsz
339+
idx_t padded_ccsz
335340
= padded_offset1; // (32) padded concat dimension for calculating batched src
336341

337-
size_t cutoff
342+
idx_t cutoff
338343
= 0; // determines boundary offset for write sg that spans multiple layout blocks
339344
bool boundary = false; // current sg spans boundary between two inputs?
340345

341-
size_t batch_offset = inner_dim * padded_ccsz * get_global_id(1);
342-
size_t batch_offset1
346+
idx_t batch_offset = inner_dim * padded_ccsz * get_global_id(1);
347+
idx_t batch_offset1
343348
= inner_dim * src_padded_concat_axis1 * get_global_id(1);
344-
size_t ext_batch_offset
349+
idx_t ext_batch_offset
345350
= inner_dim * dst_padded_concat_axis * get_global_id(1);
346351

347352
// completely aligned r/w blocks, no boundary src case nor misaligned reads
@@ -350,14 +355,16 @@ internal_padding_block_concat2(__global DATA_T *dst,
350355
DDATA_T val = BBLOCK_READ(src0 + batch_offset + sg_first_bii * B0);
351356
BBLOCK_WRITE(dst + ext_batch_offset + sg_first_bii * B0, val);
352357
} else if (sg_first_bii > last_boundary_block) {
358+
353359
// if sg_bii0 fully within second source, update src_bii+other idx vars to match src1
354-
long src_blid = (ic - src_concat_axis0) % B0;
355-
long src_bii = sg_first_bii
360+
idx_t src_blid = (ic - src_concat_axis0) % B0;
361+
idx_t src_bii = sg_first_bii
356362
- ((padded_offset1 / B0)
357363
* inner_dim); // update inner logical block idx to 0-nblocks read range for src1
358364
if (src_bii < 0) { src_bii += inner_dim; }
359365

360-
long sg_last_src_bii = src_bii + blocks_per_sg - 1;
366+
idx_t sg_last_src_bii = src_bii + blocks_per_sg - 1;
367+
361368
ccsz = src_concat_axis1 - src_concat_axis0;
362369
// TODO change dst_padded_concat_axis to padded_offset2? //TODO!!!! THIS IS WRONG
363370
padded_ccsz = dst_padded_concat_axis
@@ -366,13 +373,13 @@ internal_padding_block_concat2(__global DATA_T *dst,
366373

367374
cutoff = (B0 + (offset1 - padded_offset1) % B0)
368375
% B0; // positive modulo( offset/padded_offset , block size) |-|-------|
369-
// this cutoff will continue in following misaligned(multi-cacheline/ block src)
376+
// this cutoff will continue in following misaligned(multi-cacheline/ block src)
370377

371378
DDATA_T aVal = BBLOCK_READ(src1 + batch_offset1 + src_bii * B0);
372-
long next_block = src_bii
379+
idx_t next_block = src_bii
373380
+ inner_dim; //offset to read next required block for second half of cutoff
374-
long next_last_block = next_block + blocks_per_sg - 1;
375-
const size_t tot_src1_block
381+
idx_t next_last_block = next_block + blocks_per_sg - 1;
382+
const idx_t tot_src1_block
376383
= (DIV_UP(src_padded_concat_axis1, B0) * inner_dim);
377384
DDATA_T bVal;
378385
if (next_last_block > tot_src1_block) {
@@ -411,17 +418,20 @@ internal_padding_block_concat2(__global DATA_T *dst,
411418
aVal = bVal;
412419
}
413420

414-
if (ic >= dst_concat_axis) {
415-
aVal = 0;
416-
} //depends todo: depends where each read falls, whether in range or not TODO: match for loop in BB case
421+
const unsigned blocks_per_simd1
422+
= SIMD / B0; // WRONG! TODO: what if SIMD > B0?
423+
if (ic >= dst_concat_axis) { aVal = 0; }
417424

418-
unsigned blocks_per_simd1 = SIMD / B0; //WRONG! TODO: what if SIMD > B0?
419-
if (ic_end >= dst_concat_axis) {
425+
if (ic < dst_concat_axis && ic_end >= dst_concat_axis) {
420426
#if NPERSG > 1 // TODO: reformulate as &= 0x00FFFF...
421-
for (int i = 0; i < NPERSG; ++i) {
427+
unroll_for(int i = 0; i < NPERSG; ++i) {
422428
if ((((bii + i * blocks_per_simd1) / inner_dim) * B0 + blid)
423429
>= dst_concat_axis) {
424-
aVal[i] = 0; // NOT OK! slow af
430+
//aVal = AS_VEC(as_ulong(aVal) & ITEMMASK(i)); // aVal[i] element access is slow, especially per-byte (u8)
431+
aVal = AS_VEC(as_ulong(aVal)
432+
& 0xFFFFFFFFFFFFFFFF >> i * DATA_TYPE_SIZE
433+
* 8); // aVal[i] element access is slow, especially per-byte (u8)
434+
break;
425435
}
426436
}
427437
#else
@@ -451,8 +461,8 @@ internal_padding_block_concat2(__global DATA_T *dst,
451461
DDATA_T aVal = BBLOCK_READ(src0 + batch_offset + sg_first_bii * B0);
452462
// since these are "boundary" blocks load corresponding blocks from next source, TODO: what if sg span > boundary? due to long sg reads, maybe set minimum problem sized to avoid edge case logic?
453463

454-
long next_blid = (ic - src_concat_axis0) % B0;
455-
long next_bii = sg_first_bii
464+
idx_t next_blid = (ic - src_concat_axis0) % B0;
465+
idx_t next_bii = sg_first_bii
456466
- ((padded_offset1 / B0)
457467
* inner_dim); // update inner logical block idx to 0-nblocks read range for src1
458468
if (next_bii < 0) {
@@ -470,7 +480,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
470480
? sg_last_bii - last_boundary_block
471481
: 0; //sg span ends after boundary span
472482

473-
long sg_last_next_bii
483+
idx_t sg_last_next_bii
474484
= next_bii + blocks_per_sg - 1; //sg span in next src
475485

476486
ccsz = src_concat_axis1 - src_concat_axis0;
@@ -481,7 +491,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
481491

482492
cutoff = (B0 + (offset1 - padded_offset1) % B0)
483493
% B0; // positive modulo( offset/padded_offset , block size) |-|-------|
484-
// this cutoff will continue in following misaligned(multi-cacheline/ block src)
494+
// this cutoff will continue in following misaligned(multi-cacheline/ block src)
485495

486496
unsigned blocks_per_simd1 = SIMD / B0; //WRONG! TODO: what if SIMD > B0?
487497
DDATA_T bVal = BBLOCK_READ(src1 + batch_offset1 + next_bii * B0);
@@ -502,7 +512,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
502512
DDATA_T cVal;
503513
if (trailing_boundary_shift) {
504514
// long trailing_bii = sg_last_next_bii - blocks_per_sg - trailing_boundary_shift;
505-
long trailing_bii = 0;
515+
idx_t trailing_bii = 0;
506516
cVal = BBLOCK_READ(src1 + batch_offset1 + trailing_bii * B0);
507517

508518
int block_dt = (blocks_per_sg - trailing_boundary_shift);
@@ -526,17 +536,15 @@ internal_padding_block_concat2(__global DATA_T *dst,
526536
trailmask = (0xFFFFFFFFFFFFFFFF << (ntrail * DATA_TYPE_SIZE * 8))
527537
>> (ntrail * DATA_TYPE_SIZE
528538
* 8); // explicitly ignore any reads past boundary blocks
529-
530539
if (cutoff > 0
531540
&& (ic % B0)
532541
>= cutoff) { // TODO: should change together with sg_shuffle_dt sizeof(B0)
533542
aVal = AS_VEC(as_ulong(aVal) & trailmask);
534543
aVal |= offset_vec_data;
535544
}
536545

537-
if (cutoff > 0
538-
&& (ic % B0)
539-
< cutoff) { // TODO: should change together with sg_shuffle_dt sizeof(B0)
546+
if (cutoff > 0 && (ic % B0) < cutoff) {
547+
// TODO: should change together with sg_shuffle_dt sizeof(B0)
540548
aVal = AS_VEC(as_ulong(aVal) & trailmask);
541549
}
542550

@@ -549,15 +557,20 @@ internal_padding_block_concat2(__global DATA_T *dst,
549557
if (cutoff > 0
550558
&& (get_local_id(0) % B0)
551559
< cutoff) { // TODO: should change together with sg_shuffle_dt sizeof(B0)
552-
aVal |= cVal;
560+
aVal = AS_VEC(as_ulong(aVal) | (as_ulong(cVal) & ~trailmask));
553561
}
554562

555-
if (ic_end >= dst_concat_axis) {
556-
#if NPERSG > 1 // TODO: reformulate as &= 0x00FFFF...
563+
// TODO: doublecheck needed
564+
if (ic > dst_concat_axis) { aVal = 0; } // TODO: doublecheck needed
565+
if (ic < dst_concat_axis && ic_end >= dst_concat_axis) {
566+
#if NPERSG > 1
557567
for (int i = 0; i < NPERSG; ++i) {
558568
if ((((bii + i) / inner_dim) * B0 + blid)
559569
>= dst_concat_axis) {
560-
aVal[i] = 0;
570+
aVal = AS_VEC(as_ulong(aVal)
571+
& 0xFFFFFFFFFFFFFFFF >> i * DATA_TYPE_SIZE
572+
* 8); // aVal[i] element access is slow, especially per-byte (u8)
573+
break;
561574
}
562575
}
563576
#else

0 commit comments

Comments
 (0)