Skip to content

Commit 344a7af

Browse files
committed
xe: jit: reorder: enable f4_e3m0 up-convert
1 parent 2cafd27 commit 344a7af

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/gpu/intel/jit/codegen/reorder.hpp

+16-10
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
274274
bool src_df = (src_type == ngen::DataType::df);
275275
bool src_xf = src_bf || src_f || src_hf || src_df;
276276
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1());
277+
bool src_f4_e3m0 = (src_type == ngen_f4_e3m0());
278+
bool src_f4 = src_f4_e2m1 || src_f4_e3m0;
277279
bool f_to_xf = (src_f && (dst_bf || dst_hf));
278280
bool native_bf16 = host->exec_cfg().hw().systolic_support();
279281
op_plan_t plan = grf_size;
@@ -363,7 +365,9 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
363365
host->emov(mod, dst, src);
364366
};
365367

366-
if (src_f4_e2m1 && dst_hf) {
368+
if (src_f4 && dst_hf) {
369+
const auto scale = src_f4_e2m1 ? 0x7400 : 0x6c00;
370+
const auto shift = src_f4_e2m1 ? 9 : 10;
367371
const int nregs = utils::div_up(2 * width, grf_size);
368372
auto tmp = lex_scope.alloc_reg_buf_data(nregs).format(
369373
0, width, 1, ngen::DataType::uw);
@@ -376,22 +380,24 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
376380
auto d = dst.subregister(i, esize, dst_stride);
377381
auto t = tmp.subregister(0, esize, 1, ngen::DataType::uw);
378382
plan(unpack_u4, esize, t(1), s(src_stride));
379-
host->eshl(esize, d.uw()(dst_stride), t(1), 9);
383+
host->eshl(esize, d.uw()(dst_stride), t(1), shift);
380384
host->eshl(esize, t(1), t(1), 12);
381385
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride), 0x0E00);
382386
host->mul(esize, d(dst_stride), d(dst_stride),
383-
ngen::Immediate::hf(0x7400));
387+
ngen::Immediate::hf(scale));
384388
bfn0xCA(esize, d.uw()(dst_stride), d.uw()(dst_stride), t(1),
385389
0x8000);
386390
}
387391
return;
388392
}
389393

390-
if (src_f4_e2m1 && (dst_f || dst_bf)) {
391-
// shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
392-
// f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
393-
// position of an f32) to the equivalent f32.
394-
constexpr float shift = 85070591730234615865843651857942052864.f;
394+
if (src_f4 && (dst_f || dst_bf)) {
395+
const auto scale = src_f4_e2m1 ? 0x7e800000 : 0x7d800000;
396+
const auto shift = src_f4_e2m1 ? 22 : 23;
397+
// shift = 2^N, used to shift a (positive) f32-aligned f4 (i.e. the
398+
// mantissa bit is at the highest mantissa bit position of an f32) to
399+
// the equivalent f32.
400+
const float scale_f = utils::bit_cast<float>(scale);
395401
int step = get_step();
396402
const int nregs = utils::div_up(4 * step, grf_size);
397403
auto tmp0 = lex_scope.alloc_reg_buf_data(nregs).format(
@@ -415,12 +421,12 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
415421
}
416422
auto t0 = tmp0.subregister(0, esize, 2, ngen::DataType::uw);
417423
plan(unpack_u4, esize, t0(2), s(src_stride));
418-
host->eshl(esize, d.ud()(dst_stride), t0(2), 22);
424+
host->eshl(esize, d.ud()(dst_stride), t0(2), shift);
419425
host->eshl(esize, t0(2), t0(2), 12);
420426
host->and_(
421427
esize, d.ud()(dst_stride), d.ud()(dst_stride), 0x01C00000);
422428
host->mul(esize, d(dst_stride), d(dst_stride),
423-
ngen::Immediate::f(shift));
429+
ngen::Immediate::f(scale_f));
424430
bfn0xCA(esize, d.uw(1)(2 * dst_stride), d.uw(1)(2 * dst_stride),
425431
t0(2), 0x8000);
426432
if (dst_bf) {

src/gpu/intel/jit/reorder/gen_reorder.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
8080
src_engine == dst_engine && src_engine->kind() == engine_kind::gpu,
8181
VERBOSE_BAD_ENGINE_KIND);
8282
VDISPATCH_REORDER(utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
83-
f4_e2m1, s32, s8, u8, f64),
83+
f4_e3m0, f4_e2m1, s32, s8, u8, f64),
8484
VERBOSE_UNSUPPORTED_DT);
8585
VDISPATCH_REORDER(utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
8686
s32, s8, u8, f64),
8787
VERBOSE_UNSUPPORTED_DT);
88-
VDISPATCH_REORDER(IMPLICATION(src_dt == f4_e2m1,
88+
VDISPATCH_REORDER(IMPLICATION(utils::one_of(src_dt, f4_e3m0, f4_e2m1),
8989
utils::one_of(dst_dt, f32, f16, bf16)),
9090
VERBOSE_UNSUPPORTED_DT);
9191
VDISPATCH_REORDER(IMPLICATION(src_dt == f16 || dst_dt == f16,

0 commit comments

Comments
 (0)