@@ -44,6 +44,10 @@ KERNEL(pa_sdpa_opt)(
44
44
const __global ALIBI_INPUT_TYPE * alibi_slopes ,
45
45
#endif
46
46
__global OUTPUT_TYPE * output ,
47
+ #if PAGED_ATTENTION_SCORES_OUTPUT
48
+ __global SOFTMAX_ACCUMULATOR_TYPE * softmax_results ,
49
+ const __global int * subsequence_offsets ,
50
+ #endif
47
51
__global SOFTMAX_ACCUMULATOR_TYPE * exp_sums ,
48
52
__global SOFTMAX_ACCUMULATOR_TYPE * max_logits ,
49
53
__global OUTPUT_TYPE * tmp_out
@@ -276,6 +280,28 @@ KERNEL(pa_sdpa_opt)(
276
280
const uint max_logits_offset = exp_sums_offset ;
277
281
max_logits [max_logits_offset ] = qk_max ;
278
282
}
283
+
284
+ #if PAGED_ATTENTION_SCORES_OUTPUT
285
+ #if MULTI_TOKENS_PROCESSING
286
+ const uint subsequence_idx = gws_subseq_mapping [seq_idx ];
287
+ const uint subsequence_start_pos = subsequence_begins [subsequence_idx ];
288
+ const uint subsequence_end_pos = subsequence_begins [subsequence_idx + 1 ];
289
+ const bool save_softmax_results = seq_idx == subsequence_end_pos - 1 ;
290
+ #else
291
+ const uint subsequence_idx = seq_idx ;
292
+ const bool save_softmax_results = true;
293
+ #endif // MULTI_TOKENS_PROCESSING
294
+ // PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
295
+ // so save SEQ_LEN_PARTITION_SIZE elements for each partition
296
+ if (save_softmax_results ) {
297
+ const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
298
+ head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
299
+ partition_idx * SEQ_LEN_PARTITION_SIZE ;
300
+ for (uint i = sgid * SUBGROUP_SIZE + sglid ; i < SEQ_LEN_PARTITION_SIZE ; i += SUBGROUPS_PER_WG * SUBGROUP_SIZE ) {
301
+ softmax_results [output_offset + i ] = slm_qk_vals [i ];
302
+ }
303
+ }
304
+ #endif // PAGED_ATTENTION_SCORES_OUTPUT
279
305
}
280
306
}
281
307
@@ -370,6 +396,10 @@ KERNEL(pa_sdpa_finalization_stage)(
370
396
const __global INPUT6_TYPE * subsequence_begins ,
371
397
#endif
372
398
__global OUTPUT_TYPE * output ,
399
+ #if PAGED_ATTENTION_SCORES_OUTPUT
400
+ __global SOFTMAX_ACCUMULATOR_TYPE * softmax_results ,
401
+ const __global int * subsequence_offsets ,
402
+ #endif
373
403
const __global SOFTMAX_ACCUMULATOR_TYPE * exp_sums ,
374
404
const __global SOFTMAX_ACCUMULATOR_TYPE * max_logits ,
375
405
const __global OUTPUT_TYPE * tmp_out ,
@@ -500,3 +530,155 @@ KERNEL(pa_sdpa_finalization_stage)(
500
530
}
501
531
502
532
#endif
533
+
534
+ #ifdef SDPA_STAGE_2
535
+ #define MAX_PARTITIONS_NUM 128
536
+
537
+ REQD_SUB_GROUP_SIZE (SUBGROUP_SIZE )
538
+ KERNEL (pa_sdpa_scores_calculation )(
539
+ const __global INPUT3_TYPE * past_lens ,
540
+ const __global INPUT6_TYPE * subsequence_begins ,
541
+ __global OUTPUT1_TYPE * scores_output ,
542
+ const __global SOFTMAX_ACCUMULATOR_TYPE * softmax_output ,
543
+ const __global int * subsequence_offsets ,
544
+ const __global SOFTMAX_ACCUMULATOR_TYPE * exp_sums ,
545
+ const __global SOFTMAX_ACCUMULATOR_TYPE * max_logits ,
546
+ const __global OUTPUT_TYPE * tmp_out ,
547
+ const uint is_mixed_mode ) {
548
+ const uint subsequence_idx = get_global_id (2 );
549
+ const uint partition_global_idx = get_global_id (0 );
550
+ const uint local_id = get_local_id (0 );
551
+ const uint partition_idx = get_group_id (0 );
552
+ const uint partition_size = get_local_size (0 );
553
+ const uint max_seq_len = get_global_size (0 );
554
+ const uint partitions_num = get_num_groups (0 );
555
+ const uint sgid = get_sub_group_id ();
556
+ const uint sgid_num = get_num_sub_groups ();
557
+ const uint sglid = get_sub_group_local_id ();
558
+
559
+ const int subsequence_begin = subsequence_begins [subsequence_idx ];
560
+ const int subsequence_end = subsequence_begins [subsequence_idx + 1 ];
561
+ const uint seq_len = (subsequence_end - subsequence_begin ) + past_lens [subsequence_idx ];
562
+
563
+ const uint num_of_partitions = CEIL_DIV (seq_len , partition_size );
564
+
565
+ if (partition_idx >= num_of_partitions )
566
+ return ;
567
+
568
+ __local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sums [HEADS_NUM ];
569
+ __local SOFTMAX_ACCUMULATOR_TYPE slm_global_exp_sum [HEADS_NUM ];
570
+
571
+ SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
572
+ if (seq_len <= partition_size ) {
573
+ // If seq_len is less than the partition size, just reduce the results over the heads
574
+ for (uint head_idx = 0 ; head_idx < HEADS_NUM ; head_idx ++ ) {
575
+ const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx ;
576
+ SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output [input_offset ];
577
+ total_score += softmax_value ;
578
+ }
579
+ } else if (seq_len <= partition_size * MAX_PARTITIONS_NUM ) {
580
+ // Optimized version for longer prompts (up to partition_size * MAX_PARTITIONS_NUM, ~64K tokens)
581
+
582
+ // Depending on the previous kernel exp_sums and max_logits might have different structure:
583
+ // For ordinary 1st and 2nd token kernels, there is only a single entry per subsequence.
584
+ // However, for mixed mode execution, exp_sums and max_logits include information for all
585
+ // tokens of each subsequence, but only the last one is needed for score calculation.
586
+ const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx ;
587
+
588
+ for (uint head_idx = sgid ; head_idx < HEADS_NUM ; head_idx += sgid_num ) {
589
+ SOFTMAX_ACCUMULATOR_TYPE max_logit [MAX_PARTITIONS_NUM / SUBGROUP_SIZE ];
590
+ SOFTMAX_ACCUMULATOR_TYPE exp_sum [MAX_PARTITIONS_NUM / SUBGROUP_SIZE ];
591
+
592
+ const uint exp_sums_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num ;
593
+ for (int i = 0 ; i < partitions_num / SUBGROUP_SIZE ; i ++ ) {
594
+ max_logit [i ] = max_logits [exp_sums_offset + i * SUBGROUP_SIZE + sglid ];
595
+ exp_sum [i ] = exp_sums [exp_sums_offset + i * SUBGROUP_SIZE + sglid ];
596
+ }
597
+
598
+ const uint partitions_leftovers = partitions_num % SUBGROUP_SIZE ;
599
+ if (partitions_leftovers != 0 ) {
600
+ const uint idx = partitions_num / SUBGROUP_SIZE ;
601
+ max_logit [idx ] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits [exp_sums_offset + idx * SUBGROUP_SIZE + sglid ];
602
+ exp_sum [idx ] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums [exp_sums_offset + idx * SUBGROUP_SIZE + sglid ];
603
+ }
604
+
605
+ SOFTMAX_ACCUMULATOR_TYPE global_max_logit = max_logit [0 ];
606
+ for (uint i = 1 ; i < CEIL_DIV (partitions_num , SUBGROUP_SIZE ); i ++ ) {
607
+ global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC (global_max_logit , max_logit [i ]);
608
+ }
609
+
610
+ global_max_logit = sub_group_reduce_max (global_max_logit );
611
+
612
+ SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
613
+ for (uint i = 0 ; i < CEIL_DIV (partitions_num , SUBGROUP_SIZE ); i ++ ) {
614
+ SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum [i ] * native_exp (max_logit [i ] - global_max_logit );
615
+ // slm_exp_sums[head_idx][i * SUBGROUP_SIZE + sglid] = adjusted_exp_sum;
616
+ if (i * SUBGROUP_SIZE + sglid == partition_idx )
617
+ slm_exp_sums [head_idx ] = adjusted_exp_sum ;
618
+ global_exp_sum += adjusted_exp_sum ;
619
+ }
620
+
621
+ global_exp_sum = sub_group_reduce_add (global_exp_sum );
622
+
623
+ slm_global_exp_sum [head_idx ] = global_exp_sum ;
624
+ }
625
+
626
+ barrier (CLK_LOCAL_MEM_FENCE );
627
+
628
+ for (uint head_idx = 0 ; head_idx < HEADS_NUM ; head_idx ++ ) {
629
+ SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = slm_exp_sums [head_idx ];
630
+ SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = slm_global_exp_sum [head_idx ];
631
+
632
+ const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx ;
633
+ SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output [input_offset ];
634
+
635
+ softmax_value = softmax_value * adjusted_exp_sum / global_exp_sum ;
636
+ total_score += softmax_value ;
637
+ }
638
+ } else {
639
+ // Non optimized fallback version
640
+ const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx ;
641
+ for (uint head_idx = 0 ; head_idx < HEADS_NUM ; head_idx ++ ) {
642
+ SOFTMAX_ACCUMULATOR_TYPE global_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN ;
643
+ const uint max_logits_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num ;
644
+ for (uint i = 0 ; i < CEIL_DIV (partitions_num , SUBGROUP_SIZE ); i ++ ) {
645
+ const uint partition_offset = i * SUBGROUP_SIZE + sglid ;
646
+ SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits [max_logits_base_offset + partition_offset ];
647
+ global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC (global_max_logit , max_logit );
648
+ }
649
+
650
+ global_max_logit = sub_group_reduce_max (global_max_logit );
651
+
652
+ SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
653
+ SOFTMAX_ACCUMULATOR_TYPE partition_adjusted_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
654
+ const uint exp_sums_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num ;
655
+ for (uint i = 0 ; i < CEIL_DIV (partitions_num , SUBGROUP_SIZE ); i ++ ) {
656
+ const uint partition_offset = i * SUBGROUP_SIZE + sglid ;
657
+ SOFTMAX_ACCUMULATOR_TYPE exp_sum = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums [exp_sums_base_offset + partition_offset ];
658
+ SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits [max_logits_base_offset + partition_offset ];
659
+ SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp (max_logit - global_max_logit );
660
+ global_exp_sum += adjusted_exp_sum ;
661
+
662
+ // Save and broadcast the adjusted exp_sum for the currently being processed partition
663
+ if (i == partition_idx / SUBGROUP_SIZE )
664
+ partition_adjusted_exp_sum = sub_group_broadcast (adjusted_exp_sum , partition_idx % SUBGROUP_SIZE );
665
+ }
666
+
667
+ global_exp_sum = sub_group_reduce_add (global_exp_sum );
668
+
669
+ const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx ;
670
+ SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output [input_offset ];
671
+
672
+ softmax_value = softmax_value * partition_adjusted_exp_sum / global_exp_sum ;
673
+ total_score += softmax_value ;
674
+ }
675
+ }
676
+
677
+ const uint output_offset = subsequence_offsets [subsequence_idx ];
678
+ if (partition_global_idx < seq_len ) {
679
+ scores_output [output_offset + partition_global_idx ] = total_score ;
680
+ }
681
+ }
682
+
683
+ #undef MAX_PARTITIONS_NUM
684
+ #endif
0 commit comments