@@ -279,6 +279,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
279
279
bool src_hf8 = (src_type == ngen::DataType::hf8);
280
280
bool src_df = (src_type == ngen::DataType::df);
281
281
bool src_xf = src_bf || src_f || src_hf || src_df;
282
+ bool src_f4_e2m1 = (src_type == ngen_f4_e2m1 ());
282
283
bool f_to_xf = (src_f && (dst_bf || dst_hf));
283
284
bool native_bf16 = host->exec_cfg ().hw ().systolic_support ();
284
285
op_plan_t plan = grf_size;
@@ -333,6 +334,48 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
333
334
auto shl16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
334
335
host->eshl (mod, dst, src, 16 );
335
336
};
337
+ auto mov = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
338
+ host->emov (mod, dst, src);
339
+ };
340
+
341
+ auto u4_lower = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
342
+ host->and_ (mod, dst, src, 0xF );
343
+ };
344
+
345
+ auto u4_upper = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
346
+ host->shr (mod, dst, src, 4 );
347
+ };
348
+
349
+ auto cvt_u4_to_uw = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
350
+ auto dhs = dst.getHS ();
351
+ auto shs = src.getHS ();
352
+ auto esize = mod.getExecSize ();
353
+
354
+ auto s = src;
355
+ s.setOffset (src.getOffset () / 2 );
356
+ s.setRegion ((src.getVS () + 1 ) / 2 , (dst.getWidth () + 1 ) / 2 , shs);
357
+ s.setType (ngen::DataType::ub);
358
+
359
+ if (shs == 1 && esize > 1 ) {
360
+ auto half_esize = esize / 2 ;
361
+ auto d = dst;
362
+ d.setRegion (dst.getVS (), (dst.getWidth () + 1 ) / 2 , dhs * 2 );
363
+ plan (u4_lower, half_esize, d, s);
364
+ d.setOffset (d.getOffset () + dhs);
365
+ auto eoff = host->ExecutionOffset (half_esize);
366
+ plan (u4_upper, half_esize | eoff, d, s);
367
+ } else {
368
+ auto d = dst;
369
+ d.setType (ngen::DataType::uw);
370
+ d.setRegion (2 * d.getWidth (), d.getWidth (), 2 );
371
+ if ((src.getOffset () & 1 ) == 0 )
372
+ plan (u4_lower, esize, d, s);
373
+ else
374
+ plan (u4_upper, esize, d, s);
375
+ if (dst.getHS () == 1 ) plan (mov, esize, dst, d);
376
+ }
377
+ };
378
+
336
379
auto cvt_f32_to_bf16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
337
380
auto exec_size = mod.getExecSize ();
338
381
host->add (mod, src, src, -0x8000 );
@@ -343,9 +386,75 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
343
386
host->emov (mod, dst, src);
344
387
host->add (mod | host->f0 , dst, dst, 1 );
345
388
};
346
- auto mov = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
347
- host->emov (mod, dst, src);
348
- };
389
+
390
+ if (src_f4_e2m1 && dst_hf) {
391
+ int step = get_step ();
392
+ const int nregs = utils::div_up (4 * step, grf_size);
393
+ auto tmp = lex_scope.alloc_reg_buf_data (nregs).format (
394
+ 0 , 2 * width, 1 , ngen::DataType::uw);
395
+ for (int i = 0 ; i < width; i += step) {
396
+ step = std::min (step, width - i);
397
+ step = utils::rnd_down_pow2 (step);
398
+ int esize = step;
399
+ auto s = src.subregister (i, esize, src_stride);
400
+ auto d = dst.subregister (i, esize, dst_stride);
401
+ auto t = tmp.subregister (0 , esize, 1 , ngen::DataType::uw);
402
+ plan (cvt_u4_to_uw, esize, t (1 ), s (src_stride));
403
+ host->eshl (esize, d.uw ()(dst_stride), t (1 ), 9 );
404
+ host->eshl (esize, t (1 ), t (1 ), 12 );
405
+ host->and_ (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), 0x0E00 );
406
+ host->mul (esize, d (dst_stride), d (dst_stride),
407
+ ngen::Immediate::hf (0x7400 ));
408
+ bfn0xCA (esize, d.uw ()(dst_stride), d.uw ()(dst_stride), t (1 ),
409
+ 0x8000 );
410
+ }
411
+ return ;
412
+ }
413
+
414
+ if (src_f4_e2m1 && (dst_f || dst_bf)) {
415
+ // shift = 0x7e800000:f = 2^126, used to shift a (positive) f32-aligned
416
+ // f4_e2m1 (i.e. the mantissa bit is at the highest mantissa bit
417
+ // position of an f32) to the equivalent f32.
418
+ constexpr float shift = 85070591730234615865843651857942052864 .f ;
419
+ int step = get_step ();
420
+ const int nregs = utils::div_up (4 * step, grf_size);
421
+ auto tmp0 = lex_scope.alloc_reg_buf_data (nregs).format (
422
+ 0 , 2 * width, 1 , ngen::DataType::uw);
423
+ reg_buf_data_t tmp1;
424
+ int tmp_stride = 1 ;
425
+ if (dst_bf)
426
+ auto tmp1 = lex_scope.alloc_reg_buf_data (nregs).format (
427
+ 0 , width, 1 , ngen::DataType::f);
428
+ for (int i = 0 ; i < width; i += step) {
429
+ step = std::min (step, width - i);
430
+ step = utils::rnd_down_pow2 (step);
431
+ int esize = step;
432
+ auto s = src.subregister (i, esize, src_stride);
433
+ auto d = dst.subregister (i, esize, dst_stride);
434
+ ngen::Subregister t1;
435
+ if (dst_bf) {
436
+ t1 = tmp1.subregister (0 , esize, 1 );
437
+ std::swap (t1, d);
438
+ std::swap (tmp_stride, dst_stride);
439
+ }
440
+ auto t0 = tmp0.subregister (0 , esize, 2 );
441
+ plan (cvt_u4_to_uw, esize, t0 (2 ), s (src_stride));
442
+ host->eshl (esize, d.ud ()(dst_stride), t0 (2 ), 22 );
443
+ host->eshl (esize, t0 (2 ), t0 (2 ), 12 );
444
+ host->and_ (
445
+ esize, d.ud ()(dst_stride), d.ud ()(dst_stride), 0x01C00000 );
446
+ host->mul (esize, d (dst_stride), d (dst_stride),
447
+ ngen::Immediate::f (shift));
448
+ bfn0xCA (esize, d.uw (1 )(2 * dst_stride), d.uw (1 )(2 * dst_stride),
449
+ t0 (2 ), 0x8000 );
450
+ if (dst_bf) {
451
+ std::swap (tmp_stride, dst_stride);
452
+ std::swap (t1, d);
453
+ host->emov (esize, d.uw ()(dst_stride), t1.uw (1 )(2 ));
454
+ }
455
+ }
456
+ return ;
457
+ }
349
458
350
459
// bf16 -> f32:
351
460
// - bf16 must be packed: use left shift instead.
0 commit comments