Skip to content

Commit a42ea4d

Browse files
committed
xe: sdpa: Fix mask loads for unaligned memory
1 parent 1d9b22a commit a42ea4d

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/gpu/intel/ocl/micro_sdpa.cl

+12-1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
223223
A += DST_OFF(b1, b0, 0, 0, 0);
224224
#if WITH_ATTN_MASK
225225
msk += MSK_OFF(b1 % MSK_D0, b0 % MSK_D1, 0, 0);
226+
#ifndef BLOCK_MSK
227+
int mask_aligned = (((size_t)msk) % 4) == 0;
228+
#endif
226229
#endif
227230

228231
#if KEY_SCALES
@@ -320,9 +323,17 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
320323
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
321324
mask_tile_type mask_tile;
322325
#if BROADCAST_MASK_Q
326+
#if BLOCK_MSK
323327
tile_load_block(&mask_tile, msk, 0, k0 + sg_i0_kq, 0);
324328
#else
325-
tile_load_t(&mask_tile, msk, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
329+
if (mask_aligned) {
330+
tile_load_block(&mask_tile, msk, 0, k0 + sg_i0_kq, 0);
331+
} else {
332+
tile_load_full(&mask_tile, msk, 0, k0 + sg_i0_kq, 0);
333+
}
334+
#endif
335+
#else
336+
tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
326337
#endif
327338
#endif
328339

src/gpu/intel/ocl/micro_sdpa.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
413413
auto ldk = gemm_desc_t::get_ld(*pd()->key_md()) * key_mdw.data_type_size();
414414
auto ldv = gemm_desc_t::get_ld(*pd()->val_md()) * val_mdw.data_type_size();
415415
auto lda = gemm_desc_t::get_ld(*pd()->dst_md()) * dst_mdw.data_type_size();
416+
auto ldmsk = pd()->attn_mask_md()->dims[3] * msk_mdw.data_type_size();
416417
kernel_ctx.define_int("Q_ALIGN", jit::alignmentForLD(int(ldq)));
417418
kernel_ctx.define_int("K_ALIGN", jit::alignmentForLD(int(ldk)));
418419
kernel_ctx.define_int("V_ALIGN", jit::alignmentForLD(int(ldv)));
@@ -483,6 +484,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
483484
if (d_full) {
484485
if (ldq % 4 == 0) kernel_ctx.define_int("BLOCK_Q", 1);
485486
if (lda % 4 == 0 && v_full) kernel_ctx.define_int("BLOCK_A", 1);
487+
if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1);
486488
kernel_ctx.define_int("REMAINDER_Q", (d->queries() % tile_q) != 0);
487489
} else if (pd()->arch() >= compute::gpu_arch_t::xe_hpc) {
488490
auto vbytes = d->values() * val_mdw.data_type_size();

0 commit comments

Comments
 (0)