Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ebf1fa9

Browse files
committedMar 24, 2025·
xe: jit: gemm: expand decomp cases, enforce fpmath
1 parent 608baa3 commit ebf1fa9

File tree

3 files changed

+75
-68
lines changed

3 files changed

+75
-68
lines changed
 

‎src/gpu/intel/jit/gemm/gen_gemm.hpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ struct gen_gemm_t : public gpu_gemm_t {
7474
wei_decomp_ = (utils::one_of(d->c_type(), f32, f16, bf16, f8_e5m2,
7575
f8_e4m3)
7676
&& utils::one_of(d->a_type(), u8, s8, s4, u4)
77-
&& utils::one_of(d->b_type(), f16, f32, bf16,
78-
f8_e5m2, f8_e4m3))
77+
&& utils::one_of(d->b_type(), u8, s8, s4, u4,
78+
f16, f32, bf16, f8_e5m2, f8_e4m3))
7979
&& attr()->mayiconvert(d->a_type(), f32);
8080
dy_quant_enabled_
8181
= (utils::one_of(d->c_type(), f32, f16, bf16)
@@ -224,6 +224,9 @@ struct gen_gemm_t : public gpu_gemm_t {
224224

225225
if (!attr()->zero_points_.has_default_values()) {
226226
if (!attr_zps.has_default_values(DNNL_ARG_A)) {
227+
// Only apply to integers inputs.
228+
VDISPATCH_GEMM(utils::one_of(d->a_type(), s4, u4, s8, u8),
229+
VERBOSE_UNSUPPORTED_ZP_CFG);
227230
const int cmask_a = attr_zps.get_mask(DNNL_ARG_A);
228231
ao_dims_ = cmask_a > 0;
229232

@@ -253,10 +256,17 @@ struct gen_gemm_t : public gpu_gemm_t {
253256
VDISPATCH_GEMM(utils::one_of(cmask_a, 0, mask_per_oc,
254257
mask_per_ic),
255258
VERBOSE_UNSUPPORTED_ZP_CFG);
259+
// Weights zp can only be performantly enabled during upconversion.
260+
VDISPATCH_GEMM(wei_decomp_
261+
|| utils::one_of(d->b_type(), s4, u4),
262+
VERBOSE_UNSUPPORTED_ZP_CFG);
256263
}
257264
}
258265

259266
if (!attr_zps.has_default_values(DNNL_ARG_B)) {
267+
// Only apply to integers inputs.
268+
VDISPATCH_GEMM(utils::one_of(d->b_type(), s4, u4, s8, u8),
269+
VERBOSE_UNSUPPORTED_ZP_CFG);
260270
const int cmask_b = attr_zps.get_mask(DNNL_ARG_B);
261271
bo_dims_ = cmask_b > 0;
262272

@@ -390,6 +400,7 @@ struct gen_gemm_t : public gpu_gemm_t {
390400
: data_type::s32;
391401
if (swap_ab_) std::swap(ao_type, bo_type);
392402
bool int_acc = utils::one_of(eff_a_type(), s8, u8);
403+
int_acc &= !wei_scales_2d_;
393404
auto co_type = with_bias() ? d->bias_type()
394405
: with_sum_ab() ? d->sum_ab_type
395406
: int_acc ? s32

‎src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -535,19 +535,13 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
535535

536536
auto add_mode_matches = [&](bool has_mode, const char *(*match)(Type)) {
537537
if (!has_mode) return;
538-
auto &def = base.selector.precisions;
539538
if (match(problem_.Ta)) {
540-
match_params.push_back(base);
541539
match_params.back().selector.precisions[0] = match(problem_.Ta);
542-
match_params.back().selector.precisions[1] = def[1];
543540
}
544541
if (match(problem_.Tb)) {
545-
match_params.push_back(base);
546-
match_params.back().selector.precisions[0] = def[0];
547542
match_params.back().selector.precisions[1] = match(problem_.Tb);
548543
}
549544
if (match(problem_.Ta) && match(problem_.Tb)) {
550-
match_params.push_back(base);
551545
match_params.back().selector.precisions[0] = match(problem_.Ta);
552546
match_params.back().selector.precisions[1] = match(problem_.Tb);
553547
}

‎src/gpu/intel/jit/gemm/selector/db/kernel.db

+62-60
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.