@@ -62,12 +62,14 @@ compute::scalar_type_t gen_gemm_kernel_desc_t::scalar_type() const {
62
62
63
63
status_t gen_gemm_kernel_desc_t::finalize (const char *tags) {
64
64
// Update problem alignments to match catalog entry.
65
- if (!isPacked (problem_.A .layout )) {
65
+ if (!isPacked (problem_.A .layout )
66
+ && problem_.Ta_ext .paddedSize () >= problem_.Ta .paddedSize ()) {
66
67
problem_.A .setAlignment (std::max (
67
68
problem_.Ta_ext .paddedSize (), entry_->driverInfo .alignment [0 ]));
68
69
}
69
70
70
- if (!isPacked (problem_.B .layout )) {
71
+ if (!isPacked (problem_.B .layout )
72
+ && problem_.Tb_ext .paddedSize () >= problem_.Tb .paddedSize ()) {
71
73
problem_.B .setAlignment (std::max (
72
74
problem_.Tb_ext .paddedSize (), entry_->driverInfo .alignment [1 ]));
73
75
}
@@ -380,8 +382,8 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
380
382
relaxed_acc_ = mode & mode_relaxed_acc;
381
383
382
384
auto a_type_size = types::data_type_size (a_type);
383
- auto b_type_size = types::data_type_size (a_type );
384
- auto c_type_size = types::data_type_size (a_type );
385
+ auto b_type_size = types::data_type_size (b_type );
386
+ auto c_type_size = types::data_type_size (c_type );
385
387
386
388
align_a = nstl::max (align_a, int (a_type_size));
387
389
align_b = nstl::max (align_b, int (b_type_size));
0 commit comments