Skip to content

Commit a75457b

Browse files
[fp8 blockwise] wrap triton quantization kernels in custom ops for torch.compile compatibility
stack-info: PR: #2829, branch: danielvegamyhre/stack/47
1 parent ba52b69 commit a75457b

File tree

6 files changed

+76
-66
lines changed

6 files changed

+76
-66
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
fp8_blockwise_act_quant_lhs,
19-
fp8_blockwise_weight_quant_transposed_rhs,
18+
triton_fp8_blockwise_act_quant_lhs,
19+
triton_fp8_blockwise_weight_quant_transposed_rhs,
2020
triton_fp8_gemm_1x128_128x128,
2121
)
2222

@@ -78,8 +78,8 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7878
M, N, K = config.m, config.n, config.k
7979
A = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
8080
B = torch.randn(N, K, dtype=config.out_dtype, device="cuda")
81-
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
82-
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(
81+
A_q, A_s = triton_fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
82+
B_t_q, B_t_s = triton_fp8_blockwise_weight_quant_transposed_rhs(
8383
B, dtype=torch.float8_e4m3fn
8484
)
8585

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
fp8_blockwise_act_quant_rhs,
19-
fp8_blockwise_act_quant_transposed_lhs,
18+
triton_fp8_blockwise_act_quant_rhs,
19+
triton_fp8_blockwise_act_quant_transposed_lhs,
2020
triton_fp8_gemm_1x128_128x1,
2121
)
2222

@@ -78,8 +78,10 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7878
M, N, K = config.m, config.n, config.k
7979
A = torch.randn(M, N, dtype=config.out_dtype, device="cuda")
8080
B = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
81-
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=torch.float8_e4m3fn)
82-
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn)
81+
A_t_q, A_t_s = triton_fp8_blockwise_act_quant_transposed_lhs(
82+
A, dtype=torch.float8_e4m3fn
83+
)
84+
B_q, B_s = triton_fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn)
8385

8486
def warmup(func, *args, **kwargs):
8587
for _ in range(10):

benchmarks/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
1-
import statistics
2-
from time import perf_counter_ns
3-
41
import torch
52
from torch.nn import functional as F
3+
from triton.testing import do_bench
64

75

86
def bench_fwd_bwd_microseconds(
97
fn, *args, labels=None, use_compile=False, fullgraph=True, **kwargs
108
):
119
assert labels is not None
12-
fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
13-
times = []
14-
for _ in range(10):
15-
start_ns = perf_counter_ns()
10+
11+
def fwd_bwd():
1612
out = fn(*args, **kwargs)
1713
loss = F.mse_loss(out, labels)
1814
loss.backward()
19-
torch.cuda.synchronize()
20-
end_ns = perf_counter_ns()
21-
duration_us = (end_ns - start_ns) / 1000
22-
times.append(duration_us)
23-
return statistics.median(times)
15+
16+
fwd_bwd_compiled = (
17+
torch.compile(fwd_bwd, fullgraph=fullgraph) if use_compile else fwd_bwd
18+
)
19+
return benchmark_cuda_function_in_microseconds(fwd_bwd_compiled)
2420

2521

