Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Expand matmul decomp cases #2916

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/gpu/intel/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
@@ -74,14 +74,15 @@ struct gen_gemm_t : public gpu_gemm_t {
wei_decomp_ = (utils::one_of(d->c_type(), f32, f16, bf16, f8_e5m2,
f8_e4m3)
&& utils::one_of(d->a_type(), u8, s8, s4, u4)
&& utils::one_of(d->b_type(), f16, f32, bf16,
f8_e5m2, f8_e4m3))
&& utils::one_of(d->b_type(), u8, s8, s4, u4,
f16, f32, bf16, f8_e5m2, f8_e4m3))
&& attr()->mayiconvert(d->a_type(), f32);
dy_quant_enabled_
= (utils::one_of(d->c_type(), f32, f16, bf16)
&& utils::one_of(d->a_type(), u8, s8, s4, u4)
&& utils::one_of(d->b_type(), u8, s8))
|| all_f8;
= ((utils::one_of(d->c_type(), f32, f16, bf16)
&& utils::one_of(d->a_type(), u8, s8, s4, u4)
&& utils::one_of(d->b_type(), u8, s8))
|| all_f8)
&& !attr()->mayiconvert(d->a_type(), f32);
quant_enabled_ = wei_decomp_ || dy_quant_enabled_;
CHECK(set_default_formats(false));

