Skip to content

Commit 091edf3

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b6a9700 commit 091edf3

File tree

2 files changed

+96
-50
lines changed

2 files changed

+96
-50
lines changed

tests/pytorch/fused_attn/test_fused_attn_with_cp.py

+51-5
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,65 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
108108
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
109109
), # GQA
110110
"cp_3_0": ModelConfig(
111-
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64,
111+
2,
112+
12,
113+
12,
114+
128,
115+
4096,
116+
4096,
117+
0.0,
118+
"causal",
119+
"no_bias",
120+
head_dim_v=64,
112121
), # MLA
113122
"cp_3_1": ModelConfig(
114-
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64,
123+
2,
124+
12,
125+
12,
126+
128,
127+
4096,
128+
4096,
129+
0.0,
130+
"no_mask",
131+
"no_bias",
132+
head_dim_v=64,
115133
), # MLA
116134
"cp_3_2": ModelConfig(
117-
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64,
135+
2,
136+
12,
137+
12,
138+
128,
139+
4096,
140+
4096,
141+
0.0,
142+
"causal",
143+
"post_scale_bias",
144+
head_dim_v=64,
118145
), # MLA
119146
"cp_3_3": ModelConfig(
120-
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64,
147+
2,
148+
12,
149+
12,
150+
128,
151+
4096,
152+
4096,
153+
0.0,
154+
"no_mask",
155+
"post_scale_bias",
156+
head_dim_v=64,
121157
), # MLA
122158
"cp_3_4": ModelConfig(
123-
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64, window_size=(512, 0),
159+
2,
160+
12,
161+
12,
162+
128,
163+
4096,
164+
4096,
165+
0.0,
166+
"causal",
167+
"no_bias",
168+
head_dim_v=64,
169+
window_size=(512, 0),
124170
), # MLA
125171
}
126172

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

+45-45
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def forward(
683683
p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1)
684684
elif qkv_format in ["bshd", "sbhd"]:
685685
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
686-
else: # qkv_format == "thd"
686+
else: # qkv_format == "thd"
687687
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
688688
send_recv_reqs = [[], []]
689689

