Skip to content

Commit 8a874c5

Browse files
committed
xe: jit: reorder: enable f4_e2m1 up-convert
1 parent ef01d18 commit 8a874c5

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

src/gpu/intel/jit/codegen/reorder.hpp

+94
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
273273
bool src_hf8 = (src_type == ngen::DataType::hf8);
274274
bool src_df = (src_type == ngen::DataType::df);
275275
bool src_xf = src_bf || src_f || src_hf || src_df;
276+
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1());
276277
bool f_to_xf = (src_f && (dst_bf || dst_hf));
277278
bool native_bf16 = host->exec_cfg().hw().systolic_support();
278279
op_plan_t plan = grf_size;
@@ -324,6 +325,30 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
324325
auto shl16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
325326
host->eshl(mod, dst, src, 16);
326327
};
328+
auto unpack_u4 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
329+
auto dhs = dst.getHS();
330+
auto shs = src.getHS();
331+
auto esize = mod.getExecSize();
332+
333+
if (shs == 1 && esize > 1) {
334+
auto half_esize = esize / 2;
335+
auto s = src, d = dst;
336+
d.setRegion(dst.getVS(), (dst.getWidth() + 1) / 2, dhs * 2);
337+
s.setOffset(src.getOffset() / 2);
338+
s.setRegion((src.getVS() + 1) / 2, (dst.getWidth() + 1) / 2, shs);
339+
s.setType(ngen::DataType::ub);
340+
host->and_(half_esize, d, s, 0xF);
341+
d.setOffset(d.getOffset() + dhs);
342+
host->shr(half_esize | host->ExecutionOffset(half_esize), d, s, 4);
343+
} else {
344+
if ((src.getOffset() & 1) == 0) {
345+
host->and_(esize, dst, src, 0xF);
346+
} else {
347+
host->shr(esize, dst, src, 4);
348+
}
349+
}
350+
};
351+
327352
auto cvt_f32_to_bf16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
328353
auto exec_size = mod.getExecSize();
329354
host->add(mod, src, src, -0x8000);
@@ -338,6 +363,75 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
338363
host->emov(mod, dst, src);
339364
};
340365