@@ -224,6 +225,9 @@ struct gen_gemm_t : public gpu_gemm_t {

if (!attr()->zero_points_.has_default_values()) {
if (!attr_zps.has_default_values(DNNL_ARG_A)) {
// Only apply to integers inputs.
VDISPATCH_GEMM(utils::one_of(d->a_type(), s4, u4, s8, u8),
VERBOSE_UNSUPPORTED_ZP_CFG);
const int cmask_a = attr_zps.get_mask(DNNL_ARG_A);
ao_dims_ = cmask_a > 0;

@@ -253,10 +257,20 @@ struct gen_gemm_t : public gpu_gemm_t {
VDISPATCH_GEMM(utils::one_of(cmask_a, 0, mask_per_oc,
mask_per_ic),
VERBOSE_UNSUPPORTED_ZP_CFG);
// Weights zp can only be performantly enabled during upconversion
// for cases that perform decompression.
VDISPATCH_GEMM(wei_decomp_
|| utils::one_of(
d->c_type(), s8, u8, s32)
|| utils::one_of(d->a_type(), s4, u4),
VERBOSE_UNSUPPORTED_ZP_CFG);
}
}

if (!attr_zps.has_default_values(DNNL_ARG_B)) {
// Only apply to integers inputs.
VDISPATCH_GEMM(utils::one_of(d->b_type(), s4, u4, s8, u8),
VERBOSE_UNSUPPORTED_ZP_CFG);
const int cmask_b = attr_zps.get_mask(DNNL_ARG_B);
bo_dims_ = cmask_b > 0;

@@ -344,6 +358,7 @@ struct gen_gemm_t : public gpu_gemm_t {
src_scales_2d_ = false;
else {
src_q2d_group_k = scales_group_k;
// 2d src scales only supported during dequantization.
VDISPATCH_GEMM(dy_quant_enabled_
&& utils::one_of(eff_a_type(), s4, u4),
VERBOSE_UNSUPPORTED_SCALES_CFG);
@@ -390,6 +405,7 @@ struct gen_gemm_t : public gpu_gemm_t {
: data_type::s32;
if (swap_ab_) std::swap(ao_type, bo_type);
bool int_acc = utils::one_of(eff_a_type(), s8, u8);
int_acc &= !wei_scales_2d_;
auto co_type = with_bias() ? d->bias_type()
: with_sum_ab() ? d->sum_ab_type
: int_acc ? s32
@@ -420,12 +436,17 @@ struct gen_gemm_t : public gpu_gemm_t {
// Handle special compute modes.
kernel_desc_t::compute_mode mode = kernel_desc_t::mode_default;

if (attr()->mayiconvert(f32, tf32))
set_mode(mode, kernel_desc_t::mode_tf32);
if (attr()->mayiconvert(f32, bf16))
if (attr()->mayiconvert(u8, bf16))
set_mode(mode, kernel_desc_t::mode_bf16x1);
if (attr()->mayiconvert(f32, f16))
if (attr()->mayiconvert(u8, f16))
set_mode(mode, kernel_desc_t::mode_f16x1);
if (attr()->mayiconvert(f32, tf32)
&& !(mode
& (kernel_desc_t::mode_f16x1
| kernel_desc_t::mode_bf16x1))) {
VDISPATCH_GEMM(!wei_decomp_, VERBOSE_UNSUPPORTED_DT);
set_mode(mode, kernel_desc_t::mode_tf32);
}
if (attr()->mayiconvert(f32, f32))
set_mode(mode, kernel_desc_t::mode_strict);
if (attr()->deterministic_)
69 changes: 39 additions & 30 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp
Original file line number Diff line number Diff line change
@@ -533,49 +533,58 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
bool fpmath_strict = !(fpmath_tf32 || fpmath_bf16 || fpmath_f16)
&& (mode & mode_strict) && (mode & mode_w_decomp);

auto add_mode_matches = [&](bool has_mode, const char *(*match)(Type)) {
auto add_mode_matches = [&](bool has_mode, bool optional,
const char *(*match)(Type)) {
if (!has_mode) return;
auto &def = base.selector.precisions;
if (match(problem_.Ta)) {
match_params.push_back(base);
if (optional) {
match_params.push_back(base);
match_params.back().selector.precisions[1] = def[1];
}
match_params.back().selector.precisions[0] = match(problem_.Ta);
match_params.back().selector.precisions[1] = def[1];
}
if (match(problem_.Tb)) {
match_params.push_back(base);
match_params.back().selector.precisions[0] = def[0];
if (optional) {
match_params.push_back(base);
match_params.back().selector.precisions[0] = def[0];
}
match_params.back().selector.precisions[1] = match(problem_.Tb);
}
if (match(problem_.Ta) && match(problem_.Tb)) {
match_params.push_back(base);
if (optional) match_params.push_back(base);
match_params.back().selector.precisions[0] = match(problem_.Ta);
match_params.back().selector.precisions[1] = match(problem_.Tb);
}
};

add_mode_matches(fpmath_tf32, [](Type dt) -> const char * {
if (dt == Type::f32) { return "T"; }
return nullptr;
});

add_mode_matches(fpmath_bf16, [](Type dt) -> const char * {
if (dt == Type::f32) { return "[SB]"; }
if (dt.isInt8() || dt.isInt4()) return "[OB]";
if (dt.isF8()) return "B";
return nullptr;
});

add_mode_matches(fpmath_f16, [](Type dt) -> const char * {
if (dt == Type::f32) { return "[SH]"; }
if (dt.isInt8() || dt.isInt4()) return "[OH]";
if (dt.isF8()) return "H";
return nullptr;
});

add_mode_matches(!(fpmath_f16 || fpmath_bf16), [](Type dt) -> const char * {
if (dt.isInt4()) return "[FO]";
return nullptr;
});
add_mode_matches(
fpmath_tf32, /*optional=*/true, [](Type dt) -> const char * {
if (dt == Type::f32) { return "T"; }
return nullptr;
});

add_mode_matches(
fpmath_bf16, /*optional=*/false, [](Type dt) -> const char * {
if (dt == Type::f32) { return "[SB]"; }
if (dt.isInt8() || dt.isInt4()) return "[OB]";
if (dt.isF8()) return "B";
return nullptr;
});

add_mode_matches(
fpmath_f16, /*optional=*/false, [](Type dt) -> const char * {
if (dt == Type::f32) { return "[SH]"; }
if (dt.isInt8() || dt.isInt4()) return "[OH]";
if (dt.isF8()) return "H";
return nullptr;
});

add_mode_matches(!(fpmath_f16 || fpmath_bf16), /*optional=*/false,
[](Type dt) -> const char * {
if (dt.isInt4()) return "[FO]";
return nullptr;
});

if (fpmath_strict) {
if (problem_.Tb.isInt4() && !(fpmath_f16 || fpmath_bf16)) {
@@ -588,7 +597,7 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
= match_params.back().selector.precisions[1];
}
}
add_mode_matches(true, [](Type dt) -> const char * {
add_mode_matches(true, /*optional=*/false, [](Type dt) -> const char * {
if (dt.isFP4()) return "E";
return nullptr;
});
7 changes: 6 additions & 1 deletion src/gpu/intel/jit/gemm/include/kernel_catalog.hpp
Original file line number Diff line number Diff line change
@@ -86,8 +86,13 @@ struct Selector {

friend bool operator<(const Selector &sel1, const Selector &sel2) {
auto tupleize = [](const Selector &sel) {
bool compoundA = sel.precisions[0][0] == '[';
bool compoundB = sel.precisions[1][0] == '[';
return std::make_tuple(sel.hw,
sel.precisions[0][0] & 0x1F, sel.precisions[1][0] & 0x1F,
sel.precisions[0][0] & 0x1F,
compoundA ? sel.precisions[0][2] & 0x1F : 'a',
sel.precisions[1][0] & 0x1F,
compoundB ? sel.precisions[1][2] & 0x1F : 'b',
sel.layouts[0][0], sel.layouts[1][0]);
};
return tupleize(sel1) < tupleize(sel2);
170 changes: 94 additions & 76 deletions src/gpu/intel/jit/gemm/selector/db/kernel.db

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/gpu/intel/ocl/gemm/ref_gemm.hpp
Original file line number Diff line number Diff line change
@@ -214,6 +214,7 @@ struct ref_gemm_t : public gpu_gemm_t {
DNNL_ARG_A, DNNL_ARG_B, DNNL_ARG_C};
for (int arg : supported_args) {
if (!zp.has_default_values(arg)) {
if (arg != DNNL_ARG_C) return false;
const int mask = zp.get_mask(arg);
if (mask > 0) return false;
}