File tree 1 file changed +8
-2
lines changed
1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -264,7 +264,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
264
264
265
265
problem_kq.B .layout = MatrixLayout::Pr;
266
266
problem_kq.C .layout = MatrixLayout::T;
267
- problem_kq.A .setAlignment (alignmentForLD (d->head_size () * problem.Ta ));
267
+ const memory_desc_wrapper key_mdw (key_md ());
268
+ auto ldk = static_cast <int >(
269
+ gemm_desc_t::get_ld (*key_md ()) * key_mdw.data_type_size ());
270
+ problem_kq.A .setAlignment (alignmentForLD (ldk));
268
271
problem_kq.B .setAlignment (64 ); // Q is packed in VNNI format in SLM
269
272
problem_kq.B .crosspack = 2 ;
270
273
problem_kq.B .tileR = into<uint16_t >(d_max ());
@@ -331,7 +334,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
331
334
332
335
problem_vs.B .layout = MatrixLayout::Pr;
333
336
problem_vs.C .layout = MatrixLayout::N;
334
- problem_vs.A .setAlignment (alignmentForLD (d->head_size () * problem.Ta ));
337
+ const memory_desc_wrapper val_mdw (val_md ());
338
+ auto ldv = static_cast <int >(
339
+ gemm_desc_t::get_ld (*val_md ()) * val_mdw.data_type_size ());
340
+ problem_vs.A .setAlignment (alignmentForLD (ldv));
335
341
problem_vs.B .setAlignment (64 ); // S is packed in SLM
336
342
problem_vs.B .crosspack = 16 ;
337
343
sizes.m = d->values ();
You can’t perform that action at this time.
0 commit comments