Skip to content

Commit e2ac535

Browse files
authored
[GPU] Add scores output support for PagedAttention (#28205)
### Details: - Added scores output support for PagedAttention - Added PagedAttention unit tests ### Tickets: - [CVS-153660](https://jira.devtools.intel.com/browse/CVS-153660)
1 parent c040b7b commit e2ac535

File tree

14 files changed

+1428
-161
lines changed

14 files changed

+1428
-161
lines changed

src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ struct paged_attention : public primitive_base<paged_attention> {
2424
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
2525
}
2626

27+
bool has_scores_output() const {
28+
return num_outputs == 2;
29+
}
30+
2731
bool operator==(const primitive& rhs) const override {
2832
return compare_common_params(rhs);
2933
}

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+208-71
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77
#include "intel_gpu/primitives/paged_attention.hpp"
88
#include "primitive_inst.h"
99

10+
#include "sdpa/pa_sdpa_kernel_opt.h"
11+
1012
namespace cldnn {
1113

12-
enum PagedAttentionStage {
13-
GENERATE = 0,
14-
PREFILL = 1,
15-
MIXED = 2,
16-
UNKNOWN = 3
17-
};
14+
using PagedAttentionStage = kernel_selector::PagedAttentionStage;
1815

1916
PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param);
2017

@@ -61,6 +58,9 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
6158
memory::ptr block_indices_memory_ptr() const { return input_memory_ptr(7); }
6259
memory::ptr block_indices_begins_memory_ptr() const { return input_memory_ptr(8); }
6360
memory::ptr alibi_memory_ptr() const { return input_memory_ptr(11); }
61+
memory::ptr rotated_block_indices_memory_ptr() const { return input_memory_ptr(13); }
62+
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
63+
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }
6464

6565
std::shared_ptr<network> prefill_network;
6666

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+75-12
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,38 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no
4848

4949
template<typename ShapeType>
5050
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
51-
auto out_layout = impl_param.get_input_layout(0);
51+
auto data_layout = impl_param.get_input_layout(0);
5252

5353
const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
5454
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
5555
OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. "
5656
"Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length());
5757

58-
return {out_layout};
58+
std::vector<layout> output_layouts{ data_layout };
59+
60+
const auto& desc = impl_param.typed_desc<paged_attention>();
61+
if (desc->has_scores_output()) {
62+
const auto past_lens_idx = 5;
63+
const auto output_dt = data_layout.data_type;
64+
if (impl_param.get_input_layout(past_lens_idx).is_static()) {
65+
const auto& memory_deps = impl_param.memory_deps;
66+
const auto past_lens_mem = memory_deps.at(past_lens_idx);
67+
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);
68+
69+
long int total_size = 0;
70+
for (size_t i = 0; i < past_lens_mem_lock.size(); i++) {
71+
total_size += past_lens_mem_lock[i];
72+
}
73+
74+
total_size += static_cast<long int>(impl_param.get_input_layout(0).get_shape()[0]);
75+
76+
output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx});
77+
} else {
78+
output_layouts.push_back(layout{ov::PartialShape::dynamic(1), output_dt, format::bfyx});
79+
}
80+
}
81+
82+
return output_layouts;
5983
}
6084

6185
template std::vector<layout>
@@ -81,45 +105,79 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
81105
}
82106

83107
void paged_attention_inst::on_execute() {
84-
auto stage = get_paged_attention_stage(*_impl_params);
108+
const auto& desc = _impl_params->typed_desc<paged_attention>();
109+
const bool has_scores_output = desc->has_scores_output();
110+
const auto stage = get_paged_attention_stage(*_impl_params);
85111

86-
if (stage == PagedAttentionStage::UNKNOWN ||
87-
stage == PagedAttentionStage::GENERATE)
112+
if ((stage == PagedAttentionStage::UNKNOWN) ||
113+
(stage == PagedAttentionStage::GENERATE && !has_scores_output))
88114
return;
89115

116+
auto& stream = get_network().get_stream();
117+
const auto past_lens_mem = past_lens_memory_ptr();
118+
const auto subsequence_begins_mem = subsequence_begins_memory_ptr();
119+
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
120+
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
121+
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;
122+
123+
if (has_scores_output) {
124+
const size_t subsequence_offsets_idx = 4;
125+
126+
OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
127+
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");
128+
129+
auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
130+
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
131+
}
132+
133+
if (stage == PagedAttentionStage::GENERATE) {
134+
// For the generate stage it's not necessary to configure any other intermediate
135+
// buffers. Simply calculate the offsets and exit
136+
size_t subsequence_offsets_acc = 0;
137+
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
138+
const auto past_len = past_lens_mem_lock[i];
139+
const auto seq_start = subsequence_begins_mem_lock[i];
140+
const auto seq_end = subsequence_begins_mem_lock[i + 1];
141+
const auto seq_length = seq_end - seq_start;
142+
143+
if (subsequence_offsets_lock) {
144+
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
145+
subsequence_offsets_acc += seq_length + past_len;
146+
}
147+
}
148+
149+
return;
150+
}
151+
90152
OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");
91153

