@@ -273,6 +273,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
273
273
bool src_hf8 = (src_type == ngen::DataType::hf8);
274
274
bool src_df = (src_type == ngen::DataType::df);
275
275
bool src_xf = src_bf || src_f || src_hf || src_df;
276
+ bool src_f4_e2m1 = (src_type == ngen_f4_e2m1 ());
276
277
bool f_to_xf = (src_f && (dst_bf || dst_hf));
277
278
bool native_bf16 = host->exec_cfg ().hw ().systolic_support ();
278
279
op_plan_t plan = grf_size;
@@ -324,6 +325,30 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
324
325
auto shl16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
325
326
host->eshl (mod, dst, src, 16 );
326
327
};
328
+ auto unpack_u4 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
329
+ auto dhs = dst.getHS ();
330
+ auto shs = src.getHS ();
331
+ auto esize = mod.getExecSize ();
332
+
333
+ if (shs == 1 && esize > 1 ) {
334
+ auto half_esize = esize / 2 ;
335
+ auto s = src, d = dst;
336
+ d.setRegion (dst.getVS (), (dst.getWidth () + 1 ) / 2 , dhs * 2 );
337
+ s.setOffset (src.getOffset () / 2 );
338
+ s.setRegion ((src.getVS () + 1 ) / 2 , (dst.getWidth () + 1 ) / 2 , shs);
339
+ s.setType (ngen::DataType::ub);
340
+ host->and_ (half_esize, d, s, 0xF );
341
+ d.setOffset (d.getOffset () + dhs);
342
+ host->shr (half_esize | host->ExecutionOffset (half_esize), d, s, 4 );
343
+ } else {
344
+ if ((src.getOffset () & 1 ) == 0 ) {
345
+ host->and_ (esize, dst, src, 0xF );
346
+ } else {
347
+ host->shr (esize, dst, src, 4 );
348
+ }
349
+ }
350
+ };
351
+
327
352
auto cvt_f32_to_bf16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
328
353
auto exec_size = mod.getExecSize ();
329
354
host->add (mod, src, src, -0x8000 );
@@ -338,6 +363,75 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
338
363
host->emov (mod, dst, src);
339
364
};
340
365
366
+ if (src_f4_e2m1 && dst_hf) {
367
+ const int nregs = utils::div_up (2 * width, grf_size);
368
+ auto tmp = lex_scope.alloc_reg_buf_data (nregs).format (
369
+ 0 , width, 1 , ngen::DataType::uw);
370
+ int step = get_step ();
371
+ for (int i = 0 ; i < width; i += step) {
372
+ step = std::min (step, width - i);
373
+ step = utils::rnd_down_pow2 (step);
374
+ int esize = step;
375
+ auto s = src.subregister (i, esize, src_stride);
376
+ auto d = dst.subregister (i, esize, dst_stride);
377
+ auto t = tmp.subregister (0 , esize, 1 , ngen::DataType::uw);
378
+ plan (unpack_u4, esize, t (1 ), s (src_stride));
379
+ host->eshl (esize, d.uw ()(dst_stride), t (1 ), 9 );
380
+ host->eshl (esize, t (1 ), t (1 ), 12 );
381
+ host->and_ (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), 0x0E00 );
382
+ host->mul (esize, d (dst_stride), d (dst_stride),
383
+ ngen::Immediate::hf (0x7400 ));
384
+ bfn0xCA (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), t (1 ),
385
+ 0x8000 );
386
+ }
387
+ return ;
388
+ }
389
+
390
+ if (src_f4_e2m1 && (dst_f || dst_bf)) {
391
+ // shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
392
+ // f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
393
+ // position of an f32) to the equivalent f32.
394
+ constexpr float shift = 85070591730234615865843651857942052864 .f ;
395
+ int step = get_step ();
396
+ const int nregs = utils::div_up (4 * step, grf_size);
397
+ auto tmp0 = lex_scope.alloc_reg_buf_data (nregs).format (
398
+ 0 , width, 1 , ngen::DataType::ud);
399
+ reg_buf_data_t tmp1;
400
+ int tmp_stride = 1 ;
401
+ if (dst_bf)
402
+ auto tmp1 = lex_scope.alloc_reg_buf_data (nregs).format (
403
+ 0 , width, 1 , ngen::DataType::f);
404
+ for (int i = 0 ; i < width; i += step) {
405
+ step = std::min (step, width - i);
406
+ step = utils::rnd_down_pow2 (step);
407
+ int esize = step;
408
+ auto s = src.subregister (i, esize, src_stride);
409
+ auto d = dst.subregister (i, esize, dst_stride);
410
+ ngen::Subregister t1;
411
+ if (dst_bf) {
412
+ t1 = tmp1.subregister (0 , esize, 1 );
413
+ std::swap (t1, d);
414
+ std::swap (tmp_stride, dst_stride);
415
+ }
416
+ auto t0 = tmp0.subregister (0 , esize, 2 , ngen::DataType::uw);
417
+ plan (unpack_u4, esize, t0 (2 ), s (src_stride));
418
+ host->eshl (esize, d.ud ()(dst_stride), t0 (2 ), 22 );
419
+ host->eshl (esize, t0 (2 ), t0 (2 ), 12 );
420
+ host->and_ (
421
+ esize, d.ud ()(dst_stride), d.ud ()(dst_stride), 0x01C00000 );
422
+ host->mul (esize, d (dst_stride), d (dst_stride),
423
+ ngen::Immediate::f (shift));
424
+ bfn0xCA (esize, d.uw (1 )(2 * dst_stride), d.uw (1 )(2 * dst_stride),
425
+ t0 (2 ), 0x8000 );
426
+ if (dst_bf) {
427
+ std::swap (tmp_stride, dst_stride);
428
+ std::swap (t1, d);
429
+ host->emov (esize, d.uw ()(dst_stride), t1.uw (1 )(2 ));
430
+ }
431
+ }
432
+ return ;
433
+ }
434
+
341
435
// bf16 -> f32:
342
436
// - bf16 must be packed: use left shift instead.
343
437
if (src_bf && dst_f) {
0 commit comments