Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Enable e8m0 JIT gemm scaling #2888

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/gpu/intel/jit/emulation.hpp
Original file line number Diff line number Diff line change
@@ -312,6 +312,12 @@ struct EmulationImplementation { // NOLINT(readability-identifier-naming)
dst.setType(DataType::ud);
src0.setType(DataType::uw);
g.shl(mod, dst, src0, 16, loc);
} else if (src0.getType() == DataType::bf8
&& dst.getType() == DataType::f) {
RegData hfTmp = src0;
hfTmp.setType(DataType::uw);
g.shl(mod, hfTmp, src0.setType(DataType::ub), 8, loc);
g.mov(mod, dst, hfTmp.setType(DataType::hf), loc);
} else
g.mov(mod, dst, src0, loc);
}
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ struct gen_gemm_t : public gpu_gemm_t {

// LIMITATIONS:
// - runtime dims are not supported
auto attr_skip_mask = smask_t::scales | smask_t::post_ops
auto attr_skip_mask = smask_t::scales_data_type | smask_t::post_ops
| smask_t::fpmath_mode | smask_t::accumulation_mode
| smask_t::rounding_mode;
auto &attr_zps = attr()->zero_points_;
1 change: 1 addition & 0 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ static inline Type convert_dnnl_to_kernel_type(data_type_t type) {
case data_type::bf16: return Type::bf16;
case data_type::f8_e5m2: return Type::bf8;
case data_type::f8_e4m3: return Type::hf8;
case data_type::e8m0: return Type::f8_e8m0;
case data_type::f4_e2m1: return Type::f4_e2m1;
case data_type::f4_e3m0: return Type::f4_e3m0;
case data_type::s32: return Type::s32;
19 changes: 13 additions & 6 deletions src/gpu/intel/jit/gemm/generator/pieces/post_ops.cxx
Original file line number Diff line number Diff line change
@@ -131,12 +131,19 @@ void BLASKernelGenerator<hw>::binaryOp(BinaryOp op, int simd, const RegData &dst

// Apply binary operation to C with a scalar operand.
template <HW hw>
void BLASKernelGenerator<hw>::gemmScalarBinaryOpC(BinaryOp op, const Subregister &offset,
const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state)
void BLASKernelGenerator<hw>::gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets,
const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco)
{
auto offsetTc = offset.reinterpret(0, state.Tacc.ngen());
if (offset != offsetTc)
emov(1, offsetTc, offset, strategy, state);
auto subOff = offsets[0].sub(0, Tco.ngen());
auto Tacc = state.Tacc;
auto offsetTc = subOff.reinterpret(0, Tacc.ngen());
if (subOff != offsetTc && !one_of(Tco, Type::f8_e8m0, Type::hf8)){
emov(1, offsetTc, subOff, strategy, state);
} else {
vector<RegisterBlock> repackLayout;
makeUnbackedRegLayout(Tacc, repackLayout, 1, 1, false);
copyRegisters(Tco, Tacc, repackLayout, repackLayout, offsets, offsets, strategy, state);
}
if (op == BinaryOp::Div && one_of(state.Tacc, Type::f32, Type::f16)) {
inv(1, offsetTc, offsetTc);
op = BinaryOp::Mul;
@@ -322,7 +329,7 @@ bool BLASKernelGenerator<hw>::gemmBinaryOpC(BinaryOp op, bool row, bool column,
});

if (!row && !column)
gemmScalarBinaryOpC(op, CO_regs[0].sub(0, Tco.ngen()), problem, strategy, state);
gemmScalarBinaryOpC(op, CO_regs, problem, strategy, state, Tco);
else
gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem, strategy, state, Tco, CO_layout);
}
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/include/generator.hpp
Original file line number Diff line number Diff line change
@@ -349,7 +349,7 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) {
void gemmAlphaScale(GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, bool cxCombine = true);
void gemmBetaScale(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
void binaryOp(BinaryOp op, int simd, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, CommonState &state);
void gemmScalarBinaryOpC(BinaryOp op, const ngen::Subregister &offset, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
void gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco = Type::invalid);
void gemmVectorBinaryOpC(BinaryOp op, bool column, const GRFMultirange &offsets, const ngen::Subregister &scale, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco = Type::invalid, std::vector<RegisterBlock> CO_layout = std::vector<RegisterBlock>(), int y0 = -1, int y1 = -1);
void gemmRank1UpdateC(const GRFMultirange &r, const GRFMultirange &c, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
void gemmCalcABOffsetAddrs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);