Skip to content

Commit f2f5e62

Browse files
use shared bench + profile utils in blockwise fwd bwd bench script
1 parent f7a4a59 commit f2f5e62

File tree

1 file changed

+45
-30
lines changed

1 file changed

+45
-30
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77

8+
import argparse
89
import itertools
910
from dataclasses import dataclass
1011
from typing import List
@@ -15,7 +16,7 @@
1516
from tqdm import tqdm
1617
from triton.testing import do_bench
1718

18-
from benchmarks.utils import bench_fwd_bwd_microseconds
19+
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
1920
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear
2021

2122
device = torch.device("cuda")
@@ -71,7 +72,7 @@ def get_configs() -> List[ExperimentConfig]:
7172
return configs
7273

7374

74-
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
75+
def run_experiment(config: ExperimentConfig, profile=False, use_compile=False) -> ExperimentResult:
7576
M, N, K = config.m, config.n, config.k
7677
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
7778
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
@@ -83,49 +84,59 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8384
)
8485

8586
def warmup(func, *args, **kwargs):
86-
for _ in range(10):
87+
for _ in range(3):
8788
func(*args, **kwargs)
8889

89-
def fwd_bwd(func, inputs, labels, *args, **kwargs):
90-
out = func(inputs, *args, **kwargs)
91-
loss = F.mse_loss(out, labels)
92-
loss.backward()
93-
torch.cuda.synchronize()
9490

95-
# Warmup then run bf16 torch.mm
91+
# bfloat16 bench and profile
9692
labels = inputs.new_empty(M, N).fill_(1.0)
97-
warmup(fwd_bwd, bf16_linear, inputs, labels)
98-
99-
bf16_linear_us = benchmark_cuda_function_in_microseconds(
100-
fwd_bwd, bf16_linear, inputs, labels
93+
bf16_linear_us = bench_fwd_bwd_microseconds(
94+
bf16_linear,
95+
inputs,
96+
labels=labels,
97+
use_compile=use_compile,
10198
)
102-
103-
# Warm up then run triton bench
104-
warmup(
105-
fwd_bwd,
106-
fp8_triton_linear,
107-
inputs,
108-
labels,
99+
if profile:
100+
print("Profiling bf16_linear")
101+
profile_fwd_bwd(
102+
bf16_linear,
103+
inputs,
104+
labels=labels,
105+
profile_name="bf16_linear_profile",
106+
use_compile=use_compile,
109107
)
110108

109+
# FP8 triton bench and profile
111110
fp8_triton_linear_us = bench_fwd_bwd_microseconds(
112111
fp8_triton_linear,
113112
inputs,
114113
labels=labels,
115114
)
115+
if profile:
116+
print("Profiling fp8_triton_linear")
117+
profile_fwd_bwd(
118+
fp8_triton_linear,
119+
inputs,
120+
labels=labels,
121+
profile_name="fp8_triton_linear_profile",
122+
)
116123

117-
warmup(
118-
fwd_bwd,
119-
fp8_scaled_mm_linear,
120-
inputs,
121-
labels,
122-
)
123-
124+
# FP8 torch._scaled_mm bench and profile
124125
fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds(
125126
fp8_scaled_mm_linear,
126127
inputs,
127128
labels=labels,
129+
use_compile=use_compile,
128130
)
131+
if profile:
132+
print("Profiling fp8_scaled_mm_linear")
133+
profile_fwd_bwd(
134+
fp8_scaled_mm_linear,
135+
inputs,
136+
labels=labels,
137+
profile_name="fp8_scaled_mm_linear_profile",
138+
use_compile=use_compile,
139+
)
129140

130141
return ExperimentResult(
131142
bf16_linear_us=bf16_linear_us,
@@ -165,17 +176,21 @@ def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
165176
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
166177

167178

168-
def main():
179+
def main(args: argparse.Namespace):
169180
torch.random.manual_seed(123)
170181
configs = get_configs()
171182
results = []
172183
for config in tqdm(configs):
173-
result = run_experiment(config)
184+
result = run_experiment(config, profile=args.profile, use_compile=args.compile)
174185
results.append(Experiment(config=config, result=result))
175186

176187
# Use Tabulate to print results
177188
print_results(results)
178189

179190

180191
if __name__ == "__main__":
181-
main()
192+
parser = argparse.ArgumentParser()
193+
parser.add_argument("--profile", action="store_true", help="Enable profiling")
194+
parser.add_argument("--compile", action="store_true", help="Enable compilation")
195+
args = parser.parse_args()
196+
main(args)

0 commit comments

Comments
 (0)