Skip to content

Commit 6fec1cd

Browse files
committedDec 18, 2024
xe: sdpa: Allow non-transposed scalars for the KQ matmul
1 parent b8303a5 commit 6fec1cd

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed
 

‎src/gpu/intel/ocl/micro_sdpa.cl

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
181181
uint lda = DST_S2;
182182

183183
#if KEY_SCALES || KEY_ZERO_POINTS
184-
uint ldkq = div_up(d, KEY_GROUP_SIZE);
184+
uint ldkq = KEY_D3;
185185
#endif
186186
#if VAL_SCALES || VAL_ZERO_POINTS
187187
uint ldvq = div_up(d, VAL_GROUP_SIZE);

‎src/gpu/intel/ocl/micro_sdpa.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,14 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
242242
auto scale_dt = key_scales_dt();
243243
problem_kq.Ta_scale = jit::convert_dnnl_to_kernel_type(scale_dt);
244244
problem_kq.A_scale.alignment = uint8_t(types::data_type_size(scale_dt));
245-
problem_kq.A_scale.layout = MatrixLayout::T;
245+
problem_kq.A_scale.layout = MatrixLayout::N;
246246
problem_kq.aScale2D = true;
247247
}
248248
if (with_key_zp()) {
249249
auto zp_dt = key_zp_dt();
250250
problem_kq.Tao = jit::convert_dnnl_to_kernel_type(zp_dt);
251251
problem_kq.AO.alignment = uint8_t(types::data_type_size(zp_dt));
252-
problem_kq.AO.layout = MatrixLayout::T;
252+
problem_kq.AO.layout = MatrixLayout::N;
253253
problem_kq.aoPtrDims = kq_common_zp ? 0 : 2;
254254
problem_kq.aOffset = ABOffset::Calc;
255255
}

0 commit comments

Comments
 (0)