diff --git a/src/gpu/intel/jit/emulation.hpp b/src/gpu/intel/jit/emulation.hpp index 6570f425b53..dd7b74f60ab 100644 --- a/src/gpu/intel/jit/emulation.hpp +++ b/src/gpu/intel/jit/emulation.hpp @@ -312,6 +312,12 @@ struct EmulationImplementation { // NOLINT(readability-identifier-naming) dst.setType(DataType::ud); src0.setType(DataType::uw); g.shl(mod, dst, src0, 16, loc); + } else if (src0.getType() == DataType::bf8 + && dst.getType() == DataType::f) { + RegData hfTmp = src0; + hfTmp.setType(DataType::uw); + g.shl(mod, hfTmp, src0.setType(DataType::ub), 8, loc); + g.mov(mod, dst, hfTmp.setType(DataType::hf), loc); } else g.mov(mod, dst, src0, loc); } diff --git a/src/gpu/intel/jit/gemm/gen_gemm.hpp b/src/gpu/intel/jit/gemm/gen_gemm.hpp index 8f9b1b3bdb3..b9cbcd84910 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm.hpp +++ b/src/gpu/intel/jit/gemm/gen_gemm.hpp @@ -57,7 +57,7 @@ struct gen_gemm_t : public gpu_gemm_t { // LIMITATIONS: // - runtime dims are not supported - auto attr_skip_mask = smask_t::scales | smask_t::post_ops + auto attr_skip_mask = smask_t::scales_data_type | smask_t::post_ops | smask_t::fpmath_mode | smask_t::accumulation_mode | smask_t::rounding_mode; auto &attr_zps = attr()->zero_points_; diff --git a/src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp b/src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp index c66f9d8608d..563be252fbc 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp +++ b/src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp @@ -46,6 +46,7 @@ static inline Type convert_dnnl_to_kernel_type(data_type_t type) { case data_type::bf16: return Type::bf16; case data_type::f8_e5m2: return Type::bf8; case data_type::f8_e4m3: return Type::hf8; + case data_type::e8m0: return Type::f8_e8m0; case data_type::f4_e2m1: return Type::f4_e2m1; case data_type::f4_e3m0: return Type::f4_e3m0; case data_type::s32: return Type::s32; diff --git a/src/gpu/intel/jit/gemm/generator/pieces/post_ops.cxx b/src/gpu/intel/jit/gemm/generator/pieces/post_ops.cxx index 4c6172bbd08..d8da5d21175 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/post_ops.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/post_ops.cxx @@ -131,12 +131,19 @@ void BLASKernelGenerator::binaryOp(BinaryOp op, int simd, const RegData &dst // Apply binary operation to C with a scalar operand. template -void BLASKernelGenerator::gemmScalarBinaryOpC(BinaryOp op, const Subregister &offset, - const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) +void BLASKernelGenerator::gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets, + const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco) { - auto offsetTc = offset.reinterpret(0, state.Tacc.ngen()); - if (offset != offsetTc) - emov(1, offsetTc, offset, strategy, state); + auto subOff = offsets[0].sub(0, Tco.ngen()); + auto Tacc = state.Tacc; + auto offsetTc = subOff.reinterpret(0, Tacc.ngen()); + if (subOff != offsetTc && !one_of(Tco, Type::f8_e8m0, Type::hf8)){ + emov(1, offsetTc, subOff, strategy, state); + } else { + vector repackLayout; + makeUnbackedRegLayout(Tacc, repackLayout, 1, 1, false); + copyRegisters(Tco, Tacc, repackLayout, repackLayout, offsets, offsets, strategy, state); + } if (op == BinaryOp::Div && one_of(state.Tacc, Type::f32, Type::f16)) { inv(1, offsetTc, offsetTc); op = BinaryOp::Mul; @@ -322,7 +329,7 @@ bool BLASKernelGenerator::gemmBinaryOpC(BinaryOp op, bool row, bool column, }); if (!row && !column) - gemmScalarBinaryOpC(op, CO_regs[0].sub(0, Tco.ngen()), problem, strategy, state); + gemmScalarBinaryOpC(op, CO_regs, problem, strategy, state, Tco); else gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem, strategy, state, Tco, CO_layout); } diff --git a/src/gpu/intel/jit/gemm/include/generator.hpp b/src/gpu/intel/jit/gemm/include/generator.hpp index 14ffe231715..9c9929449c0 100644 --- a/src/gpu/intel/jit/gemm/include/generator.hpp +++ b/src/gpu/intel/jit/gemm/include/generator.hpp @@ -349,7 +349,7 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) { void gemmAlphaScale(GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, bool cxCombine = true); void gemmBetaScale(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void binaryOp(BinaryOp op, int simd, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, CommonState &state); - void gemmScalarBinaryOpC(BinaryOp op, const ngen::Subregister &offset, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); + void gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco = Type::invalid); void gemmVectorBinaryOpC(BinaryOp op, bool column, const GRFMultirange &offsets, const ngen::Subregister &scale, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco = Type::invalid, std::vector CO_layout = std::vector(), int y0 = -1, int y1 = -1); void gemmRank1UpdateC(const GRFMultirange &r, const GRFMultirange &c, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state); void gemmCalcABOffsetAddrs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);