Skip to content

Commit b194730

Browse files
authored
[CPU]Support "dynamic" KV cache precision for PA (#29018)
### Details: - *Support "dynamic" KV cache precision for PagedAttention case* ### Tickets: - *CVS-161326* --------- Signed-off-by: Zhang Yi <yi3.zhang@intel.com>
1 parent a8aba4e commit b194730

File tree

6 files changed

+520
-23
lines changed

6 files changed

+520
-23
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/core/type/element_type.hpp"
8+
#include "openvino/pass/graph_rewrite.hpp"
9+
#include "transformations_visibility.hpp"
10+
11+
namespace ov {
12+
namespace pass {
13+
class TRANSFORMATIONS_API ConvertPagedAttnInputs;
14+
15+
/**
16+
* @ingroup ov_transformation_common_api
17+
* @brief Set precision and shape of KV cache in PagedAttn op based runtime options
18+
*/
19+
20+
class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
21+
public:
22+
struct KVCacheConfig {
23+
ov::element::Type keyCachePrecision;
24+
ov::element::Type valueCachePrecision;
25+
ov::element::Type inferencePrecision;
26+
size_t keyCacheBlockSize = 32;
27+
size_t valueCacheBlockSize = 32;
28+
size_t keyCacheGroupSize = 0;
29+
size_t valueCacheGroupSize = 0;
30+
bool keyCacheQuantBychannel = false;
31+
bool valueCacheQuantBychannel = false;
32+
std::vector<size_t> keyCacheDimOrder = {0, 1, 2, 3};
33+
std::vector<size_t> valueCacheDimOrder = {0, 1, 2, 3};
34+
};
35+
36+
OPENVINO_MATCHER_PASS_RTTI("ConvertPagedAttnInputs");
37+
ConvertPagedAttnInputs(const KVCacheConfig& config);
38+
39+
void setKVCacheConfig(const KVCacheConfig& config);
40+
41+
const KVCacheConfig& getKVCacheConfig() const;
42+
43+
private:
44+
KVCacheConfig m_config;
45+
};
46+
47+
} // namespace pass
48+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/common_optimizations/convert_pagedattn_inputs.hpp"
6+
7+
#include <cstdint>
8+
#include <memory>
9+
#include <transformations/utils/gen_pattern.hpp>
10+
11+
#include "itt.hpp"
12+
#include "openvino/core/rt_info.hpp"
13+
#include "openvino/op/add.hpp"
14+
#include "openvino/op/constant.hpp"
15+
#include "openvino/op/paged_attention.hpp"
16+
#include "openvino/util/log.hpp"
17+
#include "transformations/utils/utils.hpp"
18+
using namespace ov::gen_pattern;
19+
20+
ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config) : m_config(config) {
21+
MATCHER_SCOPE(ConvertPagedAttnInputs);
22+
23+
auto Q = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
24+
auto K = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
25+
auto V = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
26+
auto key_cache_0 = makePattern<ov::op::v0::Parameter>({});
27+
auto value_cache_0 = makePattern<ov::op::v0::Parameter>({});
28+
auto past_lens = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
29+
auto subsequence_begins = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
30+
auto block_indices = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
31+
auto block_indices_begins = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
32+
auto scale = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
33+
auto sliding_window = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
34+
auto alibi_slopes = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
35+
auto max_context_len = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
36+
auto rotated_block_indices = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
37+
auto rotation_deltas = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
38+
auto rotation_trig_lut = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
39+
40+
auto pa_1 = makePattern<op::PagedAttentionExtension>({Q,
41+
K,
42+
V,
43+
key_cache_0,
44+
value_cache_0,
45+
past_lens,
46+
subsequence_begins,
47+
block_indices,
48+
block_indices_begins,
49+
scale,
50+
sliding_window,
51+
alibi_slopes,
52+
max_context_len});
53+
54+
auto pa_2 = makePattern<op::PagedAttentionExtension>({Q,
55+
K,
56+
V,
57+
key_cache_0,
58+
value_cache_0,
59+
past_lens,
60+
subsequence_begins,
61+
block_indices,
62+
block_indices_begins,
63+
scale,
64+
sliding_window,
65+
alibi_slopes,
66+
max_context_len,
67+
rotated_block_indices,
68+
rotation_deltas,
69+
rotation_trig_lut});
70+
auto result = pa_1 | pa_2;
71+
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
72+
const auto pa_op = m.get_match_root();
73+
auto key_cache = ov::as_type_ptr<ov::op::v0::Parameter>(pa_op->get_input_node_shared_ptr(3));
74+
auto value_cache = ov::as_type_ptr<ov::op::v0::Parameter>(pa_op->get_input_node_shared_ptr(4));
75+
auto format_cache_precision = [](ov::element::Type cache_precision, ov::element::Type infer_precision) {
76+
return cache_precision == ov::element::f16 && infer_precision == ov::element::bf16 ? infer_precision
77+
: cache_precision;
78+
};
79+
auto init_cache_shape = [&](const size_t head_nums,
80+
const size_t head_size,
81+
const size_t block_size,
82+
const ov::element::Type precision,
83+
const size_t group_size,
84+
const bool bychannel,
85+
const std::vector<size_t>& orders) {
86+
size_t _block_size = block_size;
87+
ov::Dimension::value_type _head_nums = head_nums;
88+
ov::Dimension::value_type _head_size = head_size;
89+
ov::Dimension::value_type _group_size = group_size;
90+
_group_size = _group_size ? _group_size : _head_size;
91+
if (!bychannel) {
92+
if (_head_size % _group_size != 0) {
93+
OPENVINO_THROW("cache head_size ", head_size, "cannot be divided by group_size ", group_size);
94+
}
95+
}
96+
size_t group_num = _head_size / _group_size;
97+
if (precision == ov::element::u8) {
98+
if (bychannel) {
99+
_block_size += 2 * sizeof(float);
100+
} else {
101+
_head_size += sizeof(float) * 2 * group_num;
102+
}
103+
} else if (precision == ov::element::u4) {
104+
_head_size += sizeof(float) * 2 * group_num * 2;
105+
}
106+
auto block_shape = ov::PartialShape::dynamic(4);
107+
108+
block_shape[orders[0]] = -1;
109+
block_shape[orders[1]] = _head_nums;
110+
block_shape[orders[2]] = _block_size;
111+
block_shape[orders[3]] = _head_size;
112+
113+
return block_shape;
114+
};
115+
auto key_cache_precision = format_cache_precision(m_config.keyCachePrecision, m_config.inferencePrecision);
116+
auto value_cache_precision = format_cache_precision(m_config.valueCachePrecision, m_config.inferencePrecision);
117+
key_cache->set_element_type(key_cache_precision);
118+
value_cache->set_element_type(value_cache_precision);
119+
if (!pa_op->get_rt_info().count("num_k_heads") || !pa_op->get_rt_info().count("k_head_size") ||
120+
!pa_op->get_rt_info().count("num_v_heads") || !pa_op->get_rt_info().count("num_v_heads")) {
121+
OPENVINO_DEBUG("PagedAttn ",
122+
pa_op->get_friendly_name(),
123+
" doesn't have rtinfo for num_k_heads/k_head_size/num_v_heads/num_v_heads");
124+
return false;
125+
}
126+
const auto key_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_k_heads"].as<size_t>(),
127+
pa_op->get_rt_info()["k_head_size"].as<size_t>(),
128+
m_config.keyCacheBlockSize,
129+
key_cache_precision,
130+
m_config.keyCacheGroupSize,
131+
m_config.keyCacheQuantBychannel,
132+
m_config.keyCacheDimOrder);
133+
const auto value_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_v_heads"].as<size_t>(),
134+
pa_op->get_rt_info()["v_head_size"].as<size_t>(),
135+
m_config.valueCacheBlockSize,
136+
value_cache_precision,
137+
m_config.valueCacheGroupSize,
138+
m_config.valueCacheQuantBychannel,
139+
m_config.valueCacheDimOrder);
140+
141+
key_cache->set_partial_shape(key_cache_shape);
142+
value_cache->set_partial_shape(value_cache_shape);
143+
return true;
144+
};
145+
146+
auto m = std::make_shared<ov::pass::pattern::Matcher>(result, matcher_name);
147+
this->register_matcher(m, callback);
148+
}
149+
150+
void ov::pass::ConvertPagedAttnInputs::setKVCacheConfig(const KVCacheConfig& config) {
151+
m_config = config;
152+
}
153+
154+
const ov::pass::ConvertPagedAttnInputs::KVCacheConfig& ov::pass::ConvertPagedAttnInputs::getKVCacheConfig() const {
155+
return m_config;
156+
}

0 commit comments

Comments
 (0)