92154
const auto blocks_indexes_start_idx = 0;
93155
const auto blocks_indexes_end_idx = 1;
94156
const auto blocked_gws_subseq_mapping_idx = 2;
95157

96-
const auto past_lens_mem = past_lens_memory_ptr();
97-
auto subsequence_begins_mem = subsequence_begins_memory_ptr();
98158
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
99159
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
100160
auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx];
101161

102162
OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);
103163

104-
auto& stream = get_network().get_stream();
105-
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
106-
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
107164
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
108165
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
109166
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
110167
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;
111168

112169
if (stage == PagedAttentionStage::MIXED) {
113-
const auto sequential_gws_subseq_mapping_idx = 6;
170+
const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;
114171

115172
OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx,
116-
"Unexpected number of intermediates buffers for Paged Attention for mixed stage");
173+
"[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage");
117174

118175
auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
119176
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
120177
}
121178

122179
size_t index = 0;
180+
size_t subsequence_offsets_acc = 0;
123181
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
124182
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
125183
const auto past_len = past_lens_mem_lock[i];
@@ -159,6 +217,11 @@ void paged_attention_inst::on_execute() {
159217
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
160218
}
161219
}
220+
221+
if (subsequence_offsets_lock) {
222+
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
223+
subsequence_offsets_acc += seq_length + past_len;
224+
}
162225
}
163226
}
164227

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+182
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ KERNEL(pa_sdpa_opt)(
4444
const __global ALIBI_INPUT_TYPE* alibi_slopes,
4545
#endif
4646
__global OUTPUT_TYPE* output,
47+
#if PAGED_ATTENTION_SCORES_OUTPUT
48+
__global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
49+
const __global int* subsequence_offsets,
50+
#endif
4751
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
4852
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
4953
__global OUTPUT_TYPE* tmp_out
@@ -276,6 +280,28 @@ KERNEL(pa_sdpa_opt)(
276280
const uint max_logits_offset = exp_sums_offset;
277281
max_logits[max_logits_offset] = qk_max;
278282
}
283+
284+
#if PAGED_ATTENTION_SCORES_OUTPUT
285+
#if MULTI_TOKENS_PROCESSING
286+
const uint subsequence_idx = gws_subseq_mapping[seq_idx];
287+
const uint subsequence_start_pos = subsequence_begins[subsequence_idx];
288+
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
289+
const bool save_softmax_results = seq_idx == subsequence_end_pos - 1;
290+
#else
291+
const uint subsequence_idx = seq_idx;
292+
const bool save_softmax_results = true;
293+
#endif // MULTI_TOKENS_PROCESSING
294+
// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
295+
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
296+
if (save_softmax_results) {
297+
const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
298+
head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
299+
partition_idx * SEQ_LEN_PARTITION_SIZE;
300+
for (uint i = sgid * SUBGROUP_SIZE + sglid; i < SEQ_LEN_PARTITION_SIZE; i += SUBGROUPS_PER_WG * SUBGROUP_SIZE) {
301+
softmax_results[output_offset + i] = slm_qk_vals[i];
302+
}
303+
}
304+
#endif // PAGED_ATTENTION_SCORES_OUTPUT
279305
}
280306
}
281307

