Skip to content

Commit 7c6f622

Browse files
committed
xe: jit: reorder: enable f4_e2m1 up-convert
1 parent 3f604ab commit 7c6f622

File tree

2 files changed

+135
-4
lines changed

2 files changed

+135
-4
lines changed

src/gpu/intel/jit/codegen/reorder.hpp

+112-3
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
279279
bool src_hf8 = (src_type == ngen::DataType::hf8);
280280
bool src_df = (src_type == ngen::DataType::df);
281281
bool src_xf = src_bf || src_f || src_hf || src_df;
282+
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1());
282283
bool f_to_xf = (src_f && (dst_bf || dst_hf));
283284
bool native_bf16 = host->exec_cfg().hw().systolic_support();
284285
op_plan_t plan = grf_size;
@@ -333,6 +334,48 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
333334
auto shl16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
334335
host->eshl(mod, dst, src, 16);
335336
};
337+
auto mov = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
338+
host->emov(mod, dst, src);
339+
};
340+
341+
auto u4_lower = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
342+
host->and_(mod, dst, src, 0xF);
343+
};
344+
345+
auto u4_upper = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
346+
host->shr(mod, dst, src, 4);
347+
};
348+
349+
auto cvt_u4_to_uw = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
350+
auto dhs = dst.getHS();
351+
auto shs = src.getHS();
352+
auto esize = mod.getExecSize();
353+
354+
auto s = src;
355+
s.setOffset(src.getOffset() / 2);
356+
s.setRegion((src.getVS() + 1) / 2, (dst.getWidth() + 1) / 2, shs);
357+
s.setType(ngen::DataType::ub);
358+
359+
if (shs == 1 && esize > 1) {
360+
auto half_esize = esize / 2;
361+
auto d = dst;
362+
d.setRegion(dst.getVS(), (dst.getWidth() + 1) / 2, dhs * 2);
363+
plan(u4_lower, half_esize, d, s);
364+
d.setOffset(d.getOffset() + dhs);
365+
auto eoff = host->ExecutionOffset(half_esize);
366+
plan(u4_upper, half_esize | eoff, d, s);
367+
} else {
368+
auto d = dst;
369+
d.setType(ngen::DataType::uw);
370+
d.setRegion(2 * d.getWidth(), d.getWidth(), 2);
371+
if ((src.getOffset() & 1) == 0)
372+
plan(u4_lower, esize, d, s);
373+
else
374+
plan(u4_upper, esize, d, s);
375+
if (dst.getHS() == 1) plan(mov, esize, dst, d);
376+
}
377+
};
378+
336379
auto cvt_f32_to_bf16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
337380
auto exec_size = mod.getExecSize();
338381
host->add(mod, src, src, -0x8000);
@@ -343,9 +386,75 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
343386
host->emov(mod, dst, src);
344387
host->add(mod | host->f0, dst, dst, 1);
345388
};
346-
auto mov = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
347-
host->emov(mod, dst, src);
348-
};
389+
390+
if (src_f4_e2m1 && dst_hf) {
391+
int step = get_step();
392+
const int nregs = utils::div_up(4 * step, grf_size);
393+
auto tmp = lex_scope.alloc_reg_buf_data(nregs).format(
394+
0, 2 * width, 1, ngen::DataType::uw);
395+
for (int i = 0; i < width; i += step) {
396+
step = std::min(step, width - i);
397+
step = utils::rnd_down_pow2(step);
398+
int esize = step;
399+
auto s = src.subregister(i, esize, src_stride);
400+
auto d = dst.subregister(i, esize, dst_stride);
401+
auto t = tmp.subregister(0, esize, 1, ngen::DataType::uw);
402+
plan(cvt_u4_to_uw, esize, t(1), s(src_stride));
403+
host->eshl(esize, d.uw()(dst_stride), t(1), 9);
404+
host->eshl(esize, t(1), t(1), 12);
405+
host->and_(esize, d.uw()(dst_stride), d.uw()(dst_stride), 0x0E00);
406+
host->mul(esize, d(dst_stride), d(dst_stride),
407+
ngen::Immediate::hf(0x7400));
408+
bfn0xCA(esize, d.uw()(dst_stride), d.uw()(dst_stride), t(1),
409+
0x8000);
410+
}
411+
return;
412+
}
413+
414+
if (src_f4_e2m1 && (dst_f || dst_bf)) {
415+
// shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
416+
// f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
417+
// position of an f32) to the equivalent f32.
418+
constexpr float shift = 85070591730234615865843651857942052864.f;
419+
int step = get_step();
420+
const int nregs = utils::div_up(4 * step, grf_size);
421+
auto tmp0 = lex_scope.alloc_reg_buf_data(nregs).format(
422+
0, 2 * width, 1, ngen::DataType::uw);
423+
reg_buf_data_t tmp1;
424+
int tmp_stride = 1;
425+
if (dst_bf)
426+
auto tmp1 = lex_scope.alloc_reg_buf_data(nregs).format(
427+
0, width, 1, ngen::DataType::f);
428+
for (int i = 0; i < width; i += step) {
429+
step = std::min(step, width - i);
430+
step = utils::rnd_down_pow2(step);
431+
int esize = step;
432+
auto s = src.subregister(i, esize, src_stride);
433+
auto d = dst.subregister(i, esize, dst_stride);
434+
ngen::Subregister t1;
435+
if (dst_bf) {
436+
t1 = tmp1.subregister(0, esize, 1);
437+
std::swap(t1, d);
438+
std::swap(tmp_stride, dst_stride);
439+
}
440+
auto t0 = tmp0.subregister(0, esize, 2);
441+
plan(cvt_u4_to_uw, esize, t0(2), s(src_stride));
442+
host->eshl(esize, d.ud()(dst_stride), t0(2), 22);
443+
host->eshl(esize, t0(2), t0(2), 12);
444+
host->and_(
445+
esize, d.ud()(dst_stride), d.ud()(dst_stride), 0x01C00000);
446+
host->mul(esize, d(dst_stride), d(dst_stride),
447+
ngen::Immediate::f(shift));
448+
bfn0xCA(esize, d.uw(1)(2 * dst_stride), d.uw(1)(2 * dst_stride),
449+
t0(2), 0x8000);
450+
if (dst_bf) {
451+
std::swap(tmp_stride, dst_stride);
452+
std::swap(t1, d);
453+
host->emov(esize, d.uw()(dst_stride), t1.uw(1)(2));
454+
}
455+
}
456+
return;
457+
}
349458

