Skip to content

Commit a8e776b

Browse files
authored
[Snippets][CPU] Added KVCacheMatcher check for LLM in MHATokenization (#28812)
### Details: - *Extracted `is_LLM` check from GPU plugin to common part to reuse in other plugins* - *Used extracted `is_large_language_model` function in check for LLM in MHATokenization in CPU Plugin* ### Tickets: - *CVS-160999*
1 parent ec9dfae commit a8e776b

File tree

4 files changed

+36
-38
lines changed

4 files changed

+36
-38
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ TRANSFORMATIONS_API bool constantIsEqualTo(const std::shared_ptr<ov::op::v0::Con
193193

194194
TRANSFORMATIONS_API bool has_f16_constants(const std::shared_ptr<const ov::Model>& function);
195195

196+
TRANSFORMATIONS_API bool is_large_language_model(const ov::Model& model);
197+
196198
/**
197199
* \brief Check if 'other_shape' can be broadcasted to 'ref_shape'
198200
*

src/common/transformations/src/transformations/utils/utils.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
#include "openvino/core/validation_util.hpp"
1313
#include "openvino/op/add.hpp"
1414
#include "openvino/op/broadcast.hpp"
15+
#include "openvino/op/concat.hpp"
1516
#include "openvino/op/constant.hpp"
17+
#include "openvino/op/convert.hpp"
1618
#include "openvino/op/divide.hpp"
1719
#include "openvino/op/gather.hpp"
1820
#include "openvino/op/multiply.hpp"
21+
#include "openvino/op/paged_attention.hpp"
1922
#include "openvino/op/parameter.hpp"
23+
#include "openvino/op/read_value.hpp"
2024
#include "openvino/op/relu.hpp"
2125
#include "openvino/op/reshape.hpp"
2226
#include "openvino/op/shape_of.hpp"
@@ -25,6 +29,9 @@
2529
#include "openvino/op/tanh.hpp"
2630
#include "openvino/op/util/multi_subgraph_base.hpp"
2731
#include "openvino/op/util/shape_of_base.hpp"
32+
#include "openvino/pass/pattern/op/optional.hpp"
33+
#include "openvino/pass/pattern/op/or.hpp"
34+
#include "openvino/pass/pattern/op/wrap_type.hpp"
2835

2936
namespace ov {
3037
namespace op {
@@ -133,6 +140,28 @@ bool has_f16_constants(const std::shared_ptr<const ov::Model>& function) {
133140
return false;
134141
}
135142

143+
bool is_large_language_model(const ov::Model& model) {
144+
using namespace ov::pass::pattern;
145+
146+
const auto past = wrap_type<ov::op::v6::ReadValue>();
147+
const auto convert_past = ov::pass::pattern::optional<ov::op::v0::Convert>(past);
148+
const auto beam_idx = wrap_type<ov::op::v0::Parameter>();
149+
const auto gather_past = wrap_type<ov::op::v8::Gather>({convert_past, beam_idx, wrap_type<ov::op::v0::Constant>()});
150+
const auto gather_convert = ov::pass::pattern::optional<ov::op::v0::Convert>(gather_past);
151+
const auto concat_past_input =
152+
std::make_shared<ov::pass::pattern::op::Or>(OutputVector{convert_past, gather_convert});
153+
const auto concat = wrap_type<ov::op::v0::Concat>({concat_past_input, any_input()});
154+
const auto convert_present = ov::pass::pattern::optional<ov::op::v0::Convert>(concat);
155+
const auto present = wrap_type<ov::op::v6::Assign>({convert_present});
156+
const auto kvcache_matcher = std::make_shared<ov::pass::pattern::Matcher>(present, "KVCacheMatcher");
157+
158+
for (const auto& op : model.get_ops()) {
159+
if (kvcache_matcher->match(op->output(0)) || ov::is_type<ov::op::PagedAttentionExtension>(op))
160+
return true;
161+
}
162+
return false;
163+
}
164+
136165
bool check_for_broadcast(const ov::PartialShape& ref_shape, const ov::PartialShape& other_shape) {
137166
if (ref_shape.rank().is_dynamic() || other_shape.rank().is_dynamic()) {
138167
return false;

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1053,10 +1053,10 @@ void Transformations::MainSnippets(void) {
10531053
#if defined(OPENVINO_ARCH_X86_64)
10541054
// Currently, Snippets don't provide efficient execution for single token inference in LLM case.
10551055
// To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'.
1056-
// We consider the presence of `ScaledDotProductAttentionWithKVCache` and `PagedAttentionExtension` ops
1056+
// We consider the presence of `ScaledDotProductAttentionWithKVCache` ops
10571057
// in the model as a sign that this model is LLM.
1058-
const auto is_LLM = ov::op::util::has_op_with_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(model) ||
1059-
ov::op::util::has_op_with_type<ov::op::PagedAttentionExtension>(model);
1058+
const auto is_LLM = ov::op::util::is_large_language_model(*model.get()) ||
1059+
ov::op::util::has_op_with_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(model);
10601060

10611061
// CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
10621062
const auto is_infer_prc_supported_by_MHA =

src/plugins/intel_gpu/src/runtime/execution_config.cpp

+2-35
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,16 @@
66
#include "intel_gpu/plugin/remote_context.hpp"
77
#include "openvino/core/any.hpp"
88
#include "openvino/core/model.hpp"
9-
#include "openvino/op/concat.hpp"
10-
#include "openvino/op/convert.hpp"
11-
#include "openvino/op/gather.hpp"
129
#include "openvino/op/loop.hpp"
1310
#include "openvino/op/lstm_sequence.hpp"
14-
#include "openvino/op/paged_attention.hpp"
1511
#include "openvino/op/search_sorted.hpp"
1612
#include "openvino/op/stft.hpp"
17-
#include "openvino/pass/pattern/matcher.hpp"
18-
#include "openvino/pass/pattern/op/label.hpp"
19-
#include "openvino/pass/pattern/op/or.hpp"
20-
#include "openvino/pass/pattern/op/wrap_type.hpp"
2113
#include "ov_ops/dynamic_quantize.hpp"
2214
#include "openvino/runtime/internal_properties.hpp"
2315
#include "intel_gpu/runtime/internal_properties.hpp"
2416
#include "openvino/runtime/plugin_config.hpp"
2517
#include "openvino/runtime/properties.hpp"
18+
#include "transformations/utils/utils.hpp"
2619

2720

2821
namespace ov::intel_gpu {
@@ -86,32 +79,6 @@ bool requires_new_shape_infer(const std::shared_ptr<ov::Node>& op) {
8679
return false;
8780
}
8881

89-
bool is_llm(const ov::Model& model) {
90-
using namespace ov::pass::pattern;
91-
92-
auto past = wrap_type<ov::op::v6::ReadValue>();
93-
auto convert_past = wrap_type<ov::op::v0::Convert>({past});
94-
auto gather_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past});
95-
auto beam_idx = wrap_type<ov::op::v0::Parameter>();
96-
auto gather_past = wrap_type<ov::op::v8::Gather>({gather_input, beam_idx, wrap_type<ov::op::v0::Constant>()});
97-
auto gather_convert = wrap_type<ov::op::v0::Convert>({gather_past});
98-
auto concat_past_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past, gather_past, gather_convert});
99-
auto concat = wrap_type<ov::op::v0::Concat>({concat_past_input, any_input()});
100-
auto convert_present = wrap_type<ov::op::v0::Convert>({concat});
101-
auto present_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat, convert_present});
102-
auto present = wrap_type<ov::op::v6::Assign>({present_input});
103-
104-
auto kvcache_matcher = std::make_shared<ov::pass::pattern::Matcher>(present, "KVCacheMatcher");
105-
106-
for (auto& op : model.get_ordered_ops()) {
107-
if (kvcache_matcher->match(op) || ov::is_type<ov::op::PagedAttentionExtension>(op)) {
108-
return true;
109-
}
110-
}
111-
112-
return false;
113-
}
114-
11582
} // namespace
11683

11784
ExecutionConfig::ExecutionConfig() : ov::PluginConfig() { }
@@ -163,7 +130,7 @@ void ExecutionConfig::apply_rt_info(const IRemoteContext* context, const ov::RTM
163130
}
164131

165132
void ExecutionConfig::apply_model_specific_options(const IRemoteContext* context, const ov::Model& model) {
166-
apply_rt_info(context, get_rt_info(model), is_llm(model));
133+
apply_rt_info(context, get_rt_info(model), ov::op::util::is_large_language_model(model));
167134

168135
const auto& ops = model.get_ops();
169136

0 commit comments

Comments
 (0)