@@ -264,11 +264,22 @@ def triton_fp8_gemm_1x128_128x1(
264
264
num_stages = stages ,
265
265
)
266
266
for warps in [4 , 8 ]
267
+ for stages in [2 , 4 ]
268
+ ]
269
+
270
+ quant_kernel_configs_with_groups = [
271
+ triton .Config (
272
+ {"NUM_GROUPS" : groups },
273
+ num_warps = warps ,
274
+ num_stages = stages ,
275
+ )
276
+ for groups in [2 , 16 , 32 , 64 , 128 ]
277
+ for warps in [2 , 4 , 8 ]
267
278
for stages in [2 , 4 , 6 ]
268
279
]
269
280
270
281
271
- @triton .autotune (configs = quant_kernel_configs , key = ["K" ])
282
+ @triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
272
283
@triton .jit
273
284
def fp8_blockwise_act_quant_lhs_kernel (
274
285
x_ptr ,
@@ -283,13 +294,14 @@ def fp8_blockwise_act_quant_lhs_kernel(
283
294
M ,
284
295
K : tl .constexpr ,
285
296
BLOCK_SIZE : tl .constexpr ,
297
+ NUM_GROUPS : tl .constexpr ,
286
298
EPS : tl .constexpr ,
287
299
):
288
300
pid_m = tl .program_id (axis = 0 )
289
301
pid_k = tl .program_id (axis = 1 )
290
302
291
- # Load (1 x block_size) tile of x, where input is row major
292
- m_offs = pid_m
303
+ # Load (num_groups x block_size) tile of x, where input is row major
304
+ m_offs = pid_m * NUM_GROUPS + tl . arange ( 0 , NUM_GROUPS )
293
305
k_offs = pid_k * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
294
306
x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
295
307
x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
@@ -298,8 +310,10 @@ def fp8_blockwise_act_quant_lhs_kernel(
298
310
# Perform scaling
299
311
max_fp8_e4m3 = 448.0
300
312
min_fp8_e4m3 = - 448.0
301
- amax = tl .clamp (tl .max (tl .abs (x )), min = EPS , max = float ("inf" )).to (tl .float64 )
302
- scale = (max_fp8_e4m3 / amax ).to (tl .float32 )
313
+
314
+ # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1)
315
+ amax = tl .clamp (tl .max (tl .abs (x ), axis = 1 ), min = EPS , max = float ("inf" )).to (tl .float64 )
316
+ scale = (max_fp8_e4m3 / amax ).to (tl .float32 )[:, None ]
303
317
y = x * scale
304
318
y = tl .clamp (y , min = min_fp8_e4m3 , max = max_fp8_e4m3 ).to (y_ptr .dtype .element_ty )
305
319
@@ -309,7 +323,7 @@ def fp8_blockwise_act_quant_lhs_kernel(
309
323
tl .store (y_ptr + y_offs , y , mask = y_mask )
310
324
311
325
# Write reciprocal scales
312
- scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
326
+ scale_offs = m_offs [:, None ] * s_stride_dim_0 + pid_k * s_stride_dim_1
313
327
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
314
328
315
329
@@ -334,7 +348,10 @@ def fp8_blockwise_act_quant_lhs(
334
348
(M , K // block_size ),
335
349
(1 , M ),
336
350
)
337
- grid = lambda meta : (M , triton .cdiv (K , meta ["BLOCK_SIZE" ]))
351
+ grid = lambda meta : (
352
+ triton .cdiv (M , meta ["NUM_GROUPS" ]),
353
+ triton .cdiv (K , meta ["BLOCK_SIZE" ]),
354
+ )
338
355
fp8_blockwise_act_quant_lhs_kernel [grid ](
339
356
x ,
340
357
x .stride (0 ),
@@ -353,7 +370,7 @@ def fp8_blockwise_act_quant_lhs(
353
370
return y , s
354
371
355
372
356
- @triton .autotune (configs = quant_kernel_configs , key = ["K" ])
373
+ @triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
357
374
@triton .jit
358
375
def fp8_blockwise_act_quant_rhs_kernel (
359
376
x_ptr ,
@@ -368,33 +385,38 @@ def fp8_blockwise_act_quant_rhs_kernel(
368
385
M ,
369
386
K : tl .constexpr ,
370
387
BLOCK_SIZE : tl .constexpr ,
388
+ NUM_GROUPS : tl .constexpr ,
371
389
EPS : tl .constexpr ,
372
390
):
373
391
pid_m = tl .program_id (axis = 0 )
374
392
pid_k = tl .program_id (axis = 1 )
375
393
376
- # Load (block_size x 1) tile of x, where input is row major
394
+ # Load (block_size x block_size) tile of x, where input is row major.
395
+ # Each scaling group is (block_size x 1), but we load (block_size x block_size)
396
+ # to facilitate coalesced gmem accesses and improve efficiency.
377
397
m_offs = pid_m * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
378
- k_offs = pid_k
398
+ k_offs = pid_k * NUM_GROUPS + tl . arange ( 0 , NUM_GROUPS )
379
399
x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
380
400
x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
381
401
x = tl .load (x_ptr + x_offs , mask = x_mask )
382
402
383
403
# Perform scaling
384
404
max_fp8_e4m3 = 448.0
385
405
min_fp8_e4m3 = - 448.0
386
- amax = tl .clamp (tl .max (tl .abs (x )), min = EPS , max = float ("inf" )).to (tl .float64 )
387
- scale = (max_fp8_e4m3 / amax ).to (tl .float32 )
406
+
407
+ # Column-wise scales for RHS operand, shape (1, block_size)
408
+ amax = tl .clamp (tl .max (tl .abs (x ), axis = 0 ), min = EPS , max = float ("inf" )).to (tl .float64 )
409
+ scale = (max_fp8_e4m3 / amax ).to (tl .float32 )[None , :]
388
410
y = x * scale
389
411
y = tl .clamp (y , min = min_fp8_e4m3 , max = max_fp8_e4m3 ).to (y_ptr .dtype .element_ty )
390
412
391
- # Write output to column major fomrat
413
+ # Write output to column major format
392
414
y_offs = m_offs [:, None ] * y_stride_dim_0 + k_offs [None , :] * y_stride_dim_1
393
415
y_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
394
416
tl .store (y_ptr + y_offs , y , mask = y_mask )
395
417
396
418
# Write scales
397
- scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
419
+ scale_offs = pid_m * s_stride_dim_0 + k_offs [ None , :] * s_stride_dim_1
398
420
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
399
421
400
422
@@ -420,7 +442,7 @@ def fp8_blockwise_act_quant_rhs(
420
442
421
443
grid = lambda meta : (
422
444
triton .cdiv (M , meta ["BLOCK_SIZE" ]),
423
- K ,
445
+ triton . cdiv ( K , meta [ "NUM_GROUPS" ]) ,
424
446
)
425
447
fp8_blockwise_act_quant_rhs_kernel [grid ](
426
448
x ,
@@ -440,7 +462,7 @@ def fp8_blockwise_act_quant_rhs(
440
462
return y , s
441
463
442
464
443
- @triton .autotune (configs = quant_kernel_configs , key = ["K" ])
465
+ @triton .autotune (configs = quant_kernel_configs_with_groups , key = ["K" ])
444
466
@triton .jit
445
467
def fp8_blockwise_act_quant_transposed_lhs_kernel (
446
468
x_ptr ,
@@ -454,8 +476,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
454
476
s_stride_dim_1 ,
455
477
M ,
456
478
K : tl .constexpr ,
457
- SCALE_BLOCK_SIZE : tl .constexpr , # For scaling groups, not for grid/parallelization
458
- BLOCK_SIZE_K : tl .constexpr , # For grid/parallelization, not for scaling groups
479
+ BLOCK_SIZE : tl .constexpr , # For scaling groups, not for grid/parallelization
480
+ NUM_GROUPS : tl .constexpr , # For grid/parallelization, not for scaling groups
459
481
EPS : tl .constexpr ,
460
482
):
461
483
# This kernel reads data in row-major format, and writes to an output tensor with
@@ -465,12 +487,12 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
465
487
pid_m = tl .program_id (axis = 0 )
466
488
pid_k = tl .program_id (axis = 1 )
467
489
468
- # Load (block_size x block_size_k ) block of input, where input is row major.
490
+ # Load (block_size x num_groups ) block of input, where input is row major.
469
491
# We will be computing (block_size x 1) scaling factors (columns), and computing
470
- # `block_size_k ` at a time, so we aren't parallelizing with 1 thread per column,
492
+ # `num_groups ` at a time, so we aren't parallelizing with 1 thread per column,
471
493
# which will fail to launch for large tensors, due to max block number of 65535.
472
- m_offs = pid_m * SCALE_BLOCK_SIZE + tl .arange (0 , SCALE_BLOCK_SIZE )
473
- k_offs = pid_k * BLOCK_SIZE_K + tl .arange (0 , BLOCK_SIZE_K )
494
+ m_offs = pid_m * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
495
+ k_offs = pid_k * NUM_GROUPS + tl .arange (0 , NUM_GROUPS )
474
496
x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
475
497
x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
476
498
x = tl .load (x_ptr + x_offs , mask = x_mask )
@@ -496,7 +518,7 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
496
518
497
519
# Scale tensor size is (K, M // SCALE_BLOCK_SIZE)
498
520
scale_offs = scale_k_offs * s_stride_dim_0 + scale_m_off * s_stride_dim_1
499
- scale_mask = (scale_k_offs < K ) & (scale_m_off < M // SCALE_BLOCK_SIZE )
521
+ scale_mask = (scale_k_offs < K ) & (scale_m_off < M // BLOCK_SIZE )
500
522
501
523
# Write out reciprocal scales
502
524
tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
@@ -524,8 +546,8 @@ def fp8_blockwise_act_quant_transposed_lhs(
524
546
(1 , K ), # stride
525
547
)
526
548
grid = lambda meta : (
527
- triton .cdiv (M , meta ["SCALE_BLOCK_SIZE " ]),
528
- triton .cdiv (K , meta ["BLOCK_SIZE_K " ]),
549
+ triton .cdiv (M , meta ["BLOCK_SIZE " ]),
550
+ triton .cdiv (K , meta ["NUM_GROUPS " ]),
529
551
)
530
552
531
553
fp8_blockwise_act_quant_transposed_lhs_kernel [grid ](
@@ -540,8 +562,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
540
562
s .stride (1 ),
541
563
M ,
542
564
K = K ,
543
- SCALE_BLOCK_SIZE = block_size , # Scaling group size
544
- BLOCK_SIZE_K = block_size , # Just for parallelize the work along K as well
565
+ BLOCK_SIZE = block_size , # Scaling group size
545
566
EPS = EPS ,
546
567
)
547
568
return y , s
0 commit comments