@@ -683,7 +683,7 @@ def forward(
683
683
p2p_comm_buffers [0 ] = torch .cat ((k .view (- 1 ), v .view (- 1 )), dim = - 1 )
684
684
elif qkv_format in ["bshd" , "sbhd" ]:
685
685
p2p_comm_buffers [0 ] = torch .cat ((k .unsqueeze (- 3 ), v .unsqueeze (- 3 )), dim = - 3 )
686
- else : # qkv_format == "thd"
686
+ else : # qkv_format == "thd"
687
687
p2p_comm_buffers [0 ] = torch .cat ((k .unsqueeze (0 ), v .unsqueeze (0 )), dim = 0 )
688
688
send_recv_reqs = [[], []]
689
689
@@ -736,12 +736,8 @@ def forward(
736
736
q_inputs [i % 2 ] = q .view (q .shape [0 ], - 1 , * q .shape [- 2 :])
737
737
if enable_mla :
738
738
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
739
- k_part = k_part .view (
740
- k_part .shape [0 ], - 1 , * k_part .shape [- 2 :]
741
- )
742
- v_part = v_part .view (
743
- v_part .shape [0 ], - 1 , * v_part .shape [- 2 :]
744
- )
739
+ k_part = k_part .view (k_part .shape [0 ], - 1 , * k_part .shape [- 2 :])
740
+ v_part = v_part .view (v_part .shape [0 ], - 1 , * v_part .shape [- 2 :])
745
741
else :
746
742
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
747
743
kv_inputs [i % 2 ] = kv_inputs [i % 2 ].view (
@@ -752,12 +748,8 @@ def forward(
752
748
q_inputs [i % 2 ] = q .view (- 1 , * q .shape [- 3 :])
753
749
if enable_mla :
754
750
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
755
- k_part = k_part .view (
756
- - 1 , k_part .shape [2 ], * k_part .shape [- 2 :]
757
- )
758
- v_part = v_part .view (
759
- - 1 , v_part .shape [2 ], * v_part .shape [- 2 :]
760
- )
751
+ k_part = k_part .view (- 1 , k_part .shape [2 ], * k_part .shape [- 2 :])
752
+ v_part = v_part .view (- 1 , v_part .shape [2 ], * v_part .shape [- 2 :])
761
753
else :
762
754
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
763
755
kv_inputs [i % 2 ] = kv_inputs [i % 2 ].view (
@@ -1054,12 +1046,8 @@ def forward(
1054
1046
q_inputs [i % 2 ] = q [:, 1 , ...]
1055
1047
if enable_mla :
1056
1048
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
1057
- k_part = k_part .view (
1058
- k_part .shape [0 ], - 1 , * k_part .shape [- 2 :]
1059
- )
1060
- v_part = v_part .view (
1061
- v_part .shape [0 ], - 1 , * v_part .shape [- 2 :]
1062
- )
1049
+ k_part = k_part .view (k_part .shape [0 ], - 1 , * k_part .shape [- 2 :])
1050
+ v_part = v_part .view (v_part .shape [0 ], - 1 , * v_part .shape [- 2 :])
1063
1051
else :
1064
1052
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
1065
1053
kv_inputs [i % 2 ] = kv_inputs [i % 2 ].view (
@@ -1070,12 +1058,8 @@ def forward(
1070
1058
q_inputs [i % 2 ] = q [1 ]
1071
1059
if enable_mla :
1072
1060
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
1073
- k_part = k_part .view (
1074
- - 1 , k_part .shape [2 ], * k_part .shape [- 2 :]
1075
- )
1076
- v_part = v_part .view (
1077
- - 1 , v_part .shape [2 ], * v_part .shape [- 2 :]
1078
- )
1061
+ k_part = k_part .view (- 1 , k_part .shape [2 ], * k_part .shape [- 2 :])
1062
+ v_part = v_part .view (- 1 , v_part .shape [2 ], * v_part .shape [- 2 :])
1079
1063
else :
1080
1064
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
1081
1065
kv_inputs [i % 2 ] = kv_inputs [i % 2 ].view (
@@ -1336,10 +1320,14 @@ def forward(
1336
1320
softmax_lse = torch .clone (softmax_lse_per_step [0 ])
1337
1321
if qkv_format == "thd" :
1338
1322
if enable_mla :
1339
- out = torch .zeros_like (v if not fp8 else out_per_step [0 ]).view (v_shape )
1323
+ out = torch .zeros_like (v if not fp8 else out_per_step [0 ]).view (
1324
+ v_shape
1325
+ )
1340
1326
else :
1341
1327
# MHA or GQA
1342
- out = torch .zeros_like (q if not fp8 else out_per_step [0 ]).view (q .shape )
1328
+ out = torch .zeros_like (q if not fp8 else out_per_step [0 ]).view (
1329
+ q .shape
1330
+ )
1343
1331
elif (i - 1 ) <= rank or not causal :
1344
1332
flash_attn_fwd_softmax_lse_correction (
1345
1333
softmax_lse , softmax_lse_per_step [i - 1 ]
@@ -1774,8 +1762,8 @@ def backward(ctx, dout):
1774
1762
q_ , kv_ , out_ , dout_ = None , None , None , None
1775
1763
dq_ , dk_ , dv_ = None , None , None
1776
1764
if ctx .enable_mla :
1777
- k_part = kv [:ctx .k_numel ].view (* ctx .k_shape )
1778
- v_part = kv [ctx .k_numel :].view (* ctx .v_shape )
1765
+ k_part = kv [: ctx .k_numel ].view (* ctx .k_shape )
1766
+ v_part = kv [ctx .k_numel :].view (* ctx .v_shape )
1779
1767
# In reversed order of fwd
1780
1768
if causal :
1781
1769
if i == (cp_size - 1 ):
@@ -1816,8 +1804,12 @@ def backward(ctx, dout):
1816
1804
aux_ctx_tensors += [attn_biases [cp_size - i - 1 ]]
1817
1805
q_part = q_
1818
1806
if not ctx .enable_mla :
1819
- k_part = kv_ [..., 0 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [0 ]
1820
- v_part = kv_ [..., 1 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [1 ]
1807
+ k_part = (
1808
+ kv_ [..., 0 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [0 ]
1809
+ )
1810
+ v_part = (
1811
+ kv_ [..., 1 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [1 ]
1812
+ )
1821
1813
out_part = out_
1822
1814
dout_part = dout_
1823
1815
@@ -1965,8 +1957,12 @@ def backward(ctx, dout):
1965
1957
aux_ctx_tensors += [attn_biases [cp_size - i - 1 ]]
1966
1958
q_part = q_
1967
1959
if not ctx .enable_mla :
1968
- k_part = kv_ [..., 0 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [0 ]
1969
- v_part = kv_ [..., 1 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [1 ]
1960
+ k_part = (
1961
+ kv_ [..., 0 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [0 ]
1962
+ )
1963
+ v_part = (
1964
+ kv_ [..., 1 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [1 ]
1965
+ )
1970
1966
out_part = out_
1971
1967
dout_part = dout_
1972
1968
@@ -2105,8 +2101,12 @@ def backward(ctx, dout):
2105
2101
2106
2102
q_part = q_
2107
2103
if not ctx .enable_mla :
2108
- k_part = kv_ [..., 0 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [0 ]
2109
- v_part = kv_ [..., 1 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [1 ]
2104
+ k_part = (
2105
+ kv_ [..., 0 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [0 ]
2106
+ )
2107
+ v_part = (
2108
+ kv_ [..., 1 , :, :] if ctx .qkv_format in ["bshd" , "sbhd" ] else kv_ [1 ]
2109
+ )
2110
2110
out_part = out_
2111
2111
dout_part = dout_
2112
2112
@@ -2391,8 +2391,8 @@ def backward(ctx, dout):
2391
2391
if ctx .enable_mla :
2392
2392
# [b, 2, sk//2, np, hn] or
2393
2393
# [2, sk//2, b, np, hn]
2394
- dk = dkv [:ctx .k_numel ].view (* ctx .k_shape )
2395
- dv = dkv [ctx .k_numel :].view (* ctx .v_shape )
2394
+ dk = dkv [: ctx .k_numel ].view (* ctx .k_shape )
2395
+ dv = dkv [ctx .k_numel :].view (* ctx .v_shape )
2396
2396
if causal and (i < (cp_size - rank - 1 ) or i == (cp_size - 1 )):
2397
2397
dk_ = dk_ .view (* ctx .k_shape )
2398
2398
dv_ = dv_ .view (* ctx .v_shape )
@@ -2422,7 +2422,7 @@ def backward(ctx, dout):
2422
2422
else :
2423
2423
dk .copy_ (dk_ )
2424
2424
dv .copy_ (dv_ )
2425
- elif ctx .enable_mla and causal : # enable_mla and not fp8
2425
+ elif ctx .enable_mla and causal : # enable_mla and not fp8
2426
2426
if i == (cp_size - 1 ):
2427
2427
if rank == 0 :
2428
2428
if ctx .qkv_format == "bshd" :
@@ -2465,14 +2465,14 @@ def backward(ctx, dout):
2465
2465
elif i > 0 :
2466
2466
dk .add_ (dk_ )
2467
2467
dv .add_ (dv_ )
2468
- else : # i == 0
2468
+ else : # i == 0
2469
2469
dk .copy_ (dk_ )
2470
2470
dv .copy_ (dv_ )
2471
- elif ctx .enable_mla : # enable_mla and not fp8 and not causal
2471
+ elif ctx .enable_mla : # enable_mla and not fp8 and not causal
2472
2472
if i == 0 :
2473
2473
dk .copy_ (dk_ )
2474
2474
dv .copy_ (dv_ )
2475
- else : # i > 0
2475
+ else : # i > 0
2476
2476
dk .add_ (dk_ )
2477
2477
dv .add_ (dv_ )
2478
2478
elif ctx .fp8 :
@@ -2515,12 +2515,12 @@ def backward(ctx, dout):
2515
2515
tex .thd_grad_correction (dkv , dkv_ , cu_seqlens_kv_padded , "add" , "none" )
2516
2516
elif i > 0 :
2517
2517
dkv .add_ (dkv_ )
2518
- else : # i == 0
2518
+ else : # i == 0
2519
2519
dkv .copy_ (dkv_ )
2520
2520
else :
2521
2521
if i == 0 :
2522
2522
dkv .copy_ (dkv_ )
2523
- else : # i > 0
2523
+ else : # i > 0
2524
2524
dkv .add_ (dkv_ )
2525
2525
2526
2526
if ctx .fp8 and ctx .use_fused_attention :
@@ -2533,8 +2533,8 @@ def backward(ctx, dout):
2533
2533
2534
2534
if ctx .enable_mla :
2535
2535
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
2536
- dk_fp8 = dkv_fp8 [:ctx .k_numel ].view (cp_size , * ctx .k_shape )
2537
- dv_fp8 = dkv_fp8 [ctx .k_numel :].view (cp_size , * ctx .v_shape )
2536
+ dk_fp8 = dkv_fp8 [: ctx .k_numel ].view (cp_size , * ctx .k_shape )
2537
+ dv_fp8 = dkv_fp8 [ctx .k_numel :].view (cp_size , * ctx .v_shape )
2538
2538
dk = ctx .dQKV_CP_quantizer .create_tensor_from_data (
2539
2539
dk_fp8 , fake_dtype = torch .float32 , internal = True
2540
2540
)
0 commit comments