Skip to content

Commit 1a08176

Browse files
committed
[wip] mx: expose a fast path for casting to fp4x2
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: deefa24 ghstack-comment-id: 3210931181 Pull-Request: #2832
1 parent 248899b commit 1a08176

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,29 @@ def test_cuda_mx_dim1_invalid_block_size():
561561
scale_dim_x=1,
562562
scale_dim_y=invalid_block_size,
563563
)
564+
565+
566+
def _fp32_to_fp4_reference(
567+
data_hp: torch.Tensor,
568+
) -> torch.Tensor:
569+
data_lp = f32_to_f4_unpacked(data_hp.float())
570+
data_lp = pack_uint4(data_lp)
571+
return data_lp
572+
573+
574+
@pytest.mark.skipif(
575+
not is_sm_at_least_100(),
576+
reason="requires CUDA capability 10.0 or greater",
577+
)
578+
def test_fp32_cast_to_fp4x2():
579+
from torchao.prototype.mx_formats.kernels import triton_fp32_cast_to_fp4x2
580+
581+
M, K = 16, 16
582+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
583+
# make x's range be the representable range of fp4
584+
x = x * 6.0
585+
586+
data_ref = _fp32_to_fp4_reference(x)
587+
data = triton_fp32_cast_to_fp4x2(x)
588+
torch.testing.assert_close(data_ref, data, atol=0, rtol=0)
589+
assert data.shape == (M, K // 2)

torchao/prototype/mx_formats/kernels.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,56 @@ def _(scale_tensor):
14541454
padded_cols = n_col_blocks * 4
14551455

14561456
return scale_tensor.new_empty((padded_rows, padded_cols))
1457+
1458+
@triton.jit
1459+
def fp32_cast_to_fp4x2_triton_kernel(
1460+
x_ptr,
1461+
q_ptr,
1462+
stride_xm,
1463+
stride_xn,
1464+
M,
1465+
N,
1466+
):
1467+
pid_m = tl.program_id(1)
1468+
pid_n = tl.program_id(0)
1469+
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1470+
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
1471+
mask = None
1472+
other = None
1473+
x = tl.load(
1474+
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
1475+
) # [128, 64]
1476+
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
1477+
# Convert to FP4
1478+
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
1479+
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1480+
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
1481+
mask = (offs_m < M) & (offs_n < N // 2)
1482+
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask)
1483+
1484+
def triton_fp32_cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
1485+
"""
1486+
Input: a float32 tensor with shape (M, N)
1487+
Output: a uint8 tensor with shape (M, N // 2), with the values being the result
1488+
of casting each original value to fp4_e2m1, and then packing fp4x2
1489+
1490+
TODO(future PR): optimize performance
1491+
TODO(future PR): better checks for shapes, etc
1492+
TODO(future PR): integrate into training/inference
1493+
TODO(future PR): integrate with compile, ideally allowing fusion
1494+
"""
1495+
M, N = x.shape
1496+
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
1497+
grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
1498+
fp32_cast_to_fp4x2_triton_kernel[grid](
1499+
x,
1500+
xq,
1501+
x.stride(0),
1502+
x.stride(1),
1503+
M,
1504+
N,
1505+
)
1506+
return xq.view(torch.uint8)
14571507
else:
14581508

14591509
def triton_to_mxfp8_dim1(

0 commit comments

Comments
 (0)