@@ -309,10 +309,15 @@ KERNEL(gemm_tiled_opt)(
309
309
else
310
310
#endif // INDIRECT_INPUT1
311
311
{
312
- #if N_IS_ALIGNED_4BYTE
313
- b_tile [b_load_id ] = BLOCK_READ_B (b_ptr , 0 );
314
- #else
312
+ // #if N_IS_ALIGNED_4BYTE
313
+ // b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
314
+ // #else
315
+ // b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
316
+ // #endif
317
+ #if TILE_N_NOT_DIVISIBLE
315
318
b_tile [b_load_id ] = b_raw_global_id > N - 1 ? 0 : b_ptr [sglid ];
319
+ #else
320
+ b_tile [b_load_id ] = BLOCK_READ_B (b_ptr , 0 );
316
321
#endif
317
322
b_ptr += input1_offset ;
318
323
}
@@ -395,11 +400,16 @@ KERNEL(gemm_tiled_opt)(
395
400
#if INDIRECT_INPUT0
396
401
uint a_idx = FUNC_CALL (get_input0_indirect_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , (y + dot_id ), (k * TILE_K + sglid ), beam_table );
397
402
A_FLOATN a_read = input0 [a_idx ];
398
- #elif K_IS_ALIGNED_4BYTE
399
- A_FLOATN a_read = BLOCK_READ_A (a_ptr , 0 );
400
- #else // K_IS_ALIGNED_4BYTE
403
+ // #elif K_IS_ALIGNED_4BYTE
404
+ // A_FLOATN a_read = BLOCK_READ_A(a_ptr, 0);
405
+ // #else // K_IS_ALIGNED_4BYTE
406
+ // A_FLOATN a_read = a_ptr[sglid];
407
+ // #endif // K_IS_ALIGNED_4BYTE
408
+ #elif TILE_K_NOT_DIVISIBLE
401
409
A_FLOATN a_read = a_ptr [sglid ];
402
- #endif // K_IS_ALIGNED_4BYTE
410
+ #else // TILE_K_NOT_DIVISIBLE
411
+ A_FLOATN a_read = BLOCK_READ_A (a_ptr , 0 );
412
+ #endif // TILE_K_NOT_DIVISIBLE
403
413
#endif // IS_DYNAMIC
404
414
a_ptr += input0_offset ;
405
415
@@ -617,11 +627,16 @@ KERNEL(gemm_tiled_opt)(
617
627
else
618
628
#endif
619
629
{
620
- #if N_IS_ALIGNED_4BYTE
621
- b_tile [b_load_id ] = BLOCK_READ_B (b_ptr , 0 );
622
- #else // N_IS_ALIGNED_4BYTE
630
+ // #if N_IS_ALIGNED_4BYTE
631
+ // b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
632
+ // #else // N_IS_ALIGNED_4BYTE
633
+ // b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
634
+ // #endif // N_IS_ALIGNED_4BYTE
635
+ #if TILE_N_NOT_DIVISIBLE
623
636
b_tile [b_load_id ] = b_raw_global_id > N - 1 ? 0 : b_ptr [sglid ];
624
- #endif // N_IS_ALIGNED_4BYTE
637
+ #else // TILE_N_NOT_DIVISIBLE
638
+ b_tile [b_load_id ] = BLOCK_READ_B (b_ptr , 0 );
639
+ #endif // TILE_N_NOT_DIVISIBLE
625
640
b_ptr += input1_offset ;
626
641
}
627
642
#elif TRANSPOSE_INPUT1 == TRANSPOSE_OTHER // TRANSPOSE_INPUT1 == 0
@@ -660,23 +675,24 @@ KERNEL(gemm_tiled_opt)(
660
675
}
661
676
#endif // TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
662
677
663
- #if !INDIRECT_INPUT0 && K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST )
664
- a_ptr = input0 + FUNC_CALL (get_input0_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , y , (K_FULL_ITERATIONS * TILE_K ));
665
- #endif
678
+ // #if !INDIRECT_INPUT0 && K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST)
679
+ // a_ptr = input0 + FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, (K_FULL_ITERATIONS * TILE_K));
680
+ // #endif
666
681
// Loading leftovers of the matrix A and tile C calculation
667
682
unroll_for (uint dot_id = 0 ; dot_id < tile_m_iterations ; dot_id ++ ) {
668
683
#if INDIRECT_INPUT0
669
684
uint a_idx = FUNC_CALL (get_input0_indirect_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , (y + dot_id ), (K_FULL_ITERATIONS * TILE_K + sglid ), beam_table );
670
- INPUT0_TYPE a_read = input0 [a_idx ];
671
- #else // INDIRECT_INPUT0
672
- #if K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST )
673
- INPUT0_TYPE a_read = BLOCK_READ_A (a_ptr , 0 );
674
- a_ptr += input0_offset ;
685
+ // INPUT0_TYPE a_read = input0[a_idx];
686
+ // #else // INDIRECT_INPUT0
687
+ // #if K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST)
688
+ // INPUT0_TYPE a_read = BLOCK_READ_A(a_ptr, 0);
689
+ // a_ptr += input0_offset;
675
690
#else
676
691
uint a_idx = FUNC_CALL (get_input0_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , (y + dot_id ), (K_FULL_ITERATIONS * TILE_K + sglid ));
692
+ #endif //--kelvin
677
693
INPUT0_TYPE a_read = input0 [a_idx ];
678
- #endif
679
- #endif // INDIRECT_INPUT0
694
+ // #endif
695
+ // #endif // INDIRECT_INPUT0
680
696
unroll_for (uint simd_id = 0 ; simd_id < TILE_K_LEFTOVER ; simd_id ++ ) {
681
697
c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read , simd_id )), b_tile [simd_id ], c_tile [dot_id ]);
682
698
}
0 commit comments