5
5
# LICENSE file in the root directory of this source tree.
6
6
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7
7
8
+ import argparse
8
9
import itertools
9
10
from dataclasses import dataclass
10
11
from typing import List
15
16
from tqdm import tqdm
16
17
from triton .testing import do_bench
17
18
18
- from benchmarks .utils import bench_fwd_bwd_microseconds
19
+ from benchmarks .utils import bench_fwd_bwd_microseconds , profile_fwd_bwd
19
20
from torchao .prototype .blockwise_fp8_training .linear import Float8BlockwiseLinear
20
21
21
22
device = torch .device ("cuda" )
@@ -71,7 +72,7 @@ def get_configs() -> List[ExperimentConfig]:
71
72
return configs
72
73
73
74
74
- def run_experiment (config : ExperimentConfig ) -> ExperimentResult :
75
+ def run_experiment (config : ExperimentConfig , profile = False , use_compile = False ) -> ExperimentResult :
75
76
M , N , K = config .m , config .n , config .k
76
77
inputs = torch .randn (M , K , dtype = config .out_dtype , device = "cuda" )
77
78
bf16_linear = torch .nn .Linear (K , N , dtype = config .out_dtype , device = "cuda" )
@@ -83,49 +84,59 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
83
84
)
84
85
85
86
def warmup (func , * args , ** kwargs ):
86
- for _ in range (10 ):
87
+ for _ in range (3 ):
87
88
func (* args , ** kwargs )
88
89
89
- def fwd_bwd (func , inputs , labels , * args , ** kwargs ):
90
- out = func (inputs , * args , ** kwargs )
91
- loss = F .mse_loss (out , labels )
92
- loss .backward ()
93
- torch .cuda .synchronize ()
94
90
95
- # Warmup then run bf16 torch.mm
91
+ # bfloat16 bench and profile
96
92
labels = inputs .new_empty (M , N ).fill_ (1.0 )
97
- warmup (fwd_bwd , bf16_linear , inputs , labels )
98
-
99
- bf16_linear_us = benchmark_cuda_function_in_microseconds (
100
- fwd_bwd , bf16_linear , inputs , labels
93
+ bf16_linear_us = bench_fwd_bwd_microseconds (
94
+ bf16_linear ,
95
+ inputs ,
96
+ labels = labels ,
97
+ use_compile = use_compile ,
101
98
)
102
-
103
- # Warm up then run triton bench
104
- warmup (
105
- fwd_bwd ,
106
- fp8_triton_linear ,
107
- inputs ,
108
- labels ,
99
+ if profile :
100
+ print ("Profiling bf16_linear" )
101
+ profile_fwd_bwd (
102
+ bf16_linear ,
103
+ inputs ,
104
+ labels = labels ,
105
+ profile_name = "bf16_linear_profile" ,
106
+ use_compile = use_compile ,
109
107
)
110
108
109
+ # FP8 triton bench and profile
111
110
fp8_triton_linear_us = bench_fwd_bwd_microseconds (
112
111
fp8_triton_linear ,
113
112
inputs ,
114
113
labels = labels ,
115
114
)
115
+ if profile :
116
+ print ("Profiling fp8_triton_linear" )
117
+ profile_fwd_bwd (
118
+ fp8_triton_linear ,
119
+ inputs ,
120
+ labels = labels ,
121
+ profile_name = "fp8_triton_linear_profile" ,
122
+ )
116
123
117
- warmup (
118
- fwd_bwd ,
119
- fp8_scaled_mm_linear ,
120
- inputs ,
121
- labels ,
122
- )
123
-
124
+ # FP8 torch._scaled_mm bench and profile
124
125
fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds (
125
126
fp8_scaled_mm_linear ,
126
127
inputs ,
127
128
labels = labels ,
129
+ use_compile = use_compile ,
128
130
)
131
+ if profile :
132
+ print ("Profiling fp8_scaled_mm_linear" )
133
+ profile_fwd_bwd (
134
+ fp8_scaled_mm_linear ,
135
+ inputs ,
136
+ labels = labels ,
137
+ profile_name = "fp8_scaled_mm_linear_profile" ,
138
+ use_compile = use_compile ,
139
+ )
129
140
130
141
return ExperimentResult (
131
142
bf16_linear_us = bf16_linear_us ,
@@ -165,17 +176,21 @@ def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
165
176
return do_bench (lambda : f (* args , ** kwargs ), return_mode = "median" ) * 1e3
166
177
167
178
168
- def main ():
179
+ def main (args : argparse . Namespace ):
169
180
torch .random .manual_seed (123 )
170
181
configs = get_configs ()
171
182
results = []
172
183
for config in tqdm (configs ):
173
- result = run_experiment (config )
184
+ result = run_experiment (config , profile = args . profile , use_compile = args . compile )
174
185
results .append (Experiment (config = config , result = result ))
175
186
176
187
# Use Tabulate to print results
177
188
print_results (results )
178
189
179
190
180
191
if __name__ == "__main__" :
181
- main ()
192
+ parser = argparse .ArgumentParser ()
193
+ parser .add_argument ("--profile" , action = "store_true" , help = "Enable profiling" )
194
+ parser .add_argument ("--compile" , action = "store_true" , help = "Enable compilation" )
195
+ args = parser .parse_args ()
196
+ main (args )
0 commit comments