Skip to content

Commit f5309e1

Browse files
authored
[GPU] GQA optimization of PagedAttention OCL kernel for long sequences (#29383)
### Details: - GQA optimization of PagedAttention OCL kernel for long sequences ### Tickets: - [CVS-162048](https://jira.devtools.intel.com/browse/CVS-162048)
1 parent fb8dbbe commit f5309e1

File tree

6 files changed

+271
-114
lines changed

6 files changed

+271
-114
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
238238

239239
args.outputs = { instance.output_memory_ptr(0) };
240240
} else if (stage == Stage::PA_SDPA) {
241-
if (kernel_idx == 0 || kernel_idx == 1) {
241+
if (kernel_idx == 0 || kernel_idx == 1 || kernel_idx == 2) {
242242
// 2nd+ token calculation or mixed stage tokens calculation
243243
args.shape_info = instance.shape_info_memory_ptr();
244244

@@ -262,7 +262,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
262262
if (desc->has_alibi) {
263263
args.inputs.push_back(instance.alibi_memory_ptr());
264264
}
265-
} else if (kernel_idx == 2 || kernel_idx == 3) {
265+
} else if (kernel_idx == 3 || kernel_idx == 4) {
266266
// Finalization kernel or mixed stage finalization kernel
267267
args.inputs = { instance.past_lens_memory_ptr() };
268268

@@ -276,15 +276,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
276276
args.inputs.push_back(instance.rotation_deltas_memory_ptr());
277277
args.inputs.push_back(instance.rotation_trig_lut_memory_ptr());
278278
}
279-
} else if (kernel_idx == 4) {
279+
} else if (kernel_idx == 5) {
280280
// Output scores calculation kernel
281281
args.inputs = { instance.past_lens_memory_ptr(),
282282
instance.subsequence_begins_memory_ptr() };
283283
}
284284

285285
args.outputs = { instance.output_memory_ptr(0) };
286286

287-
if (kernel_idx == 4) {
287+
if (kernel_idx == 5) {
288288
args.outputs.push_back(instance.output_memory_ptr(1));
289289
}
290290
}
@@ -660,7 +660,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
660660

661661
if (desc->heads_num != desc->kv_heads_num) {
662662
config.broadcast_axis = 1;
663-
config.group_size = desc->heads_num / desc->kv_heads_num;
663+
config.kv_group_size = desc->heads_num / desc->kv_heads_num;
664664
}
665665

666666
if (desc->has_scores_output() && !is_dynamic) {

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
252252
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
253253
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
254254
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
255-
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
255+
config.kv_group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
256256
}
257257
}
258258

0 commit comments

Comments
 (0)