@@ -61,8 +61,14 @@ sdpa_config_t xehpg_h64_s128 = {16, 16, 16, 16, 4, 8, 4, 8};
61
61
sdpa_config_t xehpg_h64_s64 = {32 , 16 , 16 , 8 , 8 , 4 , 4 , 8 };
62
62
sdpa_config_t xehpg_h64_2nd = {8 , 16 , 16 , 8 , 8 , 1 , 4 , 2 };
63
63
64
- sdpa_config_t xehpg_q_h64 = {32 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
65
- sdpa_config_t xehpg_q_h64_2nd = {16 , 16 , 8 , 8 , 16 , 1 , 8 , 2 };
64
+ sdpa_config_t xehpg_q_h64 = {32 , 16 , 16 , 16 , 4 , 8 , 4 , 8 };
65
+ sdpa_config_t xehpg_q_h64_s128 = {16 , 16 , 16 , 8 , 8 , 4 , 4 , 8 };
66
+ sdpa_config_t xehpg_q_h64_s64 = {32 , 8 , 32 , 8 , 2 , 8 , 2 , 8 };
67
+ sdpa_config_t xehpg_q_h64_s32 = {8 , 8 , 16 , 8 , 4 , 8 , 4 , 8 };
68
+
69
+ sdpa_config_t xehpg_q_h64_s64_2nd = {8 , 8 , 8 , 8 , 8 , 2 , 8 , 2 };
70
+ sdpa_config_t xehpg_q_h64_s128_2nd = {16 , 8 , 8 , 8 , 8 , 4 , 8 , 4 };
71
+ sdpa_config_t xehpg_q_h64_2nd = {16 , 16 , 8 , 8 , 16 , 2 , 8 , 4 };
66
72
67
73
sdpa_config_t xehpg_h128 = {16 , 16 , 32 , 8 , 8 , 4 , 4 , 8 };
68
74
sdpa_config_t xehpg_h128_s32 = {16 , 16 , 16 , 8 , 16 , 2 , 8 , 4 };
@@ -90,7 +96,24 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2};
90
96
sdpa_config_t xehpc_h64_2nd = {32 , 32 , 32 , 16 , 4 , 1 , 2 , 2 };
91
97
sdpa_config_t xehpc_h64_s64_2nd = {16 , 16 , 16 , 16 , 4 , 1 , 4 , 1 };
92
98
93
- sdpa_config_t xehpc_q_h64 = {16 , 64 , 32 , 16 , 8 , 4 , 2 , 16 };
99
+ sdpa_config_t xehpc_q_h64_s64 = {16 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
100
+ sdpa_config_t xehpc_q_h64_s384 = {16 , 64 , 16 , 32 , 8 , 2 , 4 , 4 };
101
+ sdpa_config_t xehpc_q_h64_s1024 = {16 , 64 , 16 , 16 , 16 , 1 , 4 , 4 };
102
+ sdpa_config_t xehpc_q_h64 = {16 , 64 , 16 , 32 , 8 , 1 , 4 , 2 };
103
+
104
+ sdpa_config_t xehpc_q_h64_s128_integrated = {16 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
105
+ sdpa_config_t xehpc_q_h64_s384_integrated = {16 , 64 , 16 , 16 , 16 , 1 , 4 , 4 };
106
+ sdpa_config_t xehpc_q_h64_s1024_integrated = {16 , 64 , 16 , 32 , 8 , 4 , 4 , 8 };
107
+ sdpa_config_t xehpc_q_h64_integrated = {16 , 64 , 16 , 32 , 16 , 1 , 8 , 2 };
108
+
109
+ sdpa_config_t xehpc_q_h64_s96_2nd = {16 , 16 , 16 , 16 , 8 , 1 , 4 , 1 };
110
+ sdpa_config_t xehpc_q_h64_s256_2nd = {16 , 16 , 16 , 16 , 16 , 1 , 16 , 1 };
111
+ sdpa_config_t xehpc_q_h64_s1152_2nd = {16 , 16 , 16 , 16 , 16 , 1 , 16 , 1 };
112
+ sdpa_config_t xehpc_q_h64_2nd = {64 , 16 , 16 , 16 , 16 , 2 , 16 , 2 };
113
+
114
+ sdpa_config_t xehpc_q_h64_s96_2nd_integrated = {16 , 16 , 16 , 16 , 8 , 1 , 4 , 1 };
115
+ sdpa_config_t xehpc_q_h64_s384_2nd_integrated = {64 , 16 , 16 , 16 , 4 , 1 , 4 , 1 };
116
+ sdpa_config_t xehpc_q_h64_2nd_integrated = {16 , 16 , 16 , 16 , 8 , 1 , 8 , 1 };
94
117
95
118
sdpa_config_t xehpc_h128 = {16 , 64 , 32 , 16 , 16 , 2 , 4 , 8 };
96
119
sdpa_config_t xehpc_h128_s64 = {16 , 32 , 32 , 32 , 4 , 2 , 4 , 2 };
@@ -121,8 +144,16 @@ sdpa_config_t *choose_config_xehpg(
121
144
return &xehpg_h32;
122
145
} else if (head_size <= 64 ) {
123
146
if (quantized) {
124
- if (thin_q) return &xehpg_q_h64_2nd;
125
- return &xehpg_q_h64;
147
+ if (thin_q) {
148
+ if (seq <= 64 ) return &xehpg_q_h64_s64_2nd;
149
+ if (seq <= 128 ) return &xehpg_q_h64_s128_2nd;
150
+ return &xehpg_q_h64_2nd;
151
+ } else {
152
+ if (seq <= 32 ) return &xehpg_q_h64_s32;
153
+ if (seq <= 64 ) return &xehpg_q_h64_s64;
154
+ if (seq <= 128 ) return &xehpg_q_h64_s128;
155
+ return &xehpg_q_h64;
156
+ }
126
157
}
127
158
if (thin_q) return &xehpg_h64_2nd;
128
159
if (seq <= 64 ) return &xehpg_h64_s64;
@@ -164,10 +195,33 @@ sdpa_config_t *choose_config_xehpc(dim_t head_size, dim_t seq, bool thin_q,
164
195
return &xehpc_h32;
165
196
} else if (head_size <= 64 ) {
166
197
if (thin_q) {
198
+ if (quantized) {
199
+ if (is_integrated) {
200
+ if (seq <= 96 ) return &xehpc_q_h64_s96_2nd_integrated;
201
+ if (seq <= 384 ) return &xehpc_q_h64_s384_2nd_integrated;
202
+ return &xehpc_q_h64_2nd_integrated;
203
+ }
204
+ if (seq <= 96 ) return &xehpc_q_h64_s96_2nd;
205
+ if (seq <= 256 ) return &xehpc_q_h64_s256_2nd;
206
+ if (seq <= 1152 ) return &xehpc_q_h64_s1152_2nd;
207
+ return &xehpc_q_h64_2nd;
208
+ }
209
+
167
210
if (seq <= 64 ) return &xehpc_h64_s64_2nd;
168
211
return &xehpc_h64_2nd;
169
212
}
170
- if (quantized && seq >= 256 ) return &xehpc_q_h64;
213
+ if (quantized) {
214
+ if (is_integrated) {
215
+ if (seq <= 128 ) return &xehpc_q_h64_s128_integrated;
216
+ if (seq <= 384 ) return &xehpc_q_h64_s384_integrated;
217
+ if (seq <= 1024 ) return &xehpc_q_h64_s1024_integrated;
218
+ return &xehpc_q_h64_integrated;
219
+ }
220
+ if (seq <= 64 ) return &xehpc_q_h64_s64;
221
+ if (seq <= 384 ) return &xehpc_q_h64_s384;
222
+ if (seq <= 1024 ) return &xehpc_q_h64_s1024;
223
+ return &xehpc_q_h64;
224
+ }
171
225
if (seq <= 32 ) return &xehpc_h64_s32;
172
226
if (seq <= 64 ) return &xehpc_h64_s64;
173
227
return &xehpc_h64;
0 commit comments