9
9
import torch
10
10
import triton
11
11
import triton .language as tl
12
+ from torch .library import triton_op , wrap_triton
12
13
13
14
from torchao .prototype .moe_training .utils import (
14
15
_is_column_major ,
@@ -119,7 +120,7 @@ def triton_fp8_gemm_1x128_128x128(
119
120
triton .cdiv (M , META ["BLOCK_SIZE_M" ]),
120
121
triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
121
122
)
122
- triton_fp8_gemm_1x128_128x128_kernel [grid ](
123
+ wrap_triton ( triton_fp8_gemm_1x128_128x128_kernel ) [grid ](
123
124
a ,
124
125
a .stride (0 ),
125
126
a .stride (1 ),
@@ -234,7 +235,7 @@ def triton_fp8_gemm_1x128_128x1(
234
235
triton .cdiv (M , META ["BLOCK_SIZE_M" ]),
235
236
triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
236
237
)
237
- triton_fp8_gemm_1x128_128x1_kernel [grid ](
238
+ wrap_triton ( triton_fp8_gemm_1x128_128x1_kernel ) [grid ](
238
239
a ,
239
240
a .stride (0 ),
240
241
a .stride (1 ),
@@ -281,7 +282,7 @@ def triton_fp8_gemm_1x128_128x1(
281
282
282
283
@triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
283
284
@triton .jit
284
- def fp8_blockwise_act_quant_lhs_kernel (
285
+ def triton_fp8_blockwise_act_quant_lhs_kernel (
285
286
x_ptr ,
286
287
x_stride_dim_0 ,
287
288
x_stride_dim_1 ,
@@ -327,7 +328,8 @@ def fp8_blockwise_act_quant_lhs_kernel(
327
328
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
328
329
329
330
330
- def fp8_blockwise_act_quant_lhs (
331
+ @triton_op ("torchao::triton_fp8_blockwise_act_quant_lhs" , mutates_args = {})
332
+ def triton_fp8_blockwise_act_quant_lhs (
331
333
x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
332
334
) -> Tuple [torch .Tensor , torch .Tensor ]:
333
335
"""
@@ -352,7 +354,7 @@ def fp8_blockwise_act_quant_lhs(
352
354
triton .cdiv (M , meta ["NUM_GROUPS" ]),
353
355
triton .cdiv (K , meta ["BLOCK_SIZE" ]),
354
356
)
355
- fp8_blockwise_act_quant_lhs_kernel [grid ](
357
+ wrap_triton ( triton_fp8_blockwise_act_quant_lhs_kernel ) [grid ](
356
358
x ,
357
359
x .stride (0 ),
358
360
x .stride (1 ),
@@ -372,7 +374,7 @@ def fp8_blockwise_act_quant_lhs(
372
374
373
375
@triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
374
376
@triton .jit
375
- def fp8_blockwise_act_quant_rhs_kernel (
377
+ def triton_fp8_blockwise_act_quant_rhs_kernel (
376
378
x_ptr ,
377
379
x_stride_dim_0 ,
378
380
x_stride_dim_1 ,
@@ -420,7 +422,8 @@ def fp8_blockwise_act_quant_rhs_kernel(
420
422
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
421
423
422
424
423
- def fp8_blockwise_act_quant_rhs (
425
+ @triton_op ("torchao::triton_fp8_blockwise_act_quant_rhs" , mutates_args = {})
426
+ def triton_fp8_blockwise_act_quant_rhs (
424
427
x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
425
428
) -> Tuple [torch .Tensor , torch .Tensor ]:
426
429
"""
@@ -444,7 +447,7 @@ def fp8_blockwise_act_quant_rhs(
444
447
triton .cdiv (M , meta ["BLOCK_SIZE" ]),
445
448
triton .cdiv (K , meta ["NUM_GROUPS" ]),
446
449
)
447
- fp8_blockwise_act_quant_rhs_kernel [grid ](
450
+ wrap_triton ( triton_fp8_blockwise_act_quant_rhs_kernel ) [grid ](
448
451
x ,
449
452
x .stride (0 ),
450
453
x .stride (1 ),
@@ -464,7 +467,7 @@ def fp8_blockwise_act_quant_rhs(
464
467
465
468
@triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
466
469
@triton .jit
467
- def fp8_blockwise_act_quant_transposed_lhs_kernel (
470
+ def triton_fp8_blockwise_act_quant_transposed_lhs_kernel (
468
471
x_ptr ,
469
472
x_stride_dim_0 ,
470
473
x_stride_dim_1 ,
@@ -524,7 +527,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
524
527
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
525
528
526
529
527
- def fp8_blockwise_act_quant_transposed_lhs (
530
+ @triton_op ("torchao::triton_fp8_blockwise_act_quant_transposed_lhs" , mutates_args = {})
531
+ def triton_fp8_blockwise_act_quant_transposed_lhs (
528
532
x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
529
533
) -> Tuple [torch .Tensor , torch .Tensor ]:
530
534
assert x .is_contiguous (), "Input tensor must be contiguous"
@@ -550,7 +554,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
550
554
triton .cdiv (K , meta ["NUM_GROUPS" ]),
551
555
)
552
556
553
- fp8_blockwise_act_quant_transposed_lhs_kernel [grid ](
557
+ wrap_triton ( triton_fp8_blockwise_act_quant_transposed_lhs_kernel ) [grid ](
554
558
x ,
555
559
x .stride (0 ),
556
560
x .stride (1 ),
@@ -570,7 +574,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
570
574
571
575
@triton .autotune (configs = quant_kernel_configs , key = ["M" , "N" ])
572
576
@triton .jit
573
- def fp8_blockwise_weight_quant_rhs_kernel (
577
+ def triton_fp8_blockwise_weight_quant_rhs_kernel (
574
578
x_ptr ,
575
579
x_stride_dim_0 ,
576
580
x_stride_dim_1 ,
@@ -615,8 +619,9 @@ def fp8_blockwise_weight_quant_rhs_kernel(
615
619
tl .store (s_ptr + scale_m_off + scale_n_off , tl .div_rn (1.0 , scale ))
616
620
617
621
618
- def fp8_blockwise_weight_quant_rhs (
619
- x : torch .Tensor , block_size : int = 128 , dtype = torch .float8_e4m3fn
622
+ @triton_op ("torchao::triton_fp8_blockwise_weight_quant_rhs" , mutates_args = {})
623
+ def triton_fp8_blockwise_weight_quant_rhs (
624
+ x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
620
625
) -> Tuple [torch .Tensor , torch .Tensor ]:
621
626
assert x .is_contiguous (), "Input tensor must be contiguous"
622
627
assert x .dim () == 2 , "Input tensor must have 2 dimensions"
@@ -638,7 +643,7 @@ def fp8_blockwise_weight_quant_rhs(
638
643
triton .cdiv (M , meta ["BLOCK_SIZE" ]),
639
644
triton .cdiv (N , meta ["BLOCK_SIZE" ]),
640
645
)
641
- fp8_blockwise_weight_quant_rhs_kernel [grid ](
646
+ wrap_triton ( triton_fp8_blockwise_weight_quant_rhs_kernel ) [grid ](
642
647
x ,
643
648
x .stride (0 ),
644
649
x .stride (1 ),
@@ -658,7 +663,7 @@ def fp8_blockwise_weight_quant_rhs(
658
663
659
664
@triton .autotune (configs = quant_kernel_configs , key = ["M" , "N" ])
660
665
@triton .jit
661
- def fp8_blockwise_weight_quant_transposed_rhs_kernel (
666
+ def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel (
662
667
x_ptr ,
663
668
x_stride_dim_0 ,
664
669
x_stride_dim_1 ,
@@ -719,8 +724,9 @@ def fp8_blockwise_weight_quant_transposed_rhs_kernel(
719
724
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
720
725
721
726
722
- def fp8_blockwise_weight_quant_transposed_rhs (
723
- x : torch .Tensor , block_size : int = 128 , dtype = torch .float8_e4m3fn
727
+ @triton_op ("torchao::triton_fp8_blockwise_weight_quant_transposed_rhs" , mutates_args = {})
728
+ def triton_fp8_blockwise_weight_quant_transposed_rhs (
729
+ x : torch .Tensor , block_size : int = 128 , dtype : torch .dtype = torch .float8_e4m3fn
724
730
) -> Tuple [torch .Tensor , torch .Tensor ]:
725
731
assert x .is_contiguous (), "Input tensor must be contiguous"
726
732
assert x .dim () == 2 , "Input tensor must have 2 dimensions"
@@ -742,7 +748,7 @@ def fp8_blockwise_weight_quant_transposed_rhs(
742
748
triton .cdiv (M , meta ["BLOCK_SIZE" ]),
743
749
triton .cdiv (N , meta ["BLOCK_SIZE" ]),
744
750
)
745
- fp8_blockwise_weight_quant_transposed_rhs_kernel [grid ](
751
+ wrap_triton ( triton_fp8_blockwise_weight_quant_transposed_rhs_kernel ) [grid ](
746
752
x ,
747
753
x .stride (0 ),
748
754
x .stride (1 ),
0 commit comments