@@ -61,8 +61,14 @@ sdpa_config_t xehpg_h64_s128 = {16, 16, 16, 16, 4, 8, 4, 8};
61
61
sdpa_config_t xehpg_h64_s64 = {32 , 16 , 16 , 8 , 8 , 4 , 4 , 8 };
62
62
sdpa_config_t xehpg_h64_2nd = {8 , 16 , 16 , 8 , 8 , 1 , 4 , 2 };
63
63
64
- sdpa_config_t xehpg_q_h64 = {32 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
65
- sdpa_config_t xehpg_q_h64_2nd = {16 , 16 , 8 , 8 , 16 , 1 , 8 , 2 };
64
+ sdpa_config_t xehpg_q_h64 = {32 , 16 , 16 , 16 , 4 , 8 , 4 , 8 };
65
+ sdpa_config_t xehpg_q_h64_s128 = {16 , 16 , 16 , 8 , 8 , 4 , 4 , 8 };
66
+ sdpa_config_t xehpg_q_h64_s64 = {32 , 8 , 32 , 8 , 2 , 8 , 2 , 8 };
67
+ sdpa_config_t xehpg_q_h64_s32 = {8 , 8 , 16 , 8 , 4 , 8 , 4 , 8 };
68
+
69
+ sdpa_config_t xehpg_q_h64_s64_2nd = {8 , 8 , 8 , 8 , 8 , 2 , 8 , 2 };
70
+ sdpa_config_t xehpg_q_h64_s128_2nd = {16 , 8 , 8 , 8 , 8 , 4 , 8 , 4 };
71
+ sdpa_config_t xehpg_q_h64_2nd = {16 , 16 , 8 , 8 , 16 , 2 , 8 , 4 };
66
72
67
73
sdpa_config_t xehpg_h128 = {16 , 16 , 32 , 8 , 8 , 4 , 4 , 8 };
68
74
sdpa_config_t xehpg_h128_s32 = {16 , 16 , 16 , 8 , 16 , 2 , 8 , 4 };
@@ -90,7 +96,23 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2};
90
96
sdpa_config_t xehpc_h64_2nd = {32 , 32 , 32 , 16 , 4 , 1 , 2 , 2 };
91
97
sdpa_config_t xehpc_h64_s64_2nd = {16 , 16 , 16 , 16 , 4 , 1 , 4 , 1 };
92
98
93
- sdpa_config_t xehpc_q_h64 = {16 , 64 , 32 , 16 , 8 , 4 , 2 , 16 };
99
+ sdpa_config_t xehpc_q_h64_s64 = {16 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
100
+ sdpa_config_t xehpc_q_h64_s384 = {16 , 64 , 16 , 32 , 8 , 2 , 4 , 4 };
101
+ sdpa_config_t xehpc_q_h64_s1024 = {16 , 64 , 16 , 16 , 16 , 1 , 4 , 4 };
102
+ sdpa_config_t xehpc_q_h64 = {16 , 64 , 16 , 32 , 8 , 1 , 4 , 2 };
103
+
104
+ sdpa_config_t xehpc_q_h64_s128_integrated = {16 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
105
+ sdpa_config_t xehpc_q_h64_s384_integrated = {16 , 64 , 16 , 16 , 16 , 1 , 4 , 4 };
106
+ sdpa_config_t xehpc_q_h64_integrated = {16 , 64 , 16 , 32 , 16 , 1 , 8 , 2 };
107
+
108
+ sdpa_config_t xehpc_q_h64_s96_2nd = {16 , 16 , 16 , 16 , 8 , 1 , 4 , 1 };
109
+ sdpa_config_t xehpc_q_h64_s256_2nd = {16 , 16 , 16 , 16 , 16 , 1 , 16 , 1 };
110
+ sdpa_config_t xehpc_q_h64_s1152_2nd = {16 , 16 , 16 , 16 , 16 , 1 , 16 , 1 };
111
+ sdpa_config_t xehpc_q_h64_2nd = {64 , 16 , 16 , 16 , 16 , 2 , 16 , 2 };
112
+
113
+ sdpa_config_t xehpc_q_h64_s96_2nd_integrated = {16 , 16 , 16 , 16 , 8 , 1 , 4 , 1 };
114
+ sdpa_config_t xehpc_q_h64_s384_2nd_integrated = {64 , 16 , 16 , 16 , 4 , 1 , 4 , 1 };
115
+ sdpa_config_t xehpc_q_h64_2nd_integrated = {16 , 16 , 16 , 16 , 8 , 1 , 8 , 1 };
94
116
95
117
sdpa_config_t xehpc_h128 = {16 , 64 , 32 , 16 , 16 , 2 , 4 , 8 };
96
118
sdpa_config_t xehpc_h128_s64 = {16 , 32 , 32 , 32 , 4 , 2 , 4 , 2 };
@@ -121,8 +143,16 @@ sdpa_config_t *choose_config_xehpg(
121
143
return &xehpg_h32;
122
144
} else if (head_size <= 64 ) {
123
145
if (quantized) {
124
- if (thin_q) return &xehpg_q_h64_2nd;
125
- return &xehpg_q_h64;
146
+ if (thin_q) {
147
+ if (seq <= 64 ) return &xehpg_q_h64_s64_2nd;
148
+ if (seq <= 128 ) return &xehpg_q_h64_s128_2nd;
149
+ return &xehpg_q_h64_2nd;
150
+ } else {
151
+ if (seq <= 32 ) return &xehpg_q_h64_s32;
152
+ if (seq <= 64 ) return &xehpg_q_h64_s64;
153
+ if (seq <= 128 ) return &xehpg_q_h64_s128;
154
+ return &xehpg_q_h64;
155
+ }
126
156
}
127
157
if (thin_q) return &xehpg_h64_2nd;
128
158
if (seq <= 64 ) return &xehpg_h64_s64;
@@ -164,10 +194,32 @@ sdpa_config_t *choose_config_xehpc(dim_t head_size, dim_t seq, bool thin_q,
164
194
return &xehpc_h32;
165
195
} else if (head_size <= 64 ) {
166
196
if (thin_q) {
197
+ if (quantized) {
198
+ if (is_integrated) {
199
+ if (seq <= 96 ) return &xehpc_q_h64_s96_2nd_integrated;
200
+ if (seq <= 384 ) return &xehpc_q_h64_s384_2nd_integrated;
201
+ return &xehpc_q_h64_2nd_integrated;
202
+ }
203
+ if (seq <= 96 ) return &xehpc_q_h64_s96_2nd;
204
+ if (seq <= 256 ) return &xehpc_q_h64_s256_2nd;
205
+ if (seq <= 1152 ) return &xehpc_q_h64_s1152_2nd;
206
+ return &xehpc_q_h64_2nd;
207
+ }
208
+
167
209
if (seq <= 64 ) return &xehpc_h64_s64_2nd;
168
210
return &xehpc_h64_2nd;
169
211
}
170
- if (quantized && seq >= 256 ) return &xehpc_q_h64;
212
+ if (quantized) {
213
+ if (is_integrated) {
214
+ if (seq <= 128 ) return &xehpc_q_h64_s128_integrated;
215
+ if (seq <= 384 ) return &xehpc_q_h64_s384_integrated;
216
+ return &xehpc_q_h64_integrated;
217
+ }
218
+ if (seq <= 64 ) return &xehpc_q_h64_s64;
219
+ if (seq <= 384 ) return &xehpc_q_h64_s384;
220
+ if (seq <= 1024 ) return &xehpc_q_h64_s1024;
221
+ return &xehpc_q_h64;
222
+ }
171
223
if (seq <= 32 ) return &xehpc_h64_s32;
172
224
if (seq <= 64 ) return &xehpc_h64_s64;
173
225
return &xehpc_h64;
0 commit comments