@@ -242,24 +242,28 @@ reusable_simple_concat(__global DATA_T *dst, const ulong dst_offset0,
242
242
#define BBLOCK_WRITE BLOCK_WRITE
243
243
#define DDATA_T DATA1_T
244
244
#define AS_VEC as_ulong
245
+ #define ITEMMASK (n ) (0xFFFFFFFFFFFFFFFFUL)
245
246
#elif DATA_TYPE_SIZE == 4
246
247
#define NPERSG 2
247
248
#define BBLOCK_READ BLOCK_READ2
248
249
#define BBLOCK_WRITE BLOCK_WRITE2
249
250
#define DDATA_T DATA2_T
250
251
#define AS_VEC as_uint2
252
+ #define ITEMMASK (n ) (0xFFFFFFFFUL << ((n)*DATA_TYPE_SIZE * 8))
251
253
#elif DATA_TYPE_SIZE == 2
252
254
#define NPERSG 4
253
255
#define BBLOCK_READ BLOCK_READ4
254
256
#define BBLOCK_WRITE BLOCK_WRITE4
255
257
#define DDATA_T DATA4_T
256
258
#define AS_VEC as_ushort4
259
+ #define ITEMMASK (n ) (0xFFFFUL << ((n)*DATA_TYPE_SIZE * 8))
257
260
#elif DATA_TYPE_SIZE == 1
258
261
#define NPERSG 8
259
262
#define BBLOCK_READ BLOCK_READ8
260
263
#define BBLOCK_WRITE BLOCK_WRITE8
261
264
#define DDATA_T DATA8_T
262
265
#define AS_VEC as_uchar8
266
+ #define ITEMMASK (n ) (0xFFUL << ((n)*DATA_TYPE_SIZE * 8))
263
267
#endif
264
268
265
269
/*
@@ -287,7 +291,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
287
291
const idx_t src_concat_axis1 , const idx_t src_padded_concat_axis1 ,
288
292
const idx_t inner_dim ) {
289
293
290
- const size_t dtsize = DATA_TYPE_SIZE ;
294
+ const int dtsize = DATA_TYPE_SIZE ;
291
295
const int loads_per_sg = NPERSG ;
292
296
const int elems_per_sg = SIMD * loads_per_sg ;
293
297
@@ -297,10 +301,10 @@ internal_padding_block_concat2(__global DATA_T *dst,
297
301
const unsigned B0 = 1 ;
298
302
#endif
299
303
300
- const size_t first_boundary_block
304
+ const idx_t first_boundary_block
301
305
= ((DIV_UP (src_concat_axis0 , B0 ) - 1 ) * inner_dim );
302
- const size_t last_boundary_block = first_boundary_block + inner_dim - 1 ;
303
- const size_t tot_boundary_block
306
+ const idx_t last_boundary_block = first_boundary_block + inner_dim - 1 ;
307
+ const idx_t tot_boundary_block
304
308
= (DIV_UP (dst_padded_concat_axis , B0 ) * inner_dim );
305
309
306
310
// TODO: host side check, if blocks_per_sg > #blocks in src0, use legacy method instead
@@ -315,33 +319,34 @@ internal_padding_block_concat2(__global DATA_T *dst,
315
319
__global const DATA_T * src ;
316
320
317
321
// within block idx, ex: 0-7, 0-15, 0-23, 0-31
318
- size_t blid = get_local_id (0 ) % B0 ;
322
+ idx_t blid = get_local_id (0 ) % B0 ;
319
323
320
- size_t id_start = get_group_id (0 ) * elems_per_sg + get_local_id (0 );
321
- long bii = id_start / B0 ;
324
+ idx_t id_start = get_group_id (0 ) * elems_per_sg + get_local_id (0 );
325
+ idx_t bii = id_start / B0 ;
322
326
323
- size_t sg_first_bii = get_group_id (0 ) * elems_per_sg / B0 ;
327
+ idx_t sg_first_bii = get_group_id (0 ) * elems_per_sg
328
+ / B0 ; //TODO: replace elems/B0 w/blocks_per_subgroup?
324
329
325
330
// index along concat dimension
326
- long ic = (bii / inner_dim ) * B0 + blid ;
327
- long ic_end = ((bii + blocks_per_sg ) / inner_dim ) * B0 + blid ;
331
+ idx_t ic = (bii / inner_dim ) * B0 + blid ;
332
+ idx_t ic_end = ((bii + blocks_per_sg ) / inner_dim ) * B0 + blid ;
328
333
329
- size_t sg_last_bii = sg_first_bii + blocks_per_sg
334
+ idx_t sg_last_bii = sg_first_bii + blocks_per_sg
330
335
- 1 ; //TODO: verify w/B0 > sg_tot_elems
331
336
332
- size_t ccsz
337
+ idx_t ccsz
333
338
= src_concat_axis0 ; // (31) padded concat dimension for calculating batched src
334
- size_t padded_ccsz
339
+ idx_t padded_ccsz
335
340
= padded_offset1 ; // (32) padded concat dimension for calculating batched src
336
341
337
- size_t cutoff
342
+ idx_t cutoff
338
343
= 0 ; // determines boundary offset for write sg that spans multiple layout blocks
339
344
bool boundary = false; // current sg spans boundary between two inputs?
340
345
341
- size_t batch_offset = inner_dim * padded_ccsz * get_global_id (1 );
342
- size_t batch_offset1
346
+ idx_t batch_offset = inner_dim * padded_ccsz * get_global_id (1 );
347
+ idx_t batch_offset1
343
348
= inner_dim * src_padded_concat_axis1 * get_global_id (1 );
344
- size_t ext_batch_offset
349
+ idx_t ext_batch_offset
345
350
= inner_dim * dst_padded_concat_axis * get_global_id (1 );
346
351
347
352
// completely aligned r/w blocks, no boundary src case nor misaligned reads
@@ -350,14 +355,16 @@ internal_padding_block_concat2(__global DATA_T *dst,
350
355
DDATA_T val = BBLOCK_READ (src0 + batch_offset + sg_first_bii * B0 );
351
356
BBLOCK_WRITE (dst + ext_batch_offset + sg_first_bii * B0 , val );
352
357
} else if (sg_first_bii > last_boundary_block ) {
358
+
353
359
// if sg_bii0 fully within second source, update src_bii+other idx vars to match src1
354
- long src_blid = (ic - src_concat_axis0 ) % B0 ;
355
- long src_bii = sg_first_bii
360
+ idx_t src_blid = (ic - src_concat_axis0 ) % B0 ;
361
+ idx_t src_bii = sg_first_bii
356
362
- ((padded_offset1 / B0 )
357
363
* inner_dim ); // update inner logical block idx to 0-nblocks read range for src1
358
364
if (src_bii < 0 ) { src_bii += inner_dim ; }
359
365
360
- long sg_last_src_bii = src_bii + blocks_per_sg - 1 ;
366
+ idx_t sg_last_src_bii = src_bii + blocks_per_sg - 1 ;
367
+
361
368
ccsz = src_concat_axis1 - src_concat_axis0 ;
362
369
// TODO change dst_padded_concat_axis to padded_offset2? //TODO!!!! THIS IS WRONG
363
370
padded_ccsz = dst_padded_concat_axis
@@ -366,13 +373,13 @@ internal_padding_block_concat2(__global DATA_T *dst,
366
373
367
374
cutoff = (B0 + (offset1 - padded_offset1 ) % B0 )
368
375
% B0 ; // positive modulo( offset/padded_offset , block size) |-|-------|
369
- // this cutoff will continue in following misaligned(multi-cacheline/ block src)
376
+ // this cutoff will continue in following misaligned(multi-cacheline/ block src)
370
377
371
378
DDATA_T aVal = BBLOCK_READ (src1 + batch_offset1 + src_bii * B0 );
372
- long next_block = src_bii
379
+ idx_t next_block = src_bii
373
380
+ inner_dim ; //offset to read next required block for second half of cutoff
374
- long next_last_block = next_block + blocks_per_sg - 1 ;
375
- const size_t tot_src1_block
381
+ idx_t next_last_block = next_block + blocks_per_sg - 1 ;
382
+ const idx_t tot_src1_block
376
383
= (DIV_UP (src_padded_concat_axis1 , B0 ) * inner_dim );
377
384
DDATA_T bVal ;
378
385
if (next_last_block > tot_src1_block ) {
@@ -411,17 +418,20 @@ internal_padding_block_concat2(__global DATA_T *dst,
411
418
aVal = bVal ;
412
419
}
413
420
414
- if ( ic >= dst_concat_axis ) {
415
- aVal = 0 ;
416
- } //depends todo: depends where each read falls, whether in range or not TODO: match for loop in BB case
421
+ const unsigned blocks_per_simd1
422
+ = SIMD / B0 ; // WRONG! TODO: what if SIMD > B0?
423
+ if ( ic >= dst_concat_axis ) { aVal = 0 ; }
417
424
418
- unsigned blocks_per_simd1 = SIMD / B0 ; //WRONG! TODO: what if SIMD > B0?
419
- if (ic_end >= dst_concat_axis ) {
425
+ if (ic < dst_concat_axis && ic_end >= dst_concat_axis ) {
420
426
#if NPERSG > 1 // TODO: reformulate as &= 0x00FFFF...
421
- for (int i = 0 ; i < NPERSG ; ++ i ) {
427
+ unroll_for (int i = 0 ; i < NPERSG ; ++ i ) {
422
428
if ((((bii + i * blocks_per_simd1 ) / inner_dim ) * B0 + blid )
423
429
>= dst_concat_axis ) {
424
- aVal [i ] = 0 ; // NOT OK! slow af
430
+ //aVal = AS_VEC(as_ulong(aVal) & ITEMMASK(i)); // aVal[i] element access is slow, especially per-byte (u8)
431
+ aVal = AS_VEC (as_ulong (aVal )
432
+ & 0xFFFFFFFFFFFFFFFF >> i * DATA_TYPE_SIZE
433
+ * 8 ); // aVal[i] element access is slow, especially per-byte (u8)
434
+ break ;
425
435
}
426
436
}
427
437
#else
@@ -451,8 +461,8 @@ internal_padding_block_concat2(__global DATA_T *dst,
451
461
DDATA_T aVal = BBLOCK_READ (src0 + batch_offset + sg_first_bii * B0 );
452
462
// since these are "boundary" blocks load corresponding blocks from next source, TODO: what if sg span > boundary? due to long sg reads, maybe set minimum problem sized to avoid edge case logic?
453
463
454
- long next_blid = (ic - src_concat_axis0 ) % B0 ;
455
- long next_bii = sg_first_bii
464
+ idx_t next_blid = (ic - src_concat_axis0 ) % B0 ;
465
+ idx_t next_bii = sg_first_bii
456
466
- ((padded_offset1 / B0 )
457
467
* inner_dim ); // update inner logical block idx to 0-nblocks read range for src1
458
468
if (next_bii < 0 ) {
@@ -470,7 +480,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
470
480
? sg_last_bii - last_boundary_block
471
481
: 0 ; //sg span ends after boundary span
472
482
473
- long sg_last_next_bii
483
+ idx_t sg_last_next_bii
474
484
= next_bii + blocks_per_sg - 1 ; //sg span in next src
475
485
476
486
ccsz = src_concat_axis1 - src_concat_axis0 ;
@@ -481,7 +491,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
481
491
482
492
cutoff = (B0 + (offset1 - padded_offset1 ) % B0 )
483
493
% B0 ; // positive modulo( offset/padded_offset , block size) |-|-------|
484
- // this cutoff will continue in following misaligned(multi-cacheline/ block src)
494
+ // this cutoff will continue in following misaligned(multi-cacheline/ block src)
485
495
486
496
unsigned blocks_per_simd1 = SIMD / B0 ; //WRONG! TODO: what if SIMD > B0?
487
497
DDATA_T bVal = BBLOCK_READ (src1 + batch_offset1 + next_bii * B0 );
@@ -502,7 +512,7 @@ internal_padding_block_concat2(__global DATA_T *dst,
502
512
DDATA_T cVal ;
503
513
if (trailing_boundary_shift ) {
504
514
// long trailing_bii = sg_last_next_bii - blocks_per_sg - trailing_boundary_shift;
505
- long trailing_bii = 0 ;
515
+ idx_t trailing_bii = 0 ;
506
516
cVal = BBLOCK_READ (src1 + batch_offset1 + trailing_bii * B0 );
507
517
508
518
int block_dt = (blocks_per_sg - trailing_boundary_shift );
@@ -526,17 +536,15 @@ internal_padding_block_concat2(__global DATA_T *dst,
526
536
trailmask = (0xFFFFFFFFFFFFFFFF << (ntrail * DATA_TYPE_SIZE * 8 ))
527
537
>> (ntrail * DATA_TYPE_SIZE
528
538
* 8 ); // explicitly ignore any reads past boundary blocks
529
-
530
539
if (cutoff > 0
531
540
&& (ic % B0 )
532
541
>= cutoff ) { // TODO: should change together with sg_shuffle_dt sizeof(B0)
533
542
aVal = AS_VEC (as_ulong (aVal ) & trailmask );
534
543
aVal |= offset_vec_data ;
535
544
}
536
545
537
- if (cutoff > 0
538
- && (ic % B0 )
539
- < cutoff ) { // TODO: should change together with sg_shuffle_dt sizeof(B0)
546
+ if (cutoff > 0 && (ic % B0 ) < cutoff ) {
547
+ // TODO: should change together with sg_shuffle_dt sizeof(B0)
540
548
aVal = AS_VEC (as_ulong (aVal ) & trailmask );
541
549
}
542
550
@@ -549,15 +557,20 @@ internal_padding_block_concat2(__global DATA_T *dst,
549
557
if (cutoff > 0
550
558
&& (get_local_id (0 ) % B0 )
551
559
< cutoff ) { // TODO: should change together with sg_shuffle_dt sizeof(B0)
552
- aVal |= cVal ;
560
+ aVal = AS_VEC ( as_ulong ( aVal ) | ( as_ulong ( cVal ) & ~ trailmask )) ;
553
561
}
554
562
555
- if (ic_end >= dst_concat_axis ) {
556
- #if NPERSG > 1 // TODO: reformulate as &= 0x00FFFF...
563
+ // TODO: doublecheck needed
564
+ if (ic > dst_concat_axis ) { aVal = 0 ; } // TODO: doublecheck needed
565
+ if (ic < dst_concat_axis && ic_end >= dst_concat_axis ) {
566
+ #if NPERSG > 1
557
567
for (int i = 0 ; i < NPERSG ; ++ i ) {
558
568
if ((((bii + i ) / inner_dim ) * B0 + blid )
559
569
>= dst_concat_axis ) {
560
- aVal [i ] = 0 ;
570
+ aVal = AS_VEC (as_ulong (aVal )
571
+ & 0xFFFFFFFFFFFFFFFF >> i * DATA_TYPE_SIZE
572
+ * 8 ); // aVal[i] element access is slow, especially per-byte (u8)
573
+ break ;
561
574
}
562
575
}
563
576
#else
0 commit comments