Skip to content

Commit ba52b69

Browse files
[fp8 blockwise] load 2d chunks for groupwise quant to enable coalesced gmem accesses
stack-info: PR: #2827, branch: danielvegamyhre/stack/46
1 parent e48e077 commit ba52b69

File tree

2 files changed

+58
-37
lines changed

2 files changed

+58
-37
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import torch
1414
from tabulate import tabulate
15-
from torch.nn import functional as F
1615
from tqdm import tqdm
1716
from triton.testing import do_bench
1817

@@ -72,7 +71,9 @@ def get_configs() -> List[ExperimentConfig]:
7271
return configs
7372

7473

75-
def run_experiment(config: ExperimentConfig, profile=False, use_compile=False) -> ExperimentResult:
74+
def run_experiment(
75+
config: ExperimentConfig, profile=False, use_compile=False
76+
) -> ExperimentResult:
7677
M, N, K = config.m, config.n, config.k
7778
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
7879
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
@@ -87,24 +88,23 @@ def warmup(func, *args, **kwargs):
8788
for _ in range(3):
8889
func(*args, **kwargs)
8990

90-
9191
# bfloat16 bench and profile
9292
labels = inputs.new_empty(M, N).fill_(1.0)
9393
bf16_linear_us = bench_fwd_bwd_microseconds(
94-
bf16_linear,
95-
inputs,
96-
labels=labels,
94+
bf16_linear,
95+
inputs,
96+
labels=labels,
9797
use_compile=use_compile,
9898
)
9999
if profile:
100100
print("Profiling bf16_linear")
101101
profile_fwd_bwd(
102-
bf16_linear,
103-
inputs,
102+
bf16_linear,
103+
inputs,
104104
labels=labels,
105105
profile_name="bf16_linear_profile",
106106
use_compile=use_compile,
107-
)
107+
)
108108

109109
# FP8 triton bench and profile
110110
fp8_triton_linear_us = bench_fwd_bwd_microseconds(
@@ -189,7 +189,7 @@ def main(args: argparse.Namespace):
189189

190190

191191
if __name__ == "__main__":
192-
parser = argparse.ArgumentParser()
192+
parser = argparse.ArgumentParser()
193193
parser.add_argument("--profile", action="store_true", help="Enable profiling")
194194
parser.add_argument("--compile", action="store_true", help="Enable compilation")
195195
args = parser.parse_args()

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,22 @@ def triton_fp8_gemm_1x128_128x1(
264264
num_stages=stages,
265265
)
266266
for warps in [4, 8]
267+
for stages in [2, 4]
268+
]
269+
270+
quant_kernel_configs_with_groups = [
271+
triton.Config(
272+
{"NUM_GROUPS": groups},
273+
num_warps=warps,
274+
num_stages=stages,
275+
)
276+
for groups in [2, 16, 32, 64, 128]
277+
for warps in [2, 4, 8]
267278
for stages in [2, 4, 6]
268279
]
269280

270281

271-
@triton.autotune(configs=quant_kernel_configs, key=["K"])
282+
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
272283
@triton.jit
273284
def fp8_blockwise_act_quant_lhs_kernel(
274285
x_ptr,
@@ -283,13 +294,14 @@ def fp8_blockwise_act_quant_lhs_kernel(
283294
M,
284295
K: tl.constexpr,
285296
BLOCK_SIZE: tl.constexpr,
297+
NUM_GROUPS: tl.constexpr,
286298
EPS: tl.constexpr,
287299
):
288300
pid_m = tl.program_id(axis=0)
289301
pid_k = tl.program_id(axis=1)
290302

291-
# Load (1 x block_size) tile of x, where input is row major
292-
m_offs = pid_m
303+
# Load (num_groups x block_size) tile of x, where input is row major
304+
m_offs = pid_m * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
293305
k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
294306
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
295307
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
@@ -298,8 +310,10 @@ def fp8_blockwise_act_quant_lhs_kernel(
298310
# Perform scaling
299311
max_fp8_e4m3 = 448.0
300312
min_fp8_e4m3 = -448.0
301-
amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64)
302-
scale = (max_fp8_e4m3 / amax).to(tl.float32)
313+
314+
# Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1)
315+
amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64)
316+
scale = (max_fp8_e4m3 / amax).to(tl.float32)[:, None]
303317
y = x * scale
304318
y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty)
305319

