Skip to content

Commit e3a9503

Browse files
xe: jit: gemm: expand decomp cases, enforce fpmath
1 parent 5f0373c commit e3a9503

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/gpu/intel/jit/gemm/gen_gemm.hpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ struct gen_gemm_t : public gpu_gemm_t {
7373
d->c_type(), f8_e5m2, f8_e4m3, f16, bf16, f32));
7474
wei_decomp_ = (utils::one_of(d->c_type(), f32, f16, bf16, f8_e5m2,
7575
f8_e4m3)
76-
&& utils::one_of(d->a_type(), u8, s8, s4, u4)
77-
&& utils::one_of(d->b_type(), f16, f32, bf16,
78-
f8_e5m2, f8_e4m3))
76+
&& utils::one_of(d->a_type(), u8, s8, s4, u4))
7977
&& attr()->mayiconvert(d->a_type(), f32);
8078
dy_quant_enabled_
8179
= (utils::one_of(d->c_type(), f32, f16, bf16)

src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,23 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
561561
add_mode_matches(fpmath_bf16, [](Type dt) -> const char * {
562562
if (dt == Type::f32) { return "[SB]"; }
563563
if (dt.isInt8() || dt.isInt4()) return "[OB]";
564+
return nullptr;
565+
});
566+
567+
add_mode_matches(fpmath_bf16, [](Type dt) -> const char * {
568+
if (dt.isInt8() || dt.isInt4()) return "B";
564569
if (dt.isF8()) return "B";
565570
return nullptr;
566571
});
567572

568573
add_mode_matches(fpmath_f16, [](Type dt) -> const char * {
569574
if (dt == Type::f32) { return "[SH]"; }
570575
if (dt.isInt8() || dt.isInt4()) return "[OH]";
576+
return nullptr;
577+
});
578+
579+
add_mode_matches(fpmath_f16, [](Type dt) -> const char * {
580+
if (dt.isInt8() || dt.isInt4()) return "H";
571581
if (dt.isF8()) return "H";
572582
return nullptr;
573583
});

0 commit comments

Comments
 (0)