@@ -1454,6 +1454,56 @@ def _(scale_tensor):
1454
1454
padded_cols = n_col_blocks * 4
1455
1455
1456
1456
return scale_tensor .new_empty ((padded_rows , padded_cols ))
1457
+
1458
+ @triton .jit
1459
+ def fp32_cast_to_fp4x2_triton_kernel (
1460
+ x_ptr ,
1461
+ q_ptr ,
1462
+ stride_xm ,
1463
+ stride_xn ,
1464
+ M ,
1465
+ N ,
1466
+ ):
1467
+ pid_m = tl .program_id (1 )
1468
+ pid_n = tl .program_id (0 )
1469
+ offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1470
+ offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
1471
+ mask = None
1472
+ other = None
1473
+ x = tl .load (
1474
+ x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
1475
+ ) # [128, 64]
1476
+ x_blocks = x .to (tl .float32 ).reshape (128 , 4 , 16 ) # [128, 4, 16]
1477
+ # Convert to FP4
1478
+ x_fp4x2 = convert_fp32_to_fp4_packed (x_blocks .reshape (128 , 32 , 2 ).split ())
1479
+ offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1480
+ offs_n = pid_n * 32 + tl .arange (0 , 32 )[None , :]
1481
+ mask = (offs_m < M ) & (offs_n < N // 2 )
1482
+ tl .store (q_ptr + offs_m * (N // 2 ) + offs_n , x_fp4x2 , mask = mask )
1483
+
1484
+ def triton_fp32_cast_to_fp4x2 (x : torch .Tensor ) -> torch .Tensor :
1485
+ """
1486
+ Input: a float32 tensor with shape (M, N)
1487
+ Output: a uint8 tensor with shape (M, N // 2), with the values being the result
1488
+ of casting each original value to fp4_e2m1, and then packing fp4x2
1489
+
1490
+ TODO(future PR): optimize performance
1491
+ TODO(future PR): better checks for shapes, etc
1492
+ TODO(future PR): integrate into training/inference
1493
+ TODO(future PR): integrate with compile, ideally allowing fusion
1494
+ """
1495
+ M , N = x .shape
1496
+ xq = x .new_empty (M , N // 2 , dtype = torch .uint8 )
1497
+ grid = (triton .cdiv (N , 64 ), triton .cdiv (M , 128 ))
1498
+ fp32_cast_to_fp4x2_triton_kernel [grid ](
1499
+ x ,
1500
+ xq ,
1501
+ x .stride (0 ),
1502
+ x .stride (1 ),
1503
+ M ,
1504
+ N ,
1505
+ )
1506
+ return xq .view (torch .uint8 )
1457
1507
else :
1458
1508
1459
1509
def triton_to_mxfp8_dim1 (
0 commit comments