@@ -238,7 +238,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
238
238
239
239
args.outputs = { instance.output_memory_ptr (0 ) };
240
240
} else if (stage == Stage::PA_SDPA) {
241
- if (kernel_idx == 0 || kernel_idx == 1 ) {
241
+ if (kernel_idx == 0 || kernel_idx == 1 || kernel_idx == 2 ) {
242
242
// 2nd+ token calculation or mixed stage tokens calculation
243
243
args.shape_info = instance.shape_info_memory_ptr ();
244
244
@@ -262,7 +262,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
262
262
if (desc->has_alibi ) {
263
263
args.inputs .push_back (instance.alibi_memory_ptr ());
264
264
}
265
- } else if (kernel_idx == 2 || kernel_idx == 3 ) {
265
+ } else if (kernel_idx == 3 || kernel_idx == 4 ) {
266
266
// Finalization kernel or mixed stage finalization kernel
267
267
args.inputs = { instance.past_lens_memory_ptr () };
268
268
@@ -276,15 +276,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
276
276
args.inputs .push_back (instance.rotation_deltas_memory_ptr ());
277
277
args.inputs .push_back (instance.rotation_trig_lut_memory_ptr ());
278
278
}
279
- } else if (kernel_idx == 4 ) {
279
+ } else if (kernel_idx == 5 ) {
280
280
// Output scores calculation kernel
281
281
args.inputs = { instance.past_lens_memory_ptr (),
282
282
instance.subsequence_begins_memory_ptr () };
283
283
}
284
284
285
285
args.outputs = { instance.output_memory_ptr (0 ) };
286
286
287
- if (kernel_idx == 4 ) {
287
+ if (kernel_idx == 5 ) {
288
288
args.outputs .push_back (instance.output_memory_ptr (1 ));
289
289
}
290
290
}
@@ -660,7 +660,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
660
660
661
661
if (desc->heads_num != desc->kv_heads_num ) {
662
662
config.broadcast_axis = 1 ;
663
- config.group_size = desc->heads_num / desc->kv_heads_num ;
663
+ config.kv_group_size = desc->heads_num / desc->kv_heads_num ;
664
664
}
665
665
666
666
if (desc->has_scores_output () && !is_dynamic) {
0 commit comments