Skip to content

Commit ce45811

Browse files
committed
xe: sdpa: Update config for quantized sdpa with head_size of 64
1 parent 6e9f67b commit ce45811

File tree

1 file changed

+60
-6
lines changed

1 file changed

+60
-6
lines changed

src/gpu/intel/ocl/micro_sdpa.cpp

+60-6
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,14 @@ sdpa_config_t xehpg_h64_s128 = {16, 16, 16, 16, 4, 8, 4, 8};
6161
sdpa_config_t xehpg_h64_s64 = {32, 16, 16, 8, 8, 4, 4, 8};
6262
sdpa_config_t xehpg_h64_2nd = {8, 16, 16, 8, 8, 1, 4, 2};
6363

64-
sdpa_config_t xehpg_q_h64 = {32, 16, 16, 16, 4, 4, 4, 4};
65-
sdpa_config_t xehpg_q_h64_2nd = {16, 16, 8, 8, 16, 1, 8, 2};
64+
sdpa_config_t xehpg_q_h64 = {32, 16, 16, 16, 4, 8, 4, 8};
65+
sdpa_config_t xehpg_q_h64_s128 = {16, 16, 16, 8, 8, 4, 4, 8};
66+
sdpa_config_t xehpg_q_h64_s64 = {32, 8, 32, 8, 2, 8, 2, 8};
67+
sdpa_config_t xehpg_q_h64_s32 = {8, 8, 16, 8, 4, 8, 4, 8};
68+
69+
sdpa_config_t xehpg_q_h64_s64_2nd = {8, 8, 8, 8, 8, 2, 8, 2};
70+
sdpa_config_t xehpg_q_h64_s128_2nd = {16, 8, 8, 8, 8, 4, 8, 4};
71+
sdpa_config_t xehpg_q_h64_2nd = {16, 16, 8, 8, 16, 2, 8, 4};
6672

6773
sdpa_config_t xehpg_h128 = {16, 16, 32, 8, 8, 4, 4, 8};
6874
sdpa_config_t xehpg_h128_s32 = {16, 16, 16, 8, 16, 2, 8, 4};
@@ -90,7 +96,24 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2};
9096
sdpa_config_t xehpc_h64_2nd = {32, 32, 32, 16, 4, 1, 2, 2};
9197
sdpa_config_t xehpc_h64_s64_2nd = {16, 16, 16, 16, 4, 1, 4, 1};
9298

93-
sdpa_config_t xehpc_q_h64 = {16, 64, 32, 16, 8, 4, 2, 16};
99+
sdpa_config_t xehpc_q_h64_s64 = {16, 16, 16, 16, 4, 4, 4, 4};
100+
sdpa_config_t xehpc_q_h64_s384 = {16, 64, 16, 32, 8, 2, 4, 4};
101+
sdpa_config_t xehpc_q_h64_s1024 = {16, 64, 16, 16, 16, 1, 4, 4};
102+
sdpa_config_t xehpc_q_h64 = {16, 64, 16, 32, 8, 1, 4, 2};
103+
104+
sdpa_config_t xehpc_q_h64_s128_integrated = {16, 16, 16, 16, 4, 4, 4, 4};
105+
sdpa_config_t xehpc_q_h64_s384_integrated = {16, 64, 16, 16, 16, 1, 4, 4};
106+
sdpa_config_t xehpc_q_h64_s1024_integrated = {16, 64, 16, 32, 8, 4, 4, 8};
107+
sdpa_config_t xehpc_q_h64_integrated = {16, 64, 16, 32, 16, 1, 8, 2};
108+
109+
sdpa_config_t xehpc_q_h64_s96_2nd = {16, 16, 16, 16, 8, 1, 4, 1};
110+
sdpa_config_t xehpc_q_h64_s256_2nd = {16, 16, 16, 16, 16, 1, 16, 1};
111+
sdpa_config_t xehpc_q_h64_s1152_2nd = {16, 16, 16, 16, 16, 1, 16, 1};
112+
sdpa_config_t xehpc_q_h64_2nd = {64, 16, 16, 16, 16, 2, 16, 2};
113+
114+
sdpa_config_t xehpc_q_h64_s96_2nd_integrated = {16, 16, 16, 16, 8, 1, 4, 1};
115+
sdpa_config_t xehpc_q_h64_s384_2nd_integrated = {64, 16, 16, 16, 4, 1, 4, 1};
116+
sdpa_config_t xehpc_q_h64_2nd_integrated = {16, 16, 16, 16, 8, 1, 8, 1};
94117

95118
sdpa_config_t xehpc_h128 = {16, 64, 32, 16, 16, 2, 4, 8};
96119
sdpa_config_t xehpc_h128_s64 = {16, 32, 32, 32, 4, 2, 4, 2};
@@ -121,8 +144,16 @@ sdpa_config_t *choose_config_xehpg(
121144
return &xehpg_h32;
122145
} else if (head_size <= 64) {
123146
if (quantized) {
124-
if (thin_q) return &xehpg_q_h64_2nd;
125-
return &xehpg_q_h64;
147+
if (thin_q) {
148+
if (seq <= 64) return &xehpg_q_h64_s64_2nd;
149+
if (seq <= 128) return &xehpg_q_h64_s128_2nd;
150+
return &xehpg_q_h64_2nd;
151+
} else {
152+
if (seq <= 32) return &xehpg_q_h64_s32;
153+
if (seq <= 64) return &xehpg_q_h64_s64;
154+
if (seq <= 128) return &xehpg_q_h64_s128;
155+
return &xehpg_q_h64;
156+
}
126157
}
127158
if (thin_q) return &xehpg_h64_2nd;
128159
if (seq <= 64) return &xehpg_h64_s64;
@@ -164,10 +195,33 @@ sdpa_config_t *choose_config_xehpc(dim_t head_size, dim_t seq, bool thin_q,
164195
return &xehpc_h32;
165196
} else if (head_size <= 64) {
166197
if (thin_q) {
198+
if (quantized) {
199+
if (is_integrated) {
200+
if (seq <= 96) return &xehpc_q_h64_s96_2nd_integrated;
201+
if (seq <= 384) return &xehpc_q_h64_s384_2nd_integrated;
202+
return &xehpc_q_h64_2nd_integrated;
203+
}
204+
if (seq <= 96) return &xehpc_q_h64_s96_2nd;
205+
if (seq <= 256) return &xehpc_q_h64_s256_2nd;
206+
if (seq <= 1152) return &xehpc_q_h64_s1152_2nd;
207+
return &xehpc_q_h64_2nd;
208+
}
209+
167210
if (seq <= 64) return &xehpc_h64_s64_2nd;
168211
return &xehpc_h64_2nd;
169212
}
170-
if (quantized && seq >= 256) return &xehpc_q_h64;
213+
if (quantized) {
214+
if (is_integrated) {
215+
if (seq <= 128) return &xehpc_q_h64_s128_integrated;
216+
if (seq <= 384) return &xehpc_q_h64_s384_integrated;
217+
if (seq <= 1024) return &xehpc_q_h64_s1024_integrated;
218+
return &xehpc_q_h64_integrated;
219+
}
220+
if (seq <= 64) return &xehpc_q_h64_s64;
221+
if (seq <= 384) return &xehpc_q_h64_s384;
222+
if (seq <= 1024) return &xehpc_q_h64_s1024;
223+
return &xehpc_q_h64;
224+
}
171225
if (seq <= 32) return &xehpc_h64_s32;
172226
if (seq <= 64) return &xehpc_h64_s64;
173227
return &xehpc_h64;

0 commit comments

Comments
 (0)