Skip to content

Commit 7eedf84

Browse files
[CPU] Add score output for PagedAttention (openvinotoolkit#25594)
### Details: - *Add score output for PagedAttention* - *...* ### Tickets: - *[146969](https://jira.devtools.intel.com/browse/CVS-146969)*
1 parent a6413b4 commit 7eedf84

File tree

7 files changed

+125
-82
lines changed

7 files changed

+125
-82
lines changed

src/plugins/intel_cpu/src/graph.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,8 @@ void Graph::EnforceInferencePrecision() {
17911791
Type::RNNSeq, // recurent nets
17921792
Type::MatMul, // bert nets
17931793
Type::ROIPooling, // object detection nets
1794-
Type::Interpolate)) // super resolution nets
1794+
Type::Interpolate, // super resolution nets
1795+
Type::PagedAttention))// page attention
17951796
continue; // stop at significant nodes
17961797
} else if (inferPrec == ov::element::f16) {
17971798
/* list of node types that must be forced to be executed in FP16 precision
@@ -1895,6 +1896,9 @@ void Graph::EnforceInferencePrecision() {
18951896
const auto &child = node->getChildEdgeAt(i)->getChild();
18961897
if (child->getType() == Type::Range && node->getType() == Type::Convert)
18971898
continue;
1899+
// skip second output of PagedAttention
1900+
if (node->getType() == Type::PagedAttention && (i != 0))
1901+
continue;
18981902

18991903
DEBUG_LOG("#",
19001904
node->getExecIndex(),

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

+91-14
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ struct MHAHelper {
759759
PlainTensor _wv_scratch_a;
760760
PlainTensor _wv_scratch_b;
761761
PlainTensor _alibi_lookup;
762+
PlainTensor _score_output;
762763
std::vector<size_t> _wsp;
763764
size_t _wsp_size_per_thread = 0;
764765

@@ -772,6 +773,8 @@ struct MHAHelper {
772773
// second token for bhl loop
773774
PlainTensor _weight_bhl;
774775
PlainTensor _output_bhl;
776+
PlainTensor _score_offsets_aligned;
777+
PlainTensor _score_offsets;
775778

776779
MHAHelper() {
777780
_weight.resize<float>({size_t{1}, size_t{1}, size_t{1}, size_t{1}});
@@ -867,6 +870,26 @@ struct MHAHelper {
867870
_wv_scratch_b.resize<DATA_TYPE>({batch, kv_len_in_blocks, _Hk, _block_size * rnd_up(_S, _block_size)});
868871
}
869872

873+
void init_score_buffers(const PlainTensor& past_lens, const PlainTensor& subsequence_begins) {
874+
static constexpr int cache_line_size = dnnl::impl::cpu::platform::get_cache_line_size();
875+
auto seq_cout = static_cast<int32_t>(past_lens.m_dims[0]);
876+
_score_offsets_aligned.resize<int32_t>({past_lens.m_dims[0]});
877+
_score_offsets.resize<int32_t>({past_lens.m_dims[0]});
878+
int32_t total_kv_len_aligned = 0;
879+
int32_t total_kv_len = 0;
880+
for (int32_t i = 0; i < seq_cout; i++) {
881+
auto q_len = subsequence_begins.ptr<int32_t>()[i + 1] - subsequence_begins.ptr<int32_t>()[i];
882+
auto kv_len = past_lens.ptr<int32_t>()[i] + q_len;
883+
_score_offsets_aligned.ptr<int32_t>()[i] = total_kv_len_aligned;
884+
_score_offsets.ptr<int32_t>()[i] = total_kv_len;
885+
// aligned to cache line to avoid false sharing
886+
total_kv_len_aligned += rnd_up(kv_len, cache_line_size / sizeof(float));
887+
total_kv_len += kv_len;
888+
}
889+
890+
_score_output.resize<float>({total_kv_len_aligned * _H});
891+
}
892+
870893
// compute one block(such as 32 tokens) of query in M dimension: softmax(q_block*k')*v
871894
// all tensors such as query... have no batch dimension because batch dimension is varying
872895
// query: [H, L, S]
@@ -875,8 +898,8 @@ struct MHAHelper {
875898
// qk_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size]
876899
// wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size]
877900
void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb,
878-
const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b,
879-
const int32_t* block_table, size_t ithr, size_t q_blk, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes) {
901+
const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, const int32_t* block_table, size_t ithr, size_t q_blk,
902+
size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
880903
auto q_start = q_blk * _block_size;
881904
auto q_end = std::min(q_start + _block_size, q_len);
882905
auto q_cnt = q_end - q_start;
@@ -947,6 +970,9 @@ struct MHAHelper {
947970
precision_of<DATA_TYPE>::value,
948971
alibi_slope);
949972
}
973+
if (score_output) {
974+
cvt_copy(score_output + h * rnd_up(cur_kv_len, 16), reinterpret_cast<DATA_TYPE*>(score), cur_kv_len);
975+
}
950976
}
951977

952978
// reuse float buffer, need to use float to compute offset
@@ -998,7 +1024,7 @@ struct MHAHelper {
9981024
// weight: [nthr, H, 32, rnd_up(kv_len, block_size)]
9991025
// output: [nthr, 32, H, S]
10001026
void exec_kernel_one_bh(const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb,
1001-
const int32_t* block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes) {
1027+
const int32_t* block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
10021028
if (_fastpath_valid) {
10031029
_gemv->tile_config();
10041030
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
@@ -1044,6 +1070,9 @@ struct MHAHelper {
10441070
ov::element::f32,
10451071
ov::element::f32,
10461072
alibi_slope);
1073+
if (score_output) {
1074+
memcpy(score_output + h * rnd_up(cur_kv_len, 16), _weight.ptr<float>(ithr, h, pq), cur_kv_len * sizeof(float));
1075+
}
10471076
}
10481077
}
10491078

@@ -1078,6 +1107,7 @@ struct MHAHelper {
10781107
const PlainTensor& present_key,
10791108
const PlainTensor& present_value,
10801109
const PlainTensor& output_emb,
1110+
const PlainTensor& output_score,
10811111
size_t max_context_len,
10821112
const PlainTensor& past_lens,
10831113
const PlainTensor& subsequence_begins,
@@ -1141,6 +1171,16 @@ struct MHAHelper {
11411171
alibi_slope);
11421172
});
11431173

1174+
if (output_score) {
1175+
parallel_for2d_dynamic(B, q_len, [&](size_t b, size_t pq) {
1176+
auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
1177+
auto* src = _weight_bhl.ptr<float>(b, 0, pq);
1178+
size_t src_stride = _weight_bhl.stride(2);
1179+
auto* dst = output_score.ptr<float>() + _score_offsets.ptr<int32_t>()[b];
1180+
attn_reduce(dst, src, _H, cur_kv_len, src_stride);
1181+
});
1182+
}
1183+
11441184
// attn_w * V
11451185
_output_bhl.resize<float>({static_cast<size_t>(_nthr), B, q_len, _H, _S});
11461186
// m_attn_w {B, H, q_len, kv_len}
@@ -1284,6 +1324,7 @@ struct MHA {
12841324
const PlainTensor& k_cache,
12851325
const PlainTensor& v_cache,
12861326
const PlainTensor& output_emb,
1327+
const PlainTensor& output_score,
12871328
size_t max_context_len,
12881329
const PlainTensor& past_lens,
12891330
const PlainTensor& subsequence_begins,
@@ -1343,16 +1384,30 @@ struct MHA {
13431384

13441385
if (q_len == 1) {
13451386
const auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[batch_in_seq]) + 1;
1387+
float* score_output = nullptr;
1388+
if (output_score) {
1389+
auto score_offset = _helper._score_offsets_aligned.template ptr<int32_t>()[batch_in_seq];
1390+
score_output = _helper._score_output.template ptr<float>() + score_offset * _helper._H;
1391+
}
13461392

13471393
_helper.exec_kernel_one_bh(q.slice(0, batch_in_token, batch_in_token), k_cache, v_cache,
13481394
output_emb.slice(0, batch_in_token, batch_in_token),
13491395
block_indices.ptr<int32_t>() + block_indices_begins.ptr<int32_t>()[batch_in_seq],
1350-
ithr, hk, 1ul, cur_kv_len, alibi_slopes);
1396+
ithr, hk, 1ul, cur_kv_len, alibi_slopes,
1397+
score_output);
13511398
} else {
13521399
const auto batch_in_reorder = item.batch_in_reorder;
13531400
const auto q_blk = item.q_block_id;
13541401
const auto q_cnt = std::min(_helper._block_size, q_len - q_blk * _helper._block_size);
13551402
const auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[batch_in_seq]) + q_blk * _helper._block_size + q_cnt;
1403+
float* score_output = nullptr;
1404+
if (output_score) {
1405+
// last block
1406+
if (q_len - q_blk * _helper._block_size <= _helper._block_size) {
1407+
auto score_offset = _helper._score_offsets_aligned.template ptr<int32_t>()[batch_in_seq];
1408+
score_output = _helper._score_output.template ptr<float>() + score_offset * _helper._H;
1409+
}
1410+
}
13561411

13571412
PlainTensor sub_query;
13581413
sub_query.resize({q_len, _helper._H, _helper._S}, q.ptr<DATA_TYPE>(batch_in_token));
@@ -1368,31 +1423,47 @@ struct MHA {
13681423
hk,
13691424
q_len,
13701425
cur_kv_len,
1371-
alibi_slopes);
1426+
alibi_slopes,
1427+
score_output);
13721428
}
13731429
});
1430+
if (output_score) {
1431+
parallel_for2d_dynamic(past_lens.m_dims[0], 1, [&](size_t b, size_t pq) {
1432+
auto seq_len = static_cast<size_t>(subsequence_begins.ptr<int32_t>()[b + 1] - subsequence_begins.ptr<int32_t>()[b]);
1433+
auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + seq_len;
1434+
auto src_offset = _helper._score_offsets_aligned.template ptr<int32_t>()[b];
1435+
auto* src = _helper._score_output.template ptr<float>() + src_offset * _helper._H;
1436+
size_t src_stride = rnd_up(cur_kv_len, 16);
1437+
auto dst_offset = _helper._score_offsets.template ptr<int32_t>()[b];
1438+
auto* dst = output_score.ptr<float>() + dst_offset;
1439+
attn_reduce(dst, src, _helper._H, cur_kv_len, src_stride);
1440+
});
1441+
}
13741442
}
13751443

13761444
// Q, K, V is ready, do attention
13771445
void operator()(PlainTensor& query,
13781446
PlainTensor& present_key,
13791447
PlainTensor& present_value,
13801448
PlainTensor& output_emb,
1449+
PlainTensor& output_score,
13811450
size_t max_context_len,
13821451
const PlainTensor& past_lens,
13831452
const PlainTensor& subsequence_begins,
13841453
const PlainTensor& block_indices,
13851454
const PlainTensor& block_indices_begins,
13861455
const PlainTensor& alibi_slopes) {
13871456
_workitems.reset(query, past_lens, subsequence_begins, _helper._block_size);
1457+
if (output_score)
1458+
_helper.init_score_buffers(past_lens, subsequence_begins);
13881459

13891460
auto nthr = static_cast<size_t>(parallel_get_max_threads());
13901461

13911462
if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
1392-
exec_loop_mixed(query, present_key, present_value, output_emb, max_context_len, past_lens, subsequence_begins,
1463+
exec_loop_mixed(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins,
13931464
block_indices, block_indices_begins, alibi_slopes);
13941465
} else {
1395-
_helper.exec_loop_bhl(query, present_key, present_value, output_emb, max_context_len, past_lens, subsequence_begins,
1466+
_helper.exec_loop_bhl(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins,
13961467
block_indices, block_indices_begins, alibi_slopes);
13971468
}
13981469
}
@@ -1406,9 +1477,9 @@ struct AttentionExecutor : public PagedAttentionExecutor {
14061477

14071478
AttentionExecutor() : _kernel(_helper) {}
14081479

1409-
void init(const std::vector<MemoryPtr>& inputs, const MemoryPtr& output, PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache,
1480+
void init(const std::vector<MemoryPtr>& inputs, const std::vector<MemoryPtr>& outputs, PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache,
14101481
PlainTensor& v_cache, PlainTensor& past_lens, PlainTensor& subsequence_begins, PlainTensor& block_indices, PlainTensor& block_indices_begins,
1411-
float& scale, size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, PlainTensor& output_emb) {
1482+
float& scale, size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, PlainTensor& output_emb, PlainTensor& output_score) {
14121483
q.reset(inputs[ID_Q]); // [B_token, H * S]
14131484
k.reset(inputs[ID_K]);
14141485
v.reset(inputs[ID_V]);
@@ -1423,7 +1494,9 @@ struct AttentionExecutor : public PagedAttentionExecutor {
14231494
if (!inputs[ID_ALIBI_SLOPES]->getShape().hasZeroDims())
14241495
alibi_slopes.reset(inputs[ID_ALIBI_SLOPES]);
14251496
max_context_len = static_cast<size_t>(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs<int32_t>());
1426-
output_emb.reset(output);
1497+
output_emb.reset(outputs[0]);
1498+
if (outputs.size() == 2)
1499+
output_score.reset(outputs[1]);
14271500

14281501
auto B_token = q.size(0);
14291502
auto Hk = k_cache.size(1);
@@ -1496,20 +1569,22 @@ struct AttentionExecutor : public PagedAttentionExecutor {
14961569
}
14971570
}
14981571

1499-
void execute(const std::vector<MemoryPtr>& inputs, const MemoryPtr output) override {
1572+
void execute(const std::vector<MemoryPtr>& inputs, const std::vector<MemoryPtr> outputs) override {
15001573
PlainTensor q, k, v, k_cache, v_cache;
15011574
PlainTensor past_lens, subsequence_begins, block_indices, block_indices_begins;
15021575
float scale;
15031576
size_t sliding_window;
15041577
PlainTensor alibi_slopes;
15051578
size_t max_context_len;
15061579
PlainTensor output_emb;
1580+
PlainTensor output_score;
15071581

1508-
init(inputs, output, q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins,
1509-
scale, sliding_window, alibi_slopes, max_context_len, output_emb);
1582+
init(inputs, outputs, q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins,
1583+
scale, sliding_window, alibi_slopes, max_context_len, output_emb, output_score);
15101584
concat_pastkv(k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins);
15111585

1512-
_kernel(q, k_cache, v_cache, output_emb, max_context_len, past_lens, subsequence_begins, block_indices, block_indices_begins, alibi_slopes);
1586+
_kernel(q, k_cache, v_cache, output_emb, output_score, max_context_len, past_lens, subsequence_begins, block_indices,
1587+
block_indices_begins, alibi_slopes);
15131588
}
15141589
};
15151590
#endif
@@ -1523,6 +1598,7 @@ std::shared_ptr<PagedAttentionExecutor> make_pa_executor(ov::element::Type data_
15231598
if (kvcache_type == ov::element::u8) {
15241599
executor = std::make_shared<AttentionExecutor<ov::bfloat16, uint8_t>>();
15251600
} else {
1601+
OPENVINO_ASSERT(kvcache_type == ov::element::bf16, "expect kvcache type bf16, current: ", kvcache_type);
15261602
executor = std::make_shared<AttentionExecutor<ov::bfloat16, ov::bfloat16>>();
15271603
}
15281604
#else
@@ -1534,6 +1610,7 @@ std::shared_ptr<PagedAttentionExecutor> make_pa_executor(ov::element::Type data_
15341610
} else if (kvcache_type == ov::element::f16) {
15351611
executor = std::make_shared<AttentionExecutor<float, ov::float16>>();
15361612
} else {
1613+
OPENVINO_ASSERT(kvcache_type == ov::element::f32, "expect kvcache type f32, current: ", kvcache_type);
15371614
executor = std::make_shared<AttentionExecutor<float, float>>();
15381615
}
15391616
} else {

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct PagedAttentionExecutor {
3333
static const size_t ID_SLIDING_WINDOW = 10; // []
3434
static const size_t ID_ALIBI_SLOPES = 11; // [H|0], float
3535
static const size_t ID_MAX_CONTEXT_LEN = 12; // []
36-
virtual void execute(const std::vector<ov::intel_cpu::MemoryPtr>& inputs, const ov::intel_cpu::MemoryPtr output) = 0;
36+
virtual void execute(const std::vector<ov::intel_cpu::MemoryPtr>& inputs, const std::vector<ov::intel_cpu::MemoryPtr> outputs) = 0;
3737
};
3838

3939
#ifdef OPENVINO_ARCH_X86_64

src/plugins/intel_cpu/src/nodes/paged_attn.cpp

+26-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "onednn/dnnl.h"
1414
#include "openvino/core/parallel.hpp"
1515
#include "openvino/util/common_util.hpp"
16-
#include "shape_inference/custom/paged_attn.hpp"
1716
#include "shape_inference/shape_inference_internal_dyn.hpp"
1817

1918
#include "utils/plain_tensor.hpp"
@@ -55,11 +54,13 @@ bool PagedAttentionKey::operator==(const PagedAttentionKey& rhs) const {
5554
}
5655

5756
PagedAttention::PagedAttention(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
58-
: Node(op, context, PAShapeInferFactory(op)) {
57+
: Node(op, context, InternalDynShapeInferFactory()) {
5958
std::string errorMessage;
6059
if (!isSupportedOperation(op, errorMessage)) {
6160
OPENVINO_THROW("CPU: " + errorMessage);
6261
}
62+
// output score may have no child
63+
m_hasScore = !op->get_output_target_inputs(1).empty();
6364
}
6465

6566
void PagedAttention::initSupportedPrimitiveDescriptors() {
@@ -113,6 +114,8 @@ void PagedAttention::initSupportedPrimitiveDescriptors() {
113114

114115
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
115116
rtPrecision, getOutputShapeAtPort(0)));
117+
config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
118+
ov::element::f32, getOutputShapeAtPort(1)));
116119

117120
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
118121
}
@@ -143,12 +146,31 @@ void PagedAttention::createPrimitive() {
143146
void PagedAttention::execute(dnnl::stream strm) {
144147
auto orginInputNumber = getOriginalInputsNumber();
145148
std::vector<MemoryPtr> inputs(orginInputNumber);
146-
auto output = getDstMemoryAtPort(0);
149+
std::vector<MemoryPtr> outputs(m_hasScore ? 2 : 1);
150+
147151
for (size_t i = 0; i < orginInputNumber; i++) {
148152
inputs[i] = getSrcMemoryAtPort(i);
149153
}
150154

151-
m_executor->execute(inputs, output);
155+
const auto& queryDims = inputs[0]->getStaticDims();
156+
if (m_hasScore) {
157+
size_t len = 0;
158+
const auto& pastLensDims = inputs[5]->getStaticDims();
159+
auto pastLens = inputs[5]->getDataAs<const int32_t>();
160+
for (size_t i = 0; i < pastLensDims[0]; i++)
161+
len += pastLens[i];
162+
len += queryDims[0];
163+
VectorDims scoreDims{len};
164+
redefineOutputMemory({queryDims, scoreDims});
165+
} else {
166+
redefineOutputMemory(0, queryDims);
167+
}
168+
169+
outputs[0] = getDstMemoryAtPort(0);
170+
if (m_hasScore)
171+
outputs[1] = getDstMemoryAtPort(1);
172+
173+
m_executor->execute(inputs, outputs);
152174
}
153175

154176
bool PagedAttention::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {

src/plugins/intel_cpu/src/nodes/paged_attn.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class PagedAttention : public Node {
4343
std::shared_ptr<ov::Extensions::Cpu::PagedAttentionExecutor> m_executor;
4444
template <typename T> struct AttentionExecutor;
4545
friend struct PagedAttentionKey;
46+
47+
bool m_hasScore = false;
4648
};
4749

4850
} // namespace node

0 commit comments

Comments
 (0)