Skip to content

Commit 83cdcfb

Browse files
committed
xe: jit: reorder: enable f4_e3m0 up-convert
1 parent 7c6f622 commit 83cdcfb

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

src/gpu/intel/jit/codegen/reorder.hpp

+20-13
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
280280
bool src_df = (src_type == ngen::DataType::df);
281281
bool src_xf = src_bf || src_f || src_hf || src_df;
282282
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1());
283+
bool src_f4_e3m0 = (src_type == ngen_f4_e3m0());
284+
bool src_f4 = src_f4_e2m1 || src_f4_e3m0;
283285
bool f_to_xf = (src_f && (dst_bf || dst_hf));
284286
bool native_bf16 = host->exec_cfg().hw().systolic_support();
285287
op_plan_t plan = grf_size;
@@ -387,7 +389,9 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
387389
host->add(mod | host->f0, dst, dst, 1);
388390
};
389391

390-
if (src_f4_e2m1 && dst_hf) {
392+
if (src_f4 && dst_hf) {
393+
const auto scale = src_f4_e2m1 ? 0x7400 : 0x6c00;
394+
const auto shift = src_f4_e2m1 ? 9 : 10;
391395
int step = get_step();
392396
const int nregs = utils::div_up(4 * step, grf_size);
393397
auto tmp = lex_scope.alloc_reg_buf_data(nregs).format(
@@ -400,22 +404,25 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
400404
auto d = dst.subregister(i, esize, dst_stride);
401405
auto t = tmp.subregister(0, esize, 1, ngen::DataType::uw);
402406
plan(cvt_u4_to_uw, esize, t(1), s(src_stride));
403-
host->eshl(esize, d.uw()(dst_stride), t(1), 9);
407+
host->eshl(esize, d.uw()(dst_stride), t(1), shift);
404408
host->eshl(esize, t(1), t(1), 12);
405-
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride), 0x0E00);
409+
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride),
410+
0x7 << shift);
406411
host->mul(esize, d(dst_stride), d(dst_stride),
407-
ngen::Immediate::hf(0x7400));
412+
ngen::Immediate::hf(scale));
408413
bfn0xCA(esize, d.uw()(dst_stride), d.uw()(dst_stride), t(1),
409414
0x8000);
410415
}
411416
return;
412417
}
413418

414-
if (src_f4_e2m1 && (dst_f || dst_bf)) {
415-
// shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
416-
// f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
417-
// position of an f32) to the equivalent f32.
418-
constexpr float shift = 85070591730234615865843651857942052864.f;
419+
if (src_f4 && (dst_f || dst_bf)) {
420+
const auto scale = src_f4_e2m1 ? 0x7e800000 : 0x7d800000;
421+
const auto shift = src_f4_e2m1 ? 22 : 23;
422+
// shift = 2^N, used to shift a (positive) f32-aligned f4 (i.e. the
423+
// mantissa bit is at the highest mantissa bit position of an f32) to
424+
// the equivalent f32.
425+
const float scale_f = utils::bit_cast<float>(scale);
419426
int step = get_step();
420427
const int nregs = utils::div_up(4 * step, grf_size);
421428
auto tmp0 = lex_scope.alloc_reg_buf_data(nregs).format(
@@ -439,12 +446,12 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
439446
}
440447
auto t0 = tmp0.subregister(0, esize, 2);
441448
plan(cvt_u4_to_uw, esize, t0(2), s(src_stride));
442-
host->eshl(esize, d.ud()(dst_stride), t0(2), 22);
449+
host->eshl(esize, d.ud()(dst_stride), t0(2), shift);
443450
host->eshl(esize, t0(2), t0(2), 12);
444-
host->and_(
445-
esize, d.ud()(dst_stride), d.ud()(dst_stride), 0x01C00000);
451+
host->and_(esize, d.ud()(dst_stride), d.ud()(dst_stride),
452+
0x7 << shift);
446453
host->mul(esize, d(dst_stride), d(dst_stride),
447-
ngen::Immediate::f(shift));
454+
ngen::Immediate::f(scale_f));
448455
bfn0xCA(esize, d.uw(1)(2 * dst_stride), d.uw(1)(2 * dst_stride),
449456
t0(2), 0x8000);
450457
if (dst_bf) {

src/gpu/intel/jit/reorder/gen_reorder.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
8080
src_engine == dst_engine && src_engine->kind() == engine_kind::gpu,
8181
VERBOSE_BAD_ENGINE_KIND);
8282
VDISPATCH_REORDER(utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
83-
f4_e2m1, s32, s8, u8, f64),
83+
f4_e3m0, f4_e2m1, s32, s8, u8, f64),
8484
VERBOSE_UNSUPPORTED_DT);
8585
VDISPATCH_REORDER(utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
8686
s32, s8, u8, f64),
8787
VERBOSE_UNSUPPORTED_DT);
88-
VDISPATCH_REORDER(IMPLICATION(src_dt == f4_e2m1,
88+
VDISPATCH_REORDER(IMPLICATION(utils::one_of(src_dt, f4_e3m0, f4_e2m1),
8989
utils::one_of(dst_dt, f32, f16, bf16)),
9090
VERBOSE_UNSUPPORTED_DT);
9191
VDISPATCH_REORDER(IMPLICATION(src_dt == f16 || dst_dt == f16,
@@ -173,7 +173,7 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
173173
cfg = std::make_shared<reorder_config_t>(exec_cfg, src_layout, dst_layout);
174174
cfg->set_zp_cfg(zp_cfg);
175175

176-
if (src_dt == f4_e2m1) {
176+
if (utils::one_of(src_dt, f4_e2m1, f4_e3m0)) {
177177
auto dims = cfg->tiles().front().dims();
178178
dim_t contiguous_inner_elems = 1;
179179
for (auto &b : cfg->src_layout().user().blocks()) {

0 commit comments

Comments
 (0)