@@ -274,6 +274,8 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
274
274
bool src_df = (src_type == ngen::DataType::df);
275
275
bool src_xf = src_bf || src_f || src_hf || src_df;
276
276
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1 ());
277
+ bool src_f4_e3m0 = (src_type == ngen_f4_e3m0 ());
278
+ bool src_f4 = src_f4_e2m1 || src_f4_e3m0;
277
279
bool f_to_xf = (src_f && (dst_bf || dst_hf));
278
280
bool native_bf16 = host->exec_cfg ().hw ().systolic_support ();
279
281
op_plan_t plan = grf_size;
@@ -363,7 +365,9 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
363
365
host->emov (mod, dst, src);
364
366
};
365
367
366
- if (src_f4_e2m1 && dst_hf) {
368
+ if (src_f4 && dst_hf) {
369
+ const auto scale = src_f4_e2m1 ? 0x7400 : 0x6c00 ;
370
+ const auto shift = src_f4_e2m1 ? 9 : 10 ;
367
371
const int nregs = utils::div_up (2 * width, grf_size);
368
372
auto tmp = lex_scope.alloc_reg_buf_data (nregs).format (
369
373
0 , width, 1 , ngen::DataType::uw);
@@ -376,22 +380,24 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
376
380
auto d = dst.subregister (i, esize, dst_stride);
377
381
auto t = tmp.subregister (0 , esize, 1 , ngen::DataType::uw);
378
382
plan (unpack_u4, esize, t (1 ), s (src_stride));
379
- host->eshl (esize, d.uw ()(dst_stride), t (1 ), 9 );
383
+ host->eshl (esize, d.uw ()(dst_stride), t (1 ), shift );
380
384
host->eshl (esize, t (1 ), t (1 ), 12 );
381
385
host->and_ (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), 0x0E00 );
382
386
host->mul (esize, d (dst_stride), d (dst_stride),
383
- ngen::Immediate::hf (0x7400 ));
387
+ ngen::Immediate::hf (scale ));
384
388
bfn0xCA (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), t (1 ),
385
389
0x8000 );
386
390
}
387
391
return ;
388
392
}
389
393
390
- if (src_f4_e2m1 && (dst_f || dst_bf)) {
391
- // shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
392
- // f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
393
- // position of an f32) to the equivalent f32.
394
- constexpr float shift = 85070591730234615865843651857942052864 .f ;
394
+ if (src_f4 && (dst_f || dst_bf)) {
395
+ const auto scale = src_f4_e2m1 ? 0x7e800000 : 0x7d800000 ;
396
+ const auto shift = src_f4_e2m1 ? 22 : 23 ;
397
+ // shift = 2^N, used to shift a (positive) f32-aligned f4 (i.e. the
398
+ // mantissa bit is at the highest mantissa bit position of an f32) to
399
+ // the equivalent f32.
400
+ const float scale_f = utils::bit_cast<float >(scale);
395
401
int step = get_step ();
396
402
const int nregs = utils::div_up (4 * step, grf_size);
397
403
auto tmp0 = lex_scope.alloc_reg_buf_data (nregs).format (
@@ -415,12 +421,12 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
415
421
}
416
422
auto t0 = tmp0.subregister (0 , esize, 2 , ngen::DataType::uw);
417
423
plan (unpack_u4, esize, t0 (2 ), s (src_stride));
418
- host->eshl (esize, d.ud ()(dst_stride), t0 (2 ), 22 );
424
+ host->eshl (esize, d.ud ()(dst_stride), t0 (2 ), shift );
419
425
host->eshl (esize, t0 (2 ), t0 (2 ), 12 );
420
426
host->and_ (
421
427
esize, d.ud ()(dst_stride), d.ud ()(dst_stride), 0x01C00000 );
422
428
host->mul (esize, d (dst_stride), d (dst_stride),
423
- ngen::Immediate::f (shift ));
429
+ ngen::Immediate::f (scale_f ));
424
430
bfn0xCA (esize, d.uw (1 )(2 * dst_stride), d.uw (1 )(2 * dst_stride),
425
431
t0 (2 ), 0x8000 );
426
432
if (dst_bf) {
0 commit comments