Skip to content

Commit d660868

Browse files
committed
xe: jit: reorder: enable f4_e3m0 up-convert
1 parent 83551da commit d660868

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

src/gpu/intel/jit/codegen/reorder.hpp

+20-13
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
274274
bool src_df = (src_type == ngen::DataType::df);
275275
bool src_xf = src_bf || src_f || src_hf || src_df;
276276
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1());
277+
bool src_f4_e3m0 = (src_type == ngen_f4_e3m0());
278+
bool src_f4 = src_f4_e2m1 || src_f4_e3m0;
277279
bool f_to_xf = (src_f && (dst_bf || dst_hf));
278280
bool native_bf16 = host->exec_cfg().hw().systolic_support();
279281
op_plan_t plan = grf_size;
@@ -363,7 +365,9 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
363365
host->emov(mod, dst, src);
364366
};
365367

366-
if (src_f4_e2m1 && dst_hf) {
368+
if (src_f4 && dst_hf) {
369+
const auto scale = src_f4_e2m1 ? 0x7400 : 0x6c00;
370+
const auto shift = src_f4_e2m1 ? 9 : 10;
367371
const int nregs = utils::div_up(2 * width, grf_size);
368372
auto tmp = lex_scope.alloc_reg_buf_data(nregs).format(
369373
0, width, 1, ngen::DataType::uw);
@@ -376,22 +380,25 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
376380
auto d = dst.subregister(i, esize, dst_stride);
377381
auto t = tmp.subregister(0, esize, 1, ngen::DataType::uw);
378382
plan(unpack_u4, esize, t(1), s(src_stride));
379-
host->eshl(esize, d.uw()(dst_stride), t(1), 9);
383+
host->eshl(esize, d.uw()(dst_stride), t(1), shift);
380384
host->eshl(esize, t(1), t(1), 12);
381-
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride), 0x0E00);
385+
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride),
386+
0x7 << shift);
382387
host->mul(esize, d(dst_stride), d(dst_stride),
383-
ngen::Immediate::hf(0x7400));
388+
ngen::Immediate::hf(scale));
384389
bfn0xCA(esize, d.uw()(dst_stride), d.uw()(dst_stride), t(1),
385390
0x8000);
386391
}
387392
return;
388393
}
389394

390-
if (src_f4_e2m1 && (dst_f || dst_bf)) {
391-
// shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
392-
// f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
393-
// position of an f32) to the equivalent f32.
394-
constexpr float shift = 85070591730234615865843651857942052864.f;
395+
if (src_f4 && (dst_f || dst_bf)) {
396+
const auto scale = src_f4_e2m1 ? 0x7e800000 : 0x7d800000;
397+
const auto shift = src_f4_e2m1 ? 22 : 23;
398+
// shift = 2^N, used to shift a (positive) f32-aligned f4 (i.e. the
399+
// mantissa bit is at the highest mantissa bit position of an f32) to
400+
// the equivalent f32.
401+
const float scale_f = utils::bit_cast<float>(scale);
395402
int step = get_step();
396403
const int nregs = utils::div_up(4 * step, grf_size);
397404
auto tmp0 = lex_scope.alloc_reg_buf_data(nregs).format(
@@ -415,12 +422,12 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
415422
}
416423
auto t0 = tmp0.subregister(0, esize, 2, ngen::DataType::uw);
417424
plan(unpack_u4, esize, t0(2), s(src_stride));
418-
host->eshl(esize, d.ud()(dst_stride), t0(2), 22);
425+
host->eshl(esize, d.ud()(dst_stride), t0(2), shift);
419426
host->eshl(esize, t0(2), t0(2), 12);
420-
host->and_(
421-
esize, d.ud()(dst_stride), d.ud()(dst_stride), 0x01C00000);
427+
host->and_(esize, d.ud()(dst_stride), d.ud()(dst_stride),
428+
0x7 << shift);
422429
host->mul(esize, d(dst_stride), d(dst_stride),
423-
ngen::Immediate::f(shift));
430+
ngen::Immediate::f(scale_f));
424431
bfn0xCA(esize, d.uw(1)(2 * dst_stride), d.uw(1)(2 * dst_stride),
425432
t0(2), 0x8000);
426433
if (dst_bf) {

src/gpu/intel/jit/reorder/gen_reorder.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
8080
src_engine == dst_engine && src_engine->kind() == engine_kind::gpu,
8181
VERBOSE_BAD_ENGINE_KIND);
8282
VDISPATCH_REORDER(utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
83-
f4_e2m1, s32, s8, u8, f64),
83+
f4_e3m0, f4_e2m1, s32, s8, u8, f64),
8484
VERBOSE_UNSUPPORTED_DT);
8585
VDISPATCH_REORDER(utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
8686
s32, s8, u8, f64),
8787
VERBOSE_UNSUPPORTED_DT);
88-
VDISPATCH_REORDER(IMPLICATION(src_dt == f4_e2m1,
88+
VDISPATCH_REORDER(IMPLICATION(utils::one_of(src_dt, f4_e3m0, f4_e2m1),
8989
utils::one_of(dst_dt, f32, f16, bf16)),
9090
VERBOSE_UNSUPPORTED_DT);
9191
VDISPATCH_REORDER(IMPLICATION(src_dt == f16 || dst_dt == f16,
@@ -173,7 +173,7 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
173173
cfg = std::make_shared<reorder_config_t>(exec_cfg, src_layout, dst_layout);
174174
cfg->set_zp_cfg(zp_cfg);
175175

176-
if (src_dt == f4_e2m1) {
176+
if (utils::one_of(src_dt, f4_e2m1, f4_e3m0)) {
177177
auto dims = cfg->tiles().front().dims();
178178
dim_t contiguous_inner_elems = 1;
179179
for (auto &b : cfg->src_layout().user().blocks()) {

0 commit comments

Comments
 (0)