@@ -370,6 +396,10 @@ KERNEL(pa_sdpa_finalization_stage)(
370396
const __global INPUT6_TYPE* subsequence_begins,
371397
#endif
372398
__global OUTPUT_TYPE* output,
399+
#if PAGED_ATTENTION_SCORES_OUTPUT
400+
__global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
401+
const __global int* subsequence_offsets,
402+
#endif
373403
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
374404
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
375405
const __global OUTPUT_TYPE* tmp_out,
@@ -500,3 +530,155 @@ KERNEL(pa_sdpa_finalization_stage)(
500530
}
501531

502532
#endif
533+
534+
#ifdef SDPA_STAGE_2
535+
#define MAX_PARTITIONS_NUM 128
536+
537+
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
538+
KERNEL(pa_sdpa_scores_calculation)(
539+
const __global INPUT3_TYPE* past_lens,
540+
const __global INPUT6_TYPE* subsequence_begins,
541+
__global OUTPUT1_TYPE* scores_output,
542+
const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output,
543+
const __global int* subsequence_offsets,
544+
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
545+
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
546+
const __global OUTPUT_TYPE* tmp_out,
547+
const uint is_mixed_mode) {
548+
const uint subsequence_idx = get_global_id(2);
549+
const uint partition_global_idx = get_global_id(0);
550+
const uint local_id = get_local_id(0);
551+
const uint partition_idx = get_group_id(0);
552+
const uint partition_size = get_local_size(0);
553+
const uint max_seq_len = get_global_size(0);
554+
const uint partitions_num = get_num_groups(0);
555+
const uint sgid = get_sub_group_id();
556+
const uint sgid_num = get_num_sub_groups();
557+
const uint sglid = get_sub_group_local_id();
558+
559+
const int subsequence_begin = subsequence_begins[subsequence_idx];
560+
const int subsequence_end = subsequence_begins[subsequence_idx + 1];
561+
const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx];
562+
563+
const uint num_of_partitions = CEIL_DIV(seq_len, partition_size);
564+
565+
if (partition_idx >= num_of_partitions)
566+
return;
567+
568+
__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sums[HEADS_NUM];
569+
__local SOFTMAX_ACCUMULATOR_TYPE slm_global_exp_sum[HEADS_NUM];
570+
571+
SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO;
572+
if (seq_len <= partition_size) {
573+
// If seq_len is less than the partition size, just reduce the results over the heads
574+
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
575+
const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
576+
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
577+
total_score += softmax_value;
578+
}
579+
} else if (seq_len <= partition_size * MAX_PARTITIONS_NUM) {
580+
// Optimized version for longer prompts (up to partition_size * MAX_PARTITIONS_NUM, ~64K tokens)
581+
582+
// Depending on the previous kernel exp_sums and max_logits might have different structure:
583+
// For ordinary 1st and 2nd token kernels, there is only a single entry per subsequence.
584+
// However, for mixed mode execution, exp_sums and max_logits include information for all
585+
// tokens of each subsequence, but only the last one is needed for score calculation.
586+
const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;
587+
588+
for (uint head_idx = sgid; head_idx < HEADS_NUM; head_idx += sgid_num) {
589+
SOFTMAX_ACCUMULATOR_TYPE max_logit[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];
590+
SOFTMAX_ACCUMULATOR_TYPE exp_sum[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];
591+
592+
const uint exp_sums_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
593+
for (int i = 0; i < partitions_num / SUBGROUP_SIZE; i++) {
594+
max_logit[i] = max_logits[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
595+
exp_sum[i] = exp_sums[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
596+
}
597+
598+
const uint partitions_leftovers = partitions_num % SUBGROUP_SIZE;
599+
if (partitions_leftovers != 0) {
600+
const uint idx = partitions_num / SUBGROUP_SIZE;
601+
max_logit[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
602+
exp_sum[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
603+
}
604+
605+
SOFTMAX_ACCUMULATOR_TYPE global_max_logit = max_logit[0];
606+
for (uint i = 1; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
607+
global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit[i]);
608+
}
609+
610+
global_max_logit = sub_group_reduce_max(global_max_logit);
611+
612+
SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
613+
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
614+
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum[i] * native_exp(max_logit[i] - global_max_logit);
615+
// slm_exp_sums[head_idx][i * SUBGROUP_SIZE + sglid] = adjusted_exp_sum;
616+
if (i * SUBGROUP_SIZE + sglid == partition_idx)
617+
slm_exp_sums[head_idx] = adjusted_exp_sum;
618+
global_exp_sum += adjusted_exp_sum;
619+
}
620+
621+
global_exp_sum = sub_group_reduce_add(global_exp_sum);
622+
623+
slm_global_exp_sum[head_idx] = global_exp_sum;
624+
}
625+
626+
barrier(CLK_LOCAL_MEM_FENCE);
627+
628+
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
629+
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = slm_exp_sums[head_idx];
630+
SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = slm_global_exp_sum[head_idx];
631+
632+
const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
633+
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
634+
635+
softmax_value = softmax_value * adjusted_exp_sum / global_exp_sum;
636+
total_score += softmax_value;
637+
}
638+
} else {
639+
// Non optimized fallback version
640+
const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;
641+
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
642+
SOFTMAX_ACCUMULATOR_TYPE global_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
643+
const uint max_logits_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
644+
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
645+
const uint partition_offset = i * SUBGROUP_SIZE + sglid;
646+
SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
647+
global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit);
648+
}
649+
650+
global_max_logit = sub_group_reduce_max(global_max_logit);
651+
652+
SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
653+
SOFTMAX_ACCUMULATOR_TYPE partition_adjusted_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
654+
const uint exp_sums_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
655+
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
656+
const uint partition_offset = i * SUBGROUP_SIZE + sglid;
657+
SOFTMAX_ACCUMULATOR_TYPE exp_sum = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_base_offset + partition_offset];
658+
SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
659+
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit);
660+
global_exp_sum += adjusted_exp_sum;
661+
662+
// Save and broadcast the adjusted exp_sum for the currently being processed partition
663+
if (i == partition_idx / SUBGROUP_SIZE)
664+
partition_adjusted_exp_sum = sub_group_broadcast(adjusted_exp_sum, partition_idx % SUBGROUP_SIZE);
665+
}
666+
667+
global_exp_sum = sub_group_reduce_add(global_exp_sum);
668+
669+
const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
670+
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
671+
672+
softmax_value = softmax_value * partition_adjusted_exp_sum / global_exp_sum;
673+
total_score += softmax_value;
674+
}
675+
}
676+
677+
const uint output_offset = subsequence_offsets[subsequence_idx];
678+
if (partition_global_idx < seq_len) {
679+
scores_output[output_offset + partition_global_idx] = total_score;
680+
}
681+
}
682+
683+
#undef MAX_PARTITIONS_NUM
684+
#endif

0 commit comments

Comments
 (0)