@@ -223,6 +223,9 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
223
223
A += DST_OFF (b1 , b0 , 0 , 0 , 0 );
224
224
#if WITH_ATTN_MASK
225
225
msk += MSK_OFF (b1 % MSK_D0 , b0 % MSK_D1 , 0 , 0 );
226
+ #ifndef BLOCK_MSK
227
+ int mask_aligned = (((size_t )msk ) % 4 ) == 0 ;
228
+ #endif
226
229
#endif
227
230
228
231
#if KEY_SCALES
@@ -320,9 +323,17 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
320
323
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
321
324
mask_tile_type mask_tile ;
322
325
#if BROADCAST_MASK_Q
326
+ #if BLOCK_MSK
323
327
tile_load_block (& mask_tile , msk , 0 , k0 + sg_i0_kq , 0 );
324
328
#else
325
- tile_load_t (& mask_tile , msk , q , k , q , sg_j0_kq + wg_j0 , k0 + sg_i0_kq );
329
+ if (mask_aligned ) {
330
+ tile_load_block (& mask_tile , msk , 0 , k0 + sg_i0_kq , 0 );
331
+ } else {
332
+ tile_load_full (& mask_tile , msk , 0 , k0 + sg_i0_kq , 0 );
333
+ }
334
+ #endif
335
+ #else
336
+ tile_load_t (& mask_tile , msk , q , k , sg_j0_kq + wg_j0 , k0 + sg_i0_kq );
326
337
#endif
327
338
#endif
328
339
0 commit comments