2622
def profile_fwd_bwd(
@@ -56,3 +52,7 @@ def profile_fwd_bwd(
5652
# Save profiler results
5753
prof.export_chrome_trace(f"{profile_name}.json")
5854
print(f"Saved: {profile_name}.json")
55+
56+
57+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
58+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3

test/prototype/blockwise_fp8_training/test_blockwise_kernels.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
from packaging import version
1313
from torchao.float8.float8_utils import compute_error
1414
from torchao.prototype.blockwise_fp8_training.kernels import (
15-
fp8_blockwise_act_quant_lhs,
16-
fp8_blockwise_act_quant_rhs,
17-
fp8_blockwise_act_quant_transposed_lhs,
18-
fp8_blockwise_weight_quant_rhs,
19-
fp8_blockwise_weight_quant_transposed_rhs,
2015
torch_blockwise_scale_act_quant_lhs,
2116
torch_blockwise_scale_act_quant_rhs,
2217
torch_blockwise_scale_weight_quant,
18+
triton_fp8_blockwise_act_quant_lhs,
19+
triton_fp8_blockwise_act_quant_rhs,
20+
triton_fp8_blockwise_act_quant_transposed_lhs,
21+
triton_fp8_blockwise_weight_quant_rhs,
22+
triton_fp8_blockwise_weight_quant_transposed_rhs,
2323
triton_fp8_gemm_1x128_128x1,
2424
triton_fp8_gemm_1x128_128x128,
2525
)
@@ -51,8 +51,8 @@ def test_triton_fp8_gemm_1x128_128x128(M, N, K, dtype):
5151
A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
5252
B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
5353
C = A @ B.T
54-
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype)
55-
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
54+
A_q, A_s = triton_fp8_blockwise_act_quant_lhs(A, dtype=dtype)
55+
B_t_q, B_t_s = triton_fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
5656
C_q = triton_fp8_gemm_1x128_128x128(
5757
A_q, B_t_q, A_s, B_t_s, out_dtype=torch.bfloat16
5858
)
@@ -76,8 +76,8 @@ def test_triton_fp8_gemm_1x128_128x1(M, N, K, dtype):
7676
A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda")
7777
B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda")
7878
C = A.T @ B
79-
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
80-
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype)
79+
A_t_q, A_t_s = triton_fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
80+
B_q, B_s = triton_fp8_blockwise_act_quant_rhs(B, dtype=dtype)
8181
C_q = triton_fp8_gemm_1x128_128x1(A_t_q, B_q, A_t_s, B_s, out_dtype=torch.bfloat16)
8282

8383
assert not C_q.isnan().any(), "C_q must not contain NaNs"
@@ -102,7 +102,7 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size):
102102
x[0, :block_size] = 0.0
103103

104104
# Get the quantized tensor and reciprocal scales using triton implementation
105-
triton_fp8, triton_scale = fp8_blockwise_act_quant_lhs(
105+
triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_lhs(
106106
x,
107107
block_size=block_size,
108108
)
@@ -149,7 +149,7 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int):
149149
x[:block_size, :block_size] = 0.0
150150

151151
# Get the quantized tensor and reciprocal scales using triton implementation
152-
triton_fp8, triton_scale = fp8_blockwise_act_quant_rhs(
152+
triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_rhs(
153153
x,
154154
block_size=block_size,
155155
)
@@ -196,7 +196,7 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int):
196196
x[0, :block_size] = 0.0
197197

198198
# Get the quantized tensor and reciprocal scales using triton implementation
199-
triton_fp8, triton_scale = fp8_blockwise_act_quant_transposed_lhs(
199+
triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_transposed_lhs(
200200
x,
201201
block_size=block_size,
202202
)
@@ -245,7 +245,7 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int):
245245
x[:block_size, :block_size] = 0.0
246246

247247
# Get the quantized tensor and reciprocal scales using triton implementation
248-
triton_fp8, triton_scale = fp8_blockwise_weight_quant_rhs(
248+
triton_fp8, triton_scale = triton_fp8_blockwise_weight_quant_rhs(
249249
x,
250250
block_size=block_size,
251251
)
@@ -292,7 +292,7 @@ def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int):
292292
x[:block_size, :block_size] = 0.0
293293