@@ -736,12 +736,8 @@ def forward(
736736
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
737737
if enable_mla:
738738
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
739-
k_part = k_part.view(
740-
k_part.shape[0], -1, *k_part.shape[-2:]
741-
)
742-
v_part = v_part.view(
743-
v_part.shape[0], -1, *v_part.shape[-2:]
744-
)
739+
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
740+
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
745741
else:
746742
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
747743
kv_inputs[i % 2] = kv_inputs[i % 2].view(
@@ -752,12 +748,8 @@ def forward(
752748
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
753749
if enable_mla:
754750
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
755-
k_part = k_part.view(
756-
-1, k_part.shape[2], *k_part.shape[-2:]
757-
)
758-
v_part = v_part.view(
759-
-1, v_part.shape[2], *v_part.shape[-2:]
760-
)
751+
k_part = k_part.view(-1, k_part.shape[2], *k_part.shape[-2:])
752+
v_part = v_part.view(-1, v_part.shape[2], *v_part.shape[-2:])
761753
else:
762754
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
763755
kv_inputs[i % 2] = kv_inputs[i % 2].view(
@@ -1054,12 +1046,8 @@ def forward(
10541046
q_inputs[i % 2] = q[:, 1, ...]
10551047
if enable_mla:
10561048
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
1057-
k_part = k_part.view(
1058-
k_part.shape[0], -1, *k_part.shape[-2:]
1059-
)
1060-
v_part = v_part.view(
1061-
v_part.shape[0], -1, *v_part.shape[-2:]
1062-
)
1049+
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
1050+
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
10631051
else:
10641052
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
10651053
kv_inputs[i % 2] = kv_inputs[i % 2].view(
@@ -1070,12 +1058,8 @@ def forward(
10701058
q_inputs[i % 2] = q[1]
10711059
if enable_mla:
10721060
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
1073-
k_part = k_part.view(
1074-
-1, k_part.shape[2], *k_part.shape[-2:]
1075-
)
1076-
v_part = v_part.view(
1077-
-1, v_part.shape[2], *v_part.shape[-2:]
1078-
)
1061+
k_part = k_part.view(-1, k_part.shape[2], *k_part.shape[-2:])
1062+
v_part = v_part.view(-1, v_part.shape[2], *v_part.shape[-2:])
10791063
else:
10801064
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
10811065
kv_inputs[i % 2] = kv_inputs[i % 2].view(
@@ -1336,10 +1320,14 @@ def forward(
13361320
softmax_lse = torch.clone(softmax_lse_per_step[0])
13371321
if qkv_format == "thd":
13381322
if enable_mla:
1339-
out = torch.zeros_like(v if not fp8 else out_per_step[0]).view(v_shape)
1323+
out = torch.zeros_like(v if not fp8 else out_per_step[0]).view(
1324+
v_shape
1325+
)
13401326
else:
13411327
# MHA or GQA
1342-
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
1328+
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(
1329+
q.shape
1330+
)
13431331
elif (i - 1) <= rank or not causal:
13441332
flash_attn_fwd_softmax_lse_correction(
13451333
softmax_lse, softmax_lse_per_step[i - 1]
@@ -1774,8 +1762,8 @@ def backward(ctx, dout):
17741762
q_, kv_, out_, dout_ = None, None, None, None
17751763
dq_, dk_, dv_ = None, None, None
17761764
if ctx.enable_mla:
1777-
k_part = kv[:ctx.k_numel].view(*ctx.k_shape)
1778-
v_part = kv[ctx.k_numel:].view(*ctx.v_shape)
1765+
k_part = kv[: ctx.k_numel].view(*ctx.k_shape)
1766+
v_part = kv[ctx.k_numel :].view(*ctx.v_shape)
17791767
# In reversed order of fwd
17801768
if causal:
17811769
if i == (cp_size - 1):
@@ -1816,8 +1804,12 @@ def backward(ctx, dout):
18161804
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
18171805
q_part = q_
18181806
if not ctx.enable_mla:
1819-
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
1820-
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
1807+
k_part = (
1808+
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
1809+
)
1810+
v_part = (
1811+
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
1812+
)
18211813
out_part = out_
18221814
dout_part = dout_
18231815

@@ -1965,8 +1957,12 @@ def backward(ctx, dout):
19651957
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
19661958
q_part = q_
19671959
if not ctx.enable_mla:
1968-
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
1969-
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
1960+
k_part = (
1961+
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
1962+
)
1963+
v_part = (
1964+
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
1965+
)
19701966
out_part = out_
19711967
dout_part = dout_
19721968

@@ -2105,8 +2101,12 @@ def backward(ctx, dout):
21052101

21062102
q_part = q_
21072103
if not ctx.enable_mla:
2108-
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
2109-
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
2104+
k_part = (
2105+
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
2106+
)
2107+
v_part = (
2108+
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
2109+
)
21102110
out_part = out_
21112111
dout_part = dout_
21122112

@@ -2391,8 +2391,8 @@ def backward(ctx, dout):
23912391
if ctx.enable_mla:
23922392
# [b, 2, sk//2, np, hn] or
23932393
# [2, sk//2, b, np, hn]
2394-
dk = dkv[:ctx.k_numel].view(*ctx.k_shape)
2395-
dv = dkv[ctx.k_numel:].view(*ctx.v_shape)
2394+
dk = dkv[: ctx.k_numel].view(*ctx.k_shape)
2395+
dv = dkv[ctx.k_numel :].view(*ctx.v_shape)
23962396
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
23972397
dk_ = dk_.view(*ctx.k_shape)
23982398
dv_ = dv_.view(*ctx.v_shape)
@@ -2422,7 +2422,7 @@ def backward(ctx, dout):
24222422
else:
24232423
dk.copy_(dk_)
24242424
dv.copy_(dv_)
2425-
elif ctx.enable_mla and causal: # enable_mla and not fp8
2425+
elif ctx.enable_mla and causal: # enable_mla and not fp8
24262426
if i == (cp_size - 1):
24272427
if rank == 0:
24282428
if ctx.qkv_format == "bshd":
@@ -2465,14 +2465,14 @@ def backward(ctx, dout):
24652465
elif i > 0:
24662466
dk.add_(dk_)
24672467
dv.add_(dv_)
2468-
else: # i == 0
2468+
else: # i == 0
24692469
dk.copy_(dk_)
24702470
dv.copy_(dv_)
2471-
elif ctx.enable_mla: # enable_mla and not fp8 and not causal
2471+
elif ctx.enable_mla: # enable_mla and not fp8 and not causal
24722472
if i == 0:
24732473
dk.copy_(dk_)
24742474
dv.copy_(dv_)
2475-
else: # i > 0
2475+
else: # i > 0
24762476
dk.add_(dk_)
24772477
dv.add_(dv_)
24782478
elif ctx.fp8:
@@ -2515,12 +2515,12 @@ def backward(ctx, dout):
25152515
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
25162516
elif i > 0:
25172517
dkv.add_(dkv_)
2518-
else: # i == 0
2518+
else: # i == 0
25192519
dkv.copy_(dkv_)
25202520
else:
25212521
if i == 0:
25222522
dkv.copy_(dkv_)
2523-
else: # i > 0
2523+
else: # i > 0
25242524
dkv.add_(dkv_)
25252525

25262526
if ctx.fp8 and ctx.use_fused_attention:
@@ -2533,8 +2533,8 @@ def backward(ctx, dout):
25332533

25342534
if ctx.enable_mla:
25352535
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
2536-
dk_fp8 = dkv_fp8[:ctx.k_numel].view(cp_size, *ctx.k_shape)
2537-
dv_fp8 = dkv_fp8[ctx.k_numel:].view(cp_size, *ctx.v_shape)
2536+
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
2537+
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
25382538
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
25392539
dk_fp8, fake_dtype=torch.float32, internal=True
25402540
)

0 commit comments

Comments
 (0)