Skip to content

Commit 8155bf1

Browse files
xe: jit: gemm: enable broader typed scales support
1 parent dd4c792 commit 8155bf1

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

src/gpu/intel/jit/gemm/gen_gemm.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ struct gen_gemm_t : public gpu_gemm_t {
5757

5858
// LIMITATIONS:
5959
// - runtime dims are not supported
60-
auto attr_skip_mask = smask_t::scales | smask_t::post_ops
61-
| smask_t::fpmath_mode | smask_t::accumulation_mode
62-
| smask_t::rounding_mode;
60+
auto attr_skip_mask = smask_t::scales_runtime_data_type
61+
| smask_t::post_ops | smask_t::fpmath_mode
62+
| smask_t::accumulation_mode | smask_t::rounding_mode;
6363
auto &attr_zps = attr()->zero_points_;
6464

6565
dev_info_ = compute_engine->device_info();

src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ static inline Type convert_dnnl_to_kernel_type(data_type_t type) {
4646
case data_type::bf16: return Type::bf16;
4747
case data_type::f8_e5m2: return Type::bf8;
4848
case data_type::f8_e4m3: return Type::hf8;
49+
case data_type::e8m0: return Type::f8_e8m0;
4950
case data_type::f4_e2m1: return Type::f4_e2m1;
5051
case data_type::f4_e3m0: return Type::f4_e3m0;
5152
case data_type::s32: return Type::s32;

src/gpu/intel/jit/gemm/generator/pieces/post_ops.cxx

+13-6
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,19 @@ void BLASKernelGenerator<hw>::binaryOp(BinaryOp op, int simd, const RegData &dst
131131

132132
// Apply binary operation to C with a scalar operand.
133133
template <HW hw>
134-
void BLASKernelGenerator<hw>::gemmScalarBinaryOpC(BinaryOp op, const Subregister &offset,
135-
const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state)
134+
void BLASKernelGenerator<hw>::gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets,
135+
const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco)
136136
{
137-
auto offsetTc = offset.reinterpret(0, state.Tacc.ngen());
138-
if (offset != offsetTc)
139-
emov(1, offsetTc, offset, strategy, state);
137+
auto subOff = offsets[0].sub(0, Tco.ngen());
138+
auto Tacc = state.Tacc;
139+
auto offsetTc = subOff.reinterpret(0, Tacc.ngen());
140+
if (subOff != offsetTc && !one_of(Tco, Type::f8_e8m0, Type::hf8)){
141+
emov(1, offsetTc, subOff, strategy, state);
142+
} else {
143+
vector<RegisterBlock> repackLayout;
144+
makeUnbackedRegLayout(Tacc, repackLayout, 1, 1, false);
145+
copyRegisters(Tco, Tacc, repackLayout, repackLayout, offsets, offsets, strategy, state);
146+
}
140147
if (op == BinaryOp::Div && one_of(state.Tacc, Type::f32, Type::f16)) {
141148
inv(1, offsetTc, offsetTc);
142149
op = BinaryOp::Mul;
@@ -322,7 +329,7 @@ bool BLASKernelGenerator<hw>::gemmBinaryOpC(BinaryOp op, bool row, bool column,
322329
});
323330

324331
if (!row && !column)
325-
gemmScalarBinaryOpC(op, CO_regs[0].sub(0, Tco.ngen()), problem, strategy, state);
332+
gemmScalarBinaryOpC(op, CO_regs, problem, strategy, state, Tco);
326333
else
327334
gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem, strategy, state, Tco, CO_layout);
328335
}

src/gpu/intel/jit/gemm/include/generator.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ class BLASKernelGenerator : public GENERATOR_BASE(hw) {
349349
void gemmAlphaScale(GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, bool cxCombine = true);
350350
void gemmBetaScale(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
351351
void binaryOp(BinaryOp op, int simd, const ngen::RegData &dst, const ngen::RegData &src0, const ngen::RegData &src1, CommonState &state);
352-
void gemmScalarBinaryOpC(BinaryOp op, const ngen::Subregister &offset, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
352+
void gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco = Type::invalid);
353353
void gemmVectorBinaryOpC(BinaryOp op, bool column, const GRFMultirange &offsets, const ngen::Subregister &scale, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco = Type::invalid, std::vector<RegisterBlock> CO_layout = std::vector<RegisterBlock>(), int y0 = -1, int y1 = -1);
354354
void gemmRank1UpdateC(const GRFMultirange &r, const GRFMultirange &c, const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);
355355
void gemmCalcABOffsetAddrs(const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state);

0 commit comments

Comments
 (0)