366+
if (src_f4_e2m1 && dst_hf) {
367+
const int nregs = utils::div_up(2 * width, grf_size);
368+
auto tmp = lex_scope.alloc_reg_buf_data(nregs).format(
369+
0, width, 1, ngen::DataType::uw);
370+
int step = get_step();
371+
for (int i = 0; i < width; i += step) {
372+
step = std::min(step, width - i);
373+
step = utils::rnd_down_pow2(step);
374+
int esize = step;
375+
auto s = src.subregister(i, esize, src_stride);
376+
auto d = dst.subregister(i, esize, dst_stride);
377+
auto t = tmp.subregister(0, esize, 1, ngen::DataType::uw);
378+
plan(unpack_u4, esize, t(1), s(src_stride));
379+
host->eshl(esize, d.uw()(dst_stride), t(1), 9);
380+
host->eshl(esize, t(1), t(1), 12);
381+
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride), 0x0E00);
382+
host->mul(esize, d(dst_stride), d(dst_stride),
383+
ngen::Immediate::hf(0x7400));
384+
bfn0xCA(esize, d.uw()(dst_stride), d.uw()(dst_stride), t(1),
385+
0x8000);
386+
}
387+
return;
388+
}
389+
390+
if (src_f4_e2m1 && (dst_f || dst_bf)) {
391+
// shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
392+
// f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
393+
// position of an f32) to the equivalent f32.
394+
constexpr float shift = 85070591730234615865843651857942052864.f;
395+
int step = get_step();
396+
const int nregs = utils::div_up(4 * step, grf_size);
397+
auto tmp0 = lex_scope.alloc_reg_buf_data(nregs).format(
398+
0, width, 1, ngen::DataType::ud);
399+
reg_buf_data_t tmp1;
400+
int tmp_stride = 1;
401+
if (dst_bf)
402+
auto tmp1 = lex_scope.alloc_reg_buf_data(nregs).format(
403+
0, width, 1, ngen::DataType::f);
404+
for (int i = 0; i < width; i += step) {
405+
step = std::min(step, width - i);
406+
step = utils::rnd_down_pow2(step);
407+
int esize = step;
408+
auto s = src.subregister(i, esize, src_stride);
409+
auto d = dst.subregister(i, esize, dst_stride);
410+
ngen::Subregister t1;
411+
if (dst_bf) {
412+
t1 = tmp1.subregister(0, esize, 1);
413+
std::swap(t1, d);
414+
std::swap(tmp_stride, dst_stride);
415+
}
416+
auto t0 = tmp0.subregister(0, esize, 2, ngen::DataType::uw);
417+
plan(unpack_u4, esize, t0(2), s(src_stride));
418+
host->eshl(esize, d.ud()(dst_stride), t0(2), 22);
419+
host->eshl(esize, t0(2), t0(2), 12);
420+
host->and_(
421+
esize, d.ud()(dst_stride), d.ud()(dst_stride), 0x01C00000);
422+
host->mul(esize, d(dst_stride), d(dst_stride),
423+
ngen::Immediate::f(shift));
424+
bfn0xCA(esize, d.uw(1)(2 * dst_stride), d.uw(1)(2 * dst_stride),
425+
t0(2), 0x8000);
426+
if (dst_bf) {
427+
std::swap(tmp_stride, dst_stride);
428+
std::swap(t1, d);
429+
host->emov(esize, d.uw()(dst_stride), t1.uw(1)(2));
430+
}
431+
}
432+
return;
433+
}
434+
341435
// bf16 -> f32:
342436
// - bf16 must be packed: use left shift instead.
343437
if (src_bf && dst_f) {

src/gpu/intel/jit/reorder/gen_reorder.cpp

+15-1
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,14 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
8080
src_engine == dst_engine && src_engine->kind() == engine_kind::gpu,
8181
VERBOSE_BAD_ENGINE_KIND);
8282
VDISPATCH_REORDER(utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
83-
s32, s8, u8, f64),
83+
f4_e2m1, s32, s8, u8, f64),
8484
VERBOSE_UNSUPPORTED_DT);
8585
VDISPATCH_REORDER(utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
8686
s32, s8, u8, f64),
8787
VERBOSE_UNSUPPORTED_DT);
88+
VDISPATCH_REORDER(IMPLICATION(src_dt == f4_e2m1,
89+
utils::one_of(dst_dt, f32, f16, bf16)),
90+
VERBOSE_UNSUPPORTED_DT);
8891
VDISPATCH_REORDER(IMPLICATION(src_dt == f16 || dst_dt == f16,
8992
device_info->has_native(f16)),
9093
VERBOSE_UNSUPPORTED_DT_CFG);
@@ -104,6 +107,7 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
104107
VDISPATCH_REORDER(
105108
IMPLICATION(dst_dt == f64, utils::one_of(src_dt, f32, f64)),
106109
VERBOSE_UNSUPPORTED_DT_CFG);
110+
107111
using sm = dnnl_primitive_attr::skip_mask_t;
108112
auto skip_mask
109113
= sm::post_ops | sm::zero_points | sm::scales | sm::rounding_mode;
@@ -168,6 +172,16 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
168172
exec_cfg.set_simd(16);
169173
cfg = std::make_shared<reorder_config_t>(exec_cfg, src_layout, dst_layout);
170174
cfg->set_zp_cfg(zp_cfg);
175+
176+
if (src_dt == f4_e2m1) {
177+
auto &tile = cfg->tiles().front();
178+
for (auto &b : cfg->src_layout().user().blocks()) {
179+
if (b.block == 1) continue;
180+
VDISPATCH_REORDER(tile.dims()[b.dim_idx] % 8 == 0,
181+
VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src");
182+
break;
183+
}
184+
}
171185
VDISPATCH_REORDER_SC(
172186
init_kernel_info(), "kernel initialization unsuccessful");
173187
VDISPATCH_REORDER_SC(maybe_create_zp_precompute_conv_pd(dst_engine),

0 commit comments

Comments
 (0)