Skip to content

Commit 1d9b22a

Browse files
committedDec 19, 2024
xe: sdpa: Fix alignment for the K and V tensors
1 parent 6fec1cd commit 1d9b22a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed
 

‎src/gpu/intel/ocl/micro_sdpa.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
264264

265265
problem_kq.B.layout = MatrixLayout::Pr;
266266
problem_kq.C.layout = MatrixLayout::T;
267-
problem_kq.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta));
267+
const memory_desc_wrapper key_mdw(key_md());
268+
auto ldk = static_cast<int>(
269+
gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size());
270+
problem_kq.A.setAlignment(alignmentForLD(ldk));
268271
problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM
269272
problem_kq.B.crosspack = 2;
270273
problem_kq.B.tileR = into<uint16_t>(d_max());
@@ -331,7 +334,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
331334

332335
problem_vs.B.layout = MatrixLayout::Pr;
333336
problem_vs.C.layout = MatrixLayout::N;
334-
problem_vs.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta));
337+
const memory_desc_wrapper val_mdw(val_md());
338+
auto ldv = static_cast<int>(
339+
gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size());
340+
problem_vs.A.setAlignment(alignmentForLD(ldv));
335341
problem_vs.B.setAlignment(64); // S is packed in SLM
336342
problem_vs.B.crosspack = 16;
337343
sizes.m = d->values();

0 commit comments

Comments
 (0)
Please sign in to comment.