@@ -309,7 +323,7 @@ def fp8_blockwise_act_quant_lhs_kernel(
309323
tl.store(y_ptr + y_offs, y, mask=y_mask)
310324

311325
# Write reciprocal scales
312-
scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
326+
scale_offs = m_offs[:, None] * s_stride_dim_0 + pid_k * s_stride_dim_1
313327
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
314328

315329

@@ -334,7 +348,10 @@ def fp8_blockwise_act_quant_lhs(
334348
(M, K // block_size),
335349
(1, M),
336350
)
337-
grid = lambda meta: (M, triton.cdiv(K, meta["BLOCK_SIZE"]))
351+
grid = lambda meta: (
352+
triton.cdiv(M, meta["NUM_GROUPS"]),
353+
triton.cdiv(K, meta["BLOCK_SIZE"]),
354+
)
338355
fp8_blockwise_act_quant_lhs_kernel[grid](
339356
x,
340357
x.stride(0),
@@ -353,7 +370,7 @@ def fp8_blockwise_act_quant_lhs(
353370
return y, s
354371

355372

356-
@triton.autotune(configs=quant_kernel_configs, key=["K"])
373+
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
357374
@triton.jit
358375
def fp8_blockwise_act_quant_rhs_kernel(
359376
x_ptr,
@@ -368,33 +385,38 @@ def fp8_blockwise_act_quant_rhs_kernel(
368385
M,
369386
K: tl.constexpr,
370387
BLOCK_SIZE: tl.constexpr,
388+
NUM_GROUPS: tl.constexpr,
371389
EPS: tl.constexpr,
372390
):
373391
pid_m = tl.program_id(axis=0)
374392
pid_k = tl.program_id(axis=1)
375393

376-
# Load (block_size x 1) tile of x, where input is row major
394+
# Load (block_size x block_size) tile of x, where input is row major.
395+
# Each scaling group is (block_size x 1), but we load (block_size x block_size)
396+
# to facilitate coalesced gmem accesses and improve efficiency.
377397
m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
378-
k_offs = pid_k
398+
k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
379399
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
380400
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
381401
x = tl.load(x_ptr + x_offs, mask=x_mask)
382402

383403
# Perform scaling
384404
max_fp8_e4m3 = 448.0
385405
min_fp8_e4m3 = -448.0
386-
amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64)
387-
scale = (max_fp8_e4m3 / amax).to(tl.float32)
406+
407+
# Column-wise scales for RHS operand, shape (1, block_size)
408+
amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64)
409+
scale = (max_fp8_e4m3 / amax).to(tl.float32)[None, :]
388410
y = x * scale
389411
y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty)
390412

391-
# Write output to column major fomrat
413+
# Write output to column major format
392414
y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1
393415
y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
394416
tl.store(y_ptr + y_offs, y, mask=y_mask)
395417

396418
# Write scales
397-
scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
419+
scale_offs = pid_m * s_stride_dim_0 + k_offs[None, :] * s_stride_dim_1
398420
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
399421

400422

@@ -420,7 +442,7 @@ def fp8_blockwise_act_quant_rhs(
420442

421443
grid = lambda meta: (
422444
triton.cdiv(M, meta["BLOCK_SIZE"]),
423-
K,
445+
triton.cdiv(K, meta["NUM_GROUPS"]),
424446
)
425447
fp8_blockwise_act_quant_rhs_kernel[grid](
426448
x,
@@ -440,7 +462,7 @@ def fp8_blockwise_act_quant_rhs(
440462
return y, s
441463

442464

443-
@triton.autotune(configs=quant_kernel_configs, key=["K"])
465+
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
444466
@triton.jit
445467
def fp8_blockwise_act_quant_transposed_lhs_kernel(
446468
x_ptr,
@@ -454,8 +476,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
454476
s_stride_dim_1,
455477
M,
456478
K: tl.constexpr,
457-
SCALE_BLOCK_SIZE: tl.constexpr, # For scaling groups, not for grid/parallelization
458-
BLOCK_SIZE_K: tl.constexpr, # For grid/parallelization, not for scaling groups
479+
BLOCK_SIZE: tl.constexpr, # For scaling groups, not for grid/parallelization
480+
NUM_GROUPS: tl.constexpr, # For grid/parallelization, not for scaling groups
459481
EPS: tl.constexpr,
460482
):
461483
# This kernel reads data in row-major format, and writes to an output tensor with
@@ -465,12 +487,12 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
465487
pid_m = tl.program_id(axis=0)
466488
pid_k = tl.program_id(axis=1)
467489

468-
# Load (block_size x block_size_k) block of input, where input is row major.
490+
# Load (block_size x num_groups) block of input, where input is row major.
469491
# We will be computing (block_size x 1) scaling factors (columns), and computing
470-
# `block_size_k` at a time, so we aren't parallelizing with 1 thread per column,
492+
# `num_groups` at a time, so we aren't parallelizing with 1 thread per column,
471493
# which will fail to launch for large tensors, due to max block number of 65535.
472-
m_offs = pid_m * SCALE_BLOCK_SIZE + tl.arange(0, SCALE_BLOCK_SIZE)
473-
k_offs = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
494+
m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
495+
k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
474496
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
475497
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
476498
x = tl.load(x_ptr + x_offs, mask=x_mask)
@@ -496,7 +518,7 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
496518

497519
# Scale tensor size is (K, M // SCALE_BLOCK_SIZE)
498520
scale_offs = scale_k_offs * s_stride_dim_0 + scale_m_off * s_stride_dim_1
499-
scale_mask = (scale_k_offs < K) & (scale_m_off < M // SCALE_BLOCK_SIZE)
521+
scale_mask = (scale_k_offs < K) & (scale_m_off < M // BLOCK_SIZE)
500522

501523
# Write out reciprocal scales
502524
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
@@ -524,8 +546,8 @@ def fp8_blockwise_act_quant_transposed_lhs(
524546
(1, K), # stride
525547
)
526548
grid = lambda meta: (
527-
triton.cdiv(M, meta["SCALE_BLOCK_SIZE"]),
528-
triton.cdiv(K, meta["BLOCK_SIZE_K"]),
549+
triton.cdiv(M, meta["BLOCK_SIZE"]),
550+
triton.cdiv(K, meta["NUM_GROUPS"]),
529551
)
530552

531553
fp8_blockwise_act_quant_transposed_lhs_kernel[grid](
@@ -540,8 +562,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
540562
s.stride(1),
541563
M,
542564
K=K,
543-
SCALE_BLOCK_SIZE=block_size, # Scaling group size
544-
BLOCK_SIZE_K=block_size, # Just for parallelize the work along K as well
565+
BLOCK_SIZE=block_size, # Scaling group size
545566
EPS=EPS,
546567
)
547568
return y, s

0 commit comments

Comments
 (0)