350459
// bf16 -> f32:
351460
// - bf16 must be packed: use left shift instead.

src/gpu/intel/jit/reorder/gen_reorder.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,14 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
8080
src_engine == dst_engine && src_engine->kind() == engine_kind::gpu,
8181
VERBOSE_BAD_ENGINE_KIND);
8282
VDISPATCH_REORDER(utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
83-
s32, s8, u8, f64),
83+
f4_e2m1, s32, s8, u8, f64),
8484
VERBOSE_UNSUPPORTED_DT);
8585
VDISPATCH_REORDER(utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3,
8686
s32, s8, u8, f64),
8787
VERBOSE_UNSUPPORTED_DT);
88+
VDISPATCH_REORDER(IMPLICATION(src_dt == f4_e2m1,
89+
utils::one_of(dst_dt, f32, f16, bf16)),
90+
VERBOSE_UNSUPPORTED_DT);
8891
VDISPATCH_REORDER(IMPLICATION(src_dt == f16 || dst_dt == f16,
8992
device_info->has_native(f16)),
9093
VERBOSE_UNSUPPORTED_DT_CFG);
@@ -104,6 +107,7 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
104107
VDISPATCH_REORDER(
105108
IMPLICATION(dst_dt == f64, utils::one_of(src_dt, f32, f64)),
106109
VERBOSE_UNSUPPORTED_DT_CFG);
110+
107111
using sm = dnnl_primitive_attr::skip_mask_t;
108112
auto skip_mask
109113
= sm::post_ops | sm::zero_points | sm::scales | sm::rounding_mode;
@@ -168,6 +172,24 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine,
168172
exec_cfg.set_simd(16);
169173
cfg = std::make_shared<reorder_config_t>(exec_cfg, src_layout, dst_layout);
170174
cfg->set_zp_cfg(zp_cfg);
175+
176+
if (src_dt == f4_e2m1) {
177+
auto dims = cfg->tiles().front().dims();
178+
dim_t contiguous_inner_elems = 1;
179+
for (auto &b : cfg->src_layout().user().blocks()) {
180+
if (b.block == 1) continue;
181+
if ((dim_t)b.stride != contiguous_inner_elems) break;
182+
if (b.block > dims[b.dim_idx]) {
183+
if (b.block % dims[b.dim_idx] == 0)
184+
contiguous_inner_elems *= dims[b.dim_idx];
185+
break;
186+
}
187+
contiguous_inner_elems *= b.block;
188+
dims[b.dim_idx] /= b.block;
189+
}
190+
VDISPATCH_REORDER(contiguous_inner_elems % 8 == 0,
191+
VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src");
192+
}
171193
VDISPATCH_REORDER_SC(
172194
init_kernel_info(), "kernel initialization unsuccessful");
173195
VDISPATCH_REORDER_SC(maybe_create_zp_precompute_conv_pd(dst_engine),

0 commit comments

Comments
 (0)