Skip to content

Commit 893b1df

Browse files
committedMar 14, 2025
xe: jit: gemm: handle data type alignment requirements more strictly
1 parent 735ecb8 commit 893b1df

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed
 

‎src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ compute::scalar_type_t gen_gemm_kernel_desc_t::scalar_type() const {
6262

6363
status_t gen_gemm_kernel_desc_t::finalize(const char *tags) {
6464
// Update problem alignments to match catalog entry.
65-
if (!isPacked(problem_.A.layout)) {
65+
if (!isPacked(problem_.A.layout)
66+
&& problem_.Ta_ext.paddedSize() >= problem_.Ta.paddedSize()) {
6667
problem_.A.setAlignment(std::max(
6768
problem_.Ta_ext.paddedSize(), entry_->driverInfo.alignment[0]));
6869
}
6970

70-
if (!isPacked(problem_.B.layout)) {
71+
if (!isPacked(problem_.B.layout)
72+
&& problem_.Tb_ext.paddedSize() >= problem_.Tb.paddedSize()) {
7173
problem_.B.setAlignment(std::max(
7274
problem_.Tb_ext.paddedSize(), entry_->driverInfo.alignment[1]));
7375
}
@@ -380,8 +382,8 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
380382
relaxed_acc_ = mode & mode_relaxed_acc;
381383

382384
auto a_type_size = types::data_type_size(a_type);
383-
auto b_type_size = types::data_type_size(a_type);
384-
auto c_type_size = types::data_type_size(a_type);
385+
auto b_type_size = types::data_type_size(b_type);
386+
auto c_type_size = types::data_type_size(c_type);
385387

386388
align_a = nstl::max(align_a, int(a_type_size));
387389
align_b = nstl::max(align_b, int(b_type_size));

0 commit comments

Comments
 (0)