Skip to content

Commit 4599fbd

Browse files
committed
xe: sdpa: Update config for quantized sdpa with head_size of 64
1 parent 1c64cb3 commit 4599fbd

File tree

1 file changed

+58
-6
lines changed

1 file changed

+58
-6
lines changed

src/gpu/intel/ocl/micro_sdpa.cpp

+58-6
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,14 @@ sdpa_config_t xehpg_h64_s128 = {16, 16, 16, 16, 4, 8, 4, 8};
6161
sdpa_config_t xehpg_h64_s64 = {32, 16, 16, 8, 8, 4, 4, 8};
6262
sdpa_config_t xehpg_h64_2nd = {8, 16, 16, 8, 8, 1, 4, 2};
6363

64-
sdpa_config_t xehpg_q_h64 = {32, 16, 16, 16, 4, 4, 4, 4};
65-
sdpa_config_t xehpg_q_h64_2nd = {16, 16, 8, 8, 16, 1, 8, 2};
64+
sdpa_config_t xehpg_q_h64 = {32, 16, 16, 16, 4, 8, 4, 8};
65+
sdpa_config_t xehpg_q_h64_s128 = {16, 16, 16, 8, 8, 4, 4, 8};
66+
sdpa_config_t xehpg_q_h64_s64 = {32, 8, 32, 8, 2, 8, 2, 8};
67+
sdpa_config_t xehpg_q_h64_s32 = {8, 8, 16, 8, 4, 8, 4, 8};
68+
69+
sdpa_config_t xehpg_q_h64_s64_2nd = {8, 8, 8, 8, 8, 2, 8, 2};
70+
sdpa_config_t xehpg_q_h64_s128_2nd = {16, 8, 8, 8, 8, 4, 8, 4};
71+
sdpa_config_t xehpg_q_h64_2nd = {16, 16, 8, 8, 16, 2, 8, 4};
6672

6773
sdpa_config_t xehpg_h128 = {16, 16, 32, 8, 8, 4, 4, 8};
6874
sdpa_config_t xehpg_h128_s32 = {16, 16, 16, 8, 16, 2, 8, 4};
@@ -90,7 +96,23 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2};
9096
sdpa_config_t xehpc_h64_2nd = {32, 32, 32, 16, 4, 1, 2, 2};
9197
sdpa_config_t xehpc_h64_s64_2nd = {16, 16, 16, 16, 4, 1, 4, 1};
9298

93-
sdpa_config_t xehpc_q_h64 = {16, 64, 32, 16, 8, 4, 2, 16};
99+
sdpa_config_t xehpc_q_h64_s64 = {16, 16, 16, 16, 4, 4, 4, 4};
100+
sdpa_config_t xehpc_q_h64_s384 = {16, 64, 16, 32, 8, 2, 4, 4};
101+
sdpa_config_t xehpc_q_h64_s1024 = {16, 64, 16, 16, 16, 1, 4, 4};
102+
sdpa_config_t xehpc_q_h64 = {16, 64, 16, 32, 8, 1, 4, 2};
103+
104+
sdpa_config_t xehpc_q_h64_s128_integrated = {16, 16, 16, 16, 4, 4, 4, 4};
105+
sdpa_config_t xehpc_q_h64_s384_integrated = {16, 64, 16, 16, 16, 1, 4, 4};
106+
sdpa_config_t xehpc_q_h64_integrated = {16, 64, 16, 32, 16, 1, 8, 2};
107+
108+
sdpa_config_t xehpc_q_h64_s96_2nd = {16, 16, 16, 16, 8, 1, 4, 1};
109+
sdpa_config_t xehpc_q_h64_s256_2nd = {16, 16, 16, 16, 16, 1, 16, 1};
110+
sdpa_config_t xehpc_q_h64_s1152_2nd = {16, 16, 16, 16, 16, 1, 16, 1};
111+
sdpa_config_t xehpc_q_h64_2nd = {64, 16, 16, 16, 16, 2, 16, 2};
112+
113+
sdpa_config_t xehpc_q_h64_s96_2nd_integrated = {16, 16, 16, 16, 8, 1, 4, 1};
114+
sdpa_config_t xehpc_q_h64_s384_2nd_integrated = {64, 16, 16, 16, 4, 1, 4, 1};
115+
sdpa_config_t xehpc_q_h64_2nd_integrated = {16, 16, 16, 16, 8, 1, 8, 1};
94116

95117
sdpa_config_t xehpc_h128 = {16, 64, 32, 16, 16, 2, 4, 8};
96118
sdpa_config_t xehpc_h128_s64 = {16, 32, 32, 32, 4, 2, 4, 2};
@@ -121,8 +143,16 @@ sdpa_config_t *choose_config_xehpg(
121143
return &xehpg_h32;
122144
} else if (head_size <= 64) {
123145
if (quantized) {
124-
if (thin_q) return &xehpg_q_h64_2nd;
125-
return &xehpg_q_h64;
146+
if (thin_q) {
147+
if (seq <= 64) return &xehpg_q_h64_s64_2nd;
148+
if (seq <= 128) return &xehpg_q_h64_s128_2nd;
149+
return &xehpg_q_h64_2nd;
150+
} else {
151+
if (seq <= 32) return &xehpg_q_h64_s32;
152+
if (seq <= 64) return &xehpg_q_h64_s64;
153+
if (seq <= 128) return &xehpg_q_h64_s128;
154+
return &xehpg_q_h64;
155+
}
126156
}
127157
if (thin_q) return &xehpg_h64_2nd;
128158
if (seq <= 64) return &xehpg_h64_s64;
@@ -164,10 +194,32 @@ sdpa_config_t *choose_config_xehpc(dim_t head_size, dim_t seq, bool thin_q,
164194
return &xehpc_h32;
165195
} else if (head_size <= 64) {
166196
if (thin_q) {
197+
if (quantized) {
198+
if (is_integrated) {
199+
if (seq <= 96) return &xehpc_q_h64_s96_2nd_integrated;
200+
if (seq <= 384) return &xehpc_q_h64_s384_2nd_integrated;
201+
return &xehpc_q_h64_2nd_integrated;
202+
}
203+
if (seq <= 96) return &xehpc_q_h64_s96_2nd;
204+
if (seq <= 256) return &xehpc_q_h64_s256_2nd;
205+
if (seq <= 1152) return &xehpc_q_h64_s1152_2nd;
206+
return &xehpc_q_h64_2nd;
207+
}
208+
167209
if (seq <= 64) return &xehpc_h64_s64_2nd;
168210
return &xehpc_h64_2nd;
169211
}
170-
if (quantized && seq >= 256) return &xehpc_q_h64;
212+
if (quantized) {
213+
if (is_integrated) {
214+
if (seq <= 128) return &xehpc_q_h64_s128_integrated;
215+
if (seq <= 384) return &xehpc_q_h64_s384_integrated;
216+
return &xehpc_q_h64_integrated;
217+
}
218+
if (seq <= 64) return &xehpc_q_h64_s64;
219+
if (seq <= 384) return &xehpc_q_h64_s384;
220+
if (seq <= 1024) return &xehpc_q_h64_s1024;
221+
return &xehpc_q_h64;
222+
}
171223
if (seq <= 32) return &xehpc_h64_s32;
172224
if (seq <= 64) return &xehpc_h64_s64;
173225
return &xehpc_h64;

0 commit comments

Comments
 (0)