@@ -107,6 +107,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
107
107
"cp_2_4" : ModelConfig (
108
108
2 , 12 , 2 , 128 , 4096 , 4096 , 0.0 , "causal" , "no_bias" , window_size = (512 , 0 )
109
109
), # GQA
110
+ "cp_3_0" : ModelConfig (
111
+ 2 , 12 , 12 , 128 , 4096 , 4096 , 0.0 , "causal" , "no_bias" , head_dim_v = 64 ,
112
+ ), # MLA
113
+ "cp_3_1" : ModelConfig (
114
+ 2 , 12 , 12 , 128 , 4096 , 4096 , 0.0 , "no_mask" , "no_bias" , head_dim_v = 64 ,
115
+ ), # MLA
116
+ "cp_3_2" : ModelConfig (
117
+ 2 , 12 , 12 , 128 , 4096 , 4096 , 0.0 , "causal" , "post_scale_bias" , head_dim_v = 64 ,
118
+ ), # MLA
119
+ "cp_3_3" : ModelConfig (
120
+ 2 , 12 , 12 , 128 , 4096 , 4096 , 0.0 , "no_mask" , "post_scale_bias" , head_dim_v = 64 ,
121
+ ), # MLA
122
+ "cp_3_4" : ModelConfig (
123
+ 2 , 12 , 12 , 128 , 4096 , 4096 , 0.0 , "causal" , "no_bias" , head_dim_v = 64 , window_size = (512 , 0 ),
124
+ ), # MLA
110
125
}
111
126
112
127
@@ -159,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
159
174
)
160
175
if dtype != "fp8" and fp8_mha :
161
176
pytest .skip ("Only fp8 works with fp8_mha=True!" )
177
+ if "p2p" not in cp_comm_type and config .head_dim_qk != config .head_dim_v :
178
+ pytest .skip ("MLA CP currently only support KV P2P!" )
162
179
163
180
subprocess .run (
164
181
get_bash_arguments (
0 commit comments