@@ -280,6 +280,8 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
280
280
bool src_df = (src_type == ngen::DataType::df);
281
281
bool src_xf = src_bf || src_f || src_hf || src_df;
282
282
bool src_f4_e2m1 = (src_type == ngen_f4_e2m1 ());
283
+ bool src_f4_e3m0 = (src_type == ngen_f4_e3m0 ());
284
+ bool src_f4 = src_f4_e2m1 || src_f4_e3m0;
283
285
bool f_to_xf = (src_f && (dst_bf || dst_hf));
284
286
bool native_bf16 = host->exec_cfg ().hw ().systolic_support ();
285
287
op_plan_t plan = grf_size;
@@ -387,7 +389,9 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
387
389
host->add (mod | host->f0 , dst, dst, 1 );
388
390
};
389
391
390
- if (src_f4_e2m1 && dst_hf) {
392
+ if (src_f4 && dst_hf) {
393
+ const auto scale = src_f4_e2m1 ? 0x7400 : 0x6c00 ;
394
+ const auto shift = src_f4_e2m1 ? 9 : 10 ;
391
395
int step = get_step ();
392
396
const int nregs = utils::div_up (4 * step, grf_size);
393
397
auto tmp = lex_scope.alloc_reg_buf_data (nregs).format (
@@ -400,22 +404,25 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
400
404
auto d = dst.subregister (i, esize, dst_stride);
401
405
auto t = tmp.subregister (0 , esize, 1 , ngen::DataType::uw);
402
406
plan (cvt_u4_to_uw, esize, t (1 ), s (src_stride));
403
- host->eshl (esize, d.uw ()(dst_stride), t (1 ), 9 );
407
+ host->eshl (esize, d.uw ()(dst_stride), t (1 ), shift );
404
408
host->eshl (esize, t (1 ), t (1 ), 12 );
405
- host->and_ (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), 0x0E00 );
409
+ host->and_ (esize, d.uw ()(dst_stride), d.uw ()(dst_stride),
410
+ 0x7 << shift);
406
411
host->mul (esize, d (dst_stride), d (dst_stride),
407
- ngen::Immediate::hf (0x7400 ));
412
+ ngen::Immediate::hf (scale ));
408
413
bfn0xCA (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), t (1 ),
409
414
0x8000 );
410
415
}
411
416
return ;
412
417
}
413
418
414
- if (src_f4_e2m1 && (dst_f || dst_bf)) {
415
- // shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
416
- // f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
417
- // position of an f32) to the equivalent f32.
418
- constexpr float shift = 85070591730234615865843651857942052864 .f ;
419
+ if (src_f4 && (dst_f || dst_bf)) {
420
+ const auto scale = src_f4_e2m1 ? 0x7e800000 : 0x7d800000 ;
421
+ const auto shift = src_f4_e2m1 ? 22 : 23 ;
422
+ // shift = 2^N, used to shift a (positive) f32-aligned f4 (i.e. the
423
+ // mantissa bit is at the highest mantissa bit position of an f32) to
424
+ // the equivalent f32.
425
+ const float scale_f = utils::bit_cast<float >(scale);
419
426
int step = get_step ();
420
427
const int nregs = utils::div_up (4 * step, grf_size);
421
428
auto tmp0 = lex_scope.alloc_reg_buf_data (nregs).format (
@@ -439,12 +446,12 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
439
446
}
440
447
auto t0 = tmp0.subregister (0 , esize, 2 );
441
448
plan (cvt_u4_to_uw, esize, t0 (2 ), s (src_stride));
442
- host->eshl (esize, d.ud ()(dst_stride), t0 (2 ), 22 );
449
+ host->eshl (esize, d.ud ()(dst_stride), t0 (2 ), shift );
443
450
host->eshl (esize, t0 (2 ), t0 (2 ), 12 );
444
- host->and_ (
445
- esize, d. ud ()(dst_stride), d. ud ()(dst_stride), 0x01C00000 );
451
+ host->and_ (esize, d. ud ()(dst_stride), d. ud ()(dst_stride),
452
+ 0x7 << shift );
446
453
host->mul (esize, d (dst_stride), d (dst_stride),
447
- ngen::Immediate::f (shift ));
454
+ ngen::Immediate::f (scale_f ));
448
455
bfn0xCA (esize, d.uw (1 )(2 * dst_stride), d.uw (1 )(2 * dst_stride),
449
456
t0 (2 ), 0x8000 );
450
457
if (dst_bf) {
0 commit comments