294294
# Get the quantized tensor and reciprocal scales using triton implementation
295-
triton_fp8, triton_scale = fp8_blockwise_weight_quant_transposed_rhs(
295+
triton_fp8, triton_scale = triton_fp8_blockwise_weight_quant_transposed_rhs(
296296
x,
297297
block_size=block_size,
298298
)

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import triton
1111
import triton.language as tl
12+
from torch.library import triton_op, wrap_triton
1213

1314
from torchao.prototype.moe_training.utils import (
1415
_is_column_major,
@@ -119,7 +120,7 @@ def triton_fp8_gemm_1x128_128x128(
119120
triton.cdiv(M, META["BLOCK_SIZE_M"]),
120121
triton.cdiv(N, META["BLOCK_SIZE_N"]),
121122
)
122-
triton_fp8_gemm_1x128_128x128_kernel[grid](
123+
wrap_triton(triton_fp8_gemm_1x128_128x128_kernel)[grid](
123124
a,
124125
a.stride(0),
125126
a.stride(1),
@@ -234,7 +235,7 @@ def triton_fp8_gemm_1x128_128x1(
234235
triton.cdiv(M, META["BLOCK_SIZE_M"]),
235236
triton.cdiv(N, META["BLOCK_SIZE_N"]),
236237
)
237-
triton_fp8_gemm_1x128_128x1_kernel[grid](
238+
wrap_triton(triton_fp8_gemm_1x128_128x1_kernel)[grid](
238239
a,
239240
a.stride(0),
240241
a.stride(1),
@@ -281,7 +282,7 @@ def triton_fp8_gemm_1x128_128x1(
281282

282283
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
283284
@triton.jit
284-
def fp8_blockwise_act_quant_lhs_kernel(
285+
def triton_fp8_blockwise_act_quant_lhs_kernel(
285286
x_ptr,
286287
x_stride_dim_0,
287288
x_stride_dim_1,
@@ -327,7 +328,8 @@ def fp8_blockwise_act_quant_lhs_kernel(
327328
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
328329

329330

330-
def fp8_blockwise_act_quant_lhs(
331+
@triton_op("torchao::triton_fp8_blockwise_act_quant_lhs", mutates_args={})
332+
def triton_fp8_blockwise_act_quant_lhs(
331333
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
332334
) -> Tuple[torch.Tensor, torch.Tensor]:
333335
"""
@@ -352,7 +354,7 @@ def fp8_blockwise_act_quant_lhs(
352354
triton.cdiv(M, meta["NUM_GROUPS"]),
353355
triton.cdiv(K, meta["BLOCK_SIZE"]),
354356
)
355-
fp8_blockwise_act_quant_lhs_kernel[grid](
357+
wrap_triton(triton_fp8_blockwise_act_quant_lhs_kernel)[grid](
356358
x,
357359
x.stride(0),
358360
x.stride(1),
@@ -372,7 +374,7 @@ def fp8_blockwise_act_quant_lhs(
372374

373375
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
374376
@triton.jit
375-
def fp8_blockwise_act_quant_rhs_kernel(
377+
def triton_fp8_blockwise_act_quant_rhs_kernel(
376378
x_ptr,
377379
x_stride_dim_0,
378380
x_stride_dim_1,
@@ -420,7 +422,8 @@ def fp8_blockwise_act_quant_rhs_kernel(
420422
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
421423

422424

423-
def fp8_blockwise_act_quant_rhs(
425+
@triton_op("torchao::triton_fp8_blockwise_act_quant_rhs", mutates_args={})
426+
def triton_fp8_blockwise_act_quant_rhs(
424427
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
425428
) -> Tuple[torch.Tensor, torch.Tensor]:
426429
"""
@@ -444,7 +447,7 @@ def fp8_blockwise_act_quant_rhs(
444447
triton.cdiv(M, meta["BLOCK_SIZE"]),
445448
triton.cdiv(K, meta["NUM_GROUPS"]),
446449
)
447-
fp8_blockwise_act_quant_rhs_kernel[grid](
450+
wrap_triton(triton_fp8_blockwise_act_quant_rhs_kernel)[grid](
448451
x,
449452
x.stride(0),
450453
x.stride(1),
@@ -464,7 +467,7 @@ def fp8_blockwise_act_quant_rhs(
464467

465468
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
466469
@triton.jit
467-
def fp8_blockwise_act_quant_transposed_lhs_kernel(
470+
def triton_fp8_blockwise_act_quant_transposed_lhs_kernel(
468471
x_ptr,
469472
x_stride_dim_0,
470473
x_stride_dim_1,
@@ -524,7 +527,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
524527
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
525528

526529

527-
def fp8_blockwise_act_quant_transposed_lhs(
530+
@triton_op("torchao::triton_fp8_blockwise_act_quant_transposed_lhs", mutates_args={})
531+
def triton_fp8_blockwise_act_quant_transposed_lhs(
528532
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
529533
) -> Tuple[torch.Tensor, torch.Tensor]:
530534
assert x.is_contiguous(), "Input tensor must be contiguous"
@@ -550,7 +554,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
550554
triton.cdiv(K, meta["NUM_GROUPS"]),
551555
)
552556

553-
fp8_blockwise_act_quant_transposed_lhs_kernel[grid](
557+
wrap_triton(triton_fp8_blockwise_act_quant_transposed_lhs_kernel)[grid](
554558
x,
555559
x.stride(0),
556560
x.stride(1),
@@ -570,7 +574,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
570574

571575
@triton.autotune(configs=quant_kernel_configs, key=["M", "N"])
572576
@triton.jit
573-
def fp8_blockwise_weight_quant_rhs_kernel(
577+
def triton_fp8_blockwise_weight_quant_rhs_kernel(
574578
x_ptr,
575579
x_stride_dim_0,
576580
x_stride_dim_1,
@@ -615,8 +619,9 @@ def fp8_blockwise_weight_quant_rhs_kernel(
615619
tl.store(s_ptr + scale_m_off + scale_n_off, tl.div_rn(1.0, scale))
616620

617621

618-
def fp8_blockwise_weight_quant_rhs(
619-
x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn
622+
@triton_op("torchao::triton_fp8_blockwise_weight_quant_rhs", mutates_args={})
623+
def triton_fp8_blockwise_weight_quant_rhs(
624+
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
620625
) -> Tuple[torch.Tensor, torch.Tensor]:
621626
assert x.is_contiguous(), "Input tensor must be contiguous"
622627
assert x.dim() == 2, "Input tensor must have 2 dimensions"
@@ -638,7 +643,7 @@ def fp8_blockwise_weight_quant_rhs(
638643
triton.cdiv(M, meta["BLOCK_SIZE"]),
639644
triton.cdiv(N, meta["BLOCK_SIZE"]),
640645
)
641-
fp8_blockwise_weight_quant_rhs_kernel[grid](
646+
wrap_triton(triton_fp8_blockwise_weight_quant_rhs_kernel)[grid](
642647
x,
643648
x.stride(0),
644649
x.stride(1),
@@ -658,7 +663,7 @@ def fp8_blockwise_weight_quant_rhs(
658663

659664
@triton.autotune(configs=quant_kernel_configs, key=["M", "N"])
660665
@triton.jit
661-
def fp8_blockwise_weight_quant_transposed_rhs_kernel(
666+
def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel(
662667
x_ptr,
663668
x_stride_dim_0,
664669
x_stride_dim_1,
@@ -719,8 +724,9 @@ def fp8_blockwise_weight_quant_transposed_rhs_kernel(
719724
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
720725

721726

722-
def fp8_blockwise_weight_quant_transposed_rhs(
723-
x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn
727+
@triton_op("torchao::triton_fp8_blockwise_weight_quant_transposed_rhs", mutates_args={})
728+
def triton_fp8_blockwise_weight_quant_transposed_rhs(
729+
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
724730
) -> Tuple[torch.Tensor, torch.Tensor]:
725731
assert x.is_contiguous(), "Input tensor must be contiguous"
726732
assert x.dim() == 2, "Input tensor must have 2 dimensions"
@@ -742,7 +748,7 @@ def fp8_blockwise_weight_quant_transposed_rhs(
742748
triton.cdiv(M, meta["BLOCK_SIZE"]),
743749
triton.cdiv(N, meta["BLOCK_SIZE"]),
744750
)
745-
fp8_blockwise_weight_quant_transposed_rhs_kernel[grid](
751+
wrap_triton(triton_fp8_blockwise_weight_quant_transposed_rhs_kernel)[grid](
746752
x,
747753
x.stride(0),
748754
x.stride(1),

0 commit comments

Comments
 (0)