Skip to content

Commit f290c61

Browse files
add UT for MLA CP
Signed-off-by: Yuzhong Wang <yuzhongw@nvidia.com>
1 parent aa4fbe8 commit f290c61

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/pytorch/fused_attn/test_fused_attn_with_cp.py

+17
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
107107
"cp_2_4": ModelConfig(
108108
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
109109
), # GQA
110+
"cp_3_0": ModelConfig(
111+
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64,
112+
), # MLA
113+
"cp_3_1": ModelConfig(
114+
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64,
115+
), # MLA
116+
"cp_3_2": ModelConfig(
117+
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64,
118+
), # MLA
119+
"cp_3_3": ModelConfig(
120+
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64,
121+
), # MLA
122+
"cp_3_4": ModelConfig(
123+
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64, window_size=(512, 0),
124+
), # MLA
110125
}
111126

112127

@@ -159,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
159174
)
160175
if dtype != "fp8" and fp8_mha:
161176
pytest.skip("Only fp8 works with fp8_mha=True!")
177+
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
178+
pytest.skip("MLA CP currently only support KV P2P!")
162179

163180
subprocess.run(
164181
get_bash_arguments(

0 commit comments

Comments
 (0)