Skip to content

Commit 55c11c6

Browse files
itikhonoslyalinCuriousPanCake
authored
Port SDPA to PagedAttention transformation (openvinotoolkit#24336)
### Details: Ported SDPA to PagedAttention transformation from python to C++ code. the related PRs: openvinotoolkit#24127 openvinotoolkit#24177 Tested model scope: - [x] "hf-internal-testing/tiny-random-BloomForCausalLM", - [x] "hf-internal-testing/tiny-random-FalconForCausalLM", - [x] "hf-internal-testing/tiny-random-Starcoder2ForCausalLM", - [x] "hf-internal-testing/tiny-random-GPTJForCausalLM", - [x] "hf-internal-testing/tiny-random-StableLmForCausalLM", - [x] "hf-internal-testing/tiny-random-LlamaForCausalLM", - [x] "hf-internal-testing/tiny-random-MistralForCausalLM", - [x] "hf-internal-testing/tiny-random-OPTForCausalLM", - [x] "hf-internal-testing/tiny-random-PhiForCausalLM", - [x] "hf-internal-testing/tiny-random-StableLmForCausalLM", - [x] "facebook/opt-125m", - [x] "llama2", - [x] "bigcode/starcoder2-7b" - [ ] "mosaicml/mpt-7b-chat" (FAILED both py/c++) - acceptable for this PR Issue: RuntimeError: Check '(axis_range_min <= axis) && (axis <= axis_range_max)' failed at src/core/src/validation_util.cpp:386: Concat Parameter axis 2 out of the tensor rank range [0, 0]. - [x] _means, that the response to the dedicated prompt is the same for the py and c++ transformations._ ### Tickets: - *CVS-138664* --------- Co-authored-by: Sergey Lyalin <sergey.lyalin@intel.com> Co-authored-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent 818d282 commit 55c11c6

File tree

15 files changed

+844
-151
lines changed

15 files changed

+844
-151
lines changed

src/bindings/python/src/openvino/_offline_transformations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from openvino._pyopenvino._offline_transformations import compress_model_transformation
1818
from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
1919
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
20+
from openvino._pyopenvino._offline_transformations import paged_attention_transformation

src/bindings/python/src/pyopenvino/core/offline_transformations.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <compress_quantize_weights.hpp>
1010
#include <openvino/pass/make_stateful.hpp>
11+
#include <openvino/pass/sdpa_to_paged_attention.hpp>
1112
#include <openvino/pass/serialize.hpp>
1213
#include <pruning.hpp>
1314
#include <transformations/common_optimizations/compress_float_constants.hpp>
@@ -127,4 +128,13 @@ void regmodule_offline_transformations(py::module m) {
127128
manager.run_passes(model);
128129
},
129130
py::arg("model"));
131+
132+
m_offline_transformations.def(
133+
"paged_attention_transformation",
134+
[](std::shared_ptr<ov::Model> model) {
135+
ov::pass::Manager manager;
136+
manager.register_pass<ov::pass::SDPAToPagedAttention>();
137+
manager.run_passes(model);
138+
},
139+
py::arg("model"));
130140
}

src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp

+2-151
Original file line numberDiff line numberDiff line change
@@ -5,162 +5,13 @@
55
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
66

77
#include "openvino/op/op.hpp"
8+
#include "openvino/op/paged_attention.hpp"
89
#include "pyopenvino/core/common.hpp"
910

1011
namespace py = pybind11;
1112

12-
namespace {
13-
14-
// This is an experimental operation that is implemented in the plugins.
15-
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
16-
class PagedAttentionExtension : public ov::op::Op {
17-
public:
18-
OPENVINO_OP("PagedAttentionExtension");
19-
20-
PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {
21-
constructor_validate_and_infer_types();
22-
}
23-
24-
void validate_and_infer_types() override {
25-
auto value_cache_shape = get_input_partial_shape(4);
26-
// m_num_kv_heads = value_cache_shape[1];
27-
// m_head_size = value_cache_shape[2];
28-
// m_block_size = value_cache_shape[3];
29-
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");
30-
31-
// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
32-
auto key_cache_shape = get_input_partial_shape(3);
33-
NODE_VALIDATION_CHECK(this,
34-
value_cache_shape.size() == 4,
35-
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
36-
// key_cache_shape[1] == m_num_kv_heads &&
37-
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
38-
// m_block_size == key_cache_shape[3], // block_size,
39-
"Key cache shape must be 4 dims");
40-
41-
// query: shape [batch_size, seq_len, num_heads * head_size]
42-
auto query_type = get_input_element_type(0);
43-
auto query_shape = get_input_partial_shape(0);
44-
NODE_VALIDATION_CHECK(
45-
this,
46-
// query_type.is_real() &&
47-
query_shape.size() == 3,
48-
// query_shape[2] == m_num_heads * m_head_size,
49-
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
50-
"Got element type ",
51-
query_type,
52-
", shape ",
53-
query_shape);
54-
55-
// key: shape [batch_size, seq_len, num_kv_heads * head_size]
56-
auto key_type = get_input_element_type(1);
57-
auto key_shape = get_input_partial_shape(1);
58-
NODE_VALIDATION_CHECK(this,
59-
// query_type == key_type &&
60-
key_shape.size() == 3,
61-
"Key type must be the same as query, shape must be the same as query. "
62-
"Got element type ",
63-
key_type,
64-
", shape ",
65-
key_shape);
66-
67-
// value: shape [batch_size, seq_len, num_kv_heads * head_size]
68-
// auto value_type = get_input_element_type(2);
69-
auto value_shape = get_input_partial_shape(2);
70-
71-
// is_prompt: boolean scalar
72-
NODE_VALIDATION_CHECK(this,
73-
// get_input_element_type(5) == ov::element::boolean &&
74-
get_input_shape(5) == ov::Shape({}),
75-
"is_prompt validation failed. ",
76-
"Got element type ",
77-
get_input_element_type(5),
78-
", shape ",
79-
get_input_shape(5));
80-
81-
// slot_mapping: shape [batch_size, max_context_len]
82-
auto slot_mapping_shape = get_input_partial_shape(6);
83-
NODE_VALIDATION_CHECK(this,
84-
// get_input_element_type(6) == ov::element::i64 &&
85-
slot_mapping_shape.size() == 2,
86-
"slot_mapping validation failed. ",
87-
"Got element type ",
88-
get_input_element_type(6),
89-
", shape ",
90-
slot_mapping_shape);
91-
92-
// max_context_len: integer scalar
93-
NODE_VALIDATION_CHECK(this,
94-
// get_input_element_type(7) == ov::element::i32 &&
95-
get_input_shape(7) == ov::Shape({}),
96-
"max_context_len validation failed. ",
97-
"Got element type ",
98-
get_input_element_type(7),
99-
", shape ",
100-
get_input_shape(7));
101-
102-
// context_lens: shape [batch_size]
103-
auto context_lens_shape = get_input_partial_shape(8);
104-
NODE_VALIDATION_CHECK(this,
105-
// get_input_element_type(8) == ov::element::i32 &&
106-
context_lens_shape.size() == 1,
107-
"context_lens validation failed. ",
108-
"Got element type ",
109-
get_input_element_type(8),
110-
", shape ",
111-
context_lens_shape);
112-
113-
// block_tables: shape [batch_size, max_block_per_request]
114-
NODE_VALIDATION_CHECK(this,
115-
// get_input_element_type(9) == ov::element::i32 &&
116-
get_input_partial_shape(9).size() == 2,
117-
"block_tables validation failed. ",
118-
"Got element type ",
119-
get_input_element_type(9),
120-
", shape ",
121-
get_input_partial_shape(9));
122-
123-
// scale: float scalar
124-
NODE_VALIDATION_CHECK(this,
125-
// get_input_element_type(10) == ov::element::f32 &&
126-
get_input_shape(10) == ov::Shape({}),
127-
"block_tables validation failed. ",
128-
"Got element type ",
129-
get_input_element_type(10),
130-
", shape ",
131-
get_input_shape(10));
132-
133-
// alibi_slopes: 1D float tensor
134-
NODE_VALIDATION_CHECK(this,
135-
// get_input_element_type(11) == ov::element::f32 &&
136-
get_input_partial_shape(11).rank().get_length() == 1,
137-
"alibi_slopes should be a 1D float tensor. ",
138-
"Got element type ",
139-
get_input_element_type(11),
140-
", shape ",
141-
get_input_partial_shape(11));
142-
143-
// sliding_window: int scalar
144-
NODE_VALIDATION_CHECK(this,
145-
// get_input_element_type(12) == ov::element::i32 &&
146-
get_input_partial_shape(12).rank().get_length() == 0,
147-
"sliding_window argument should be an i32 scalar. ",
148-
"Got element type ",
149-
get_input_element_type(12),
150-
", shape ",
151-
get_input_partial_shape(12));
152-
153-
set_output_type(0, query_type, query_shape);
154-
}
155-
156-
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
157-
return std::make_shared<PagedAttentionExtension>(new_args);
158-
}
159-
};
160-
161-
} // namespace
162-
16313
void regclass_graph_op_PagedAttentionExtension(py::module m) {
14+
using ov::op::PagedAttentionExtension;
16415
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
16516
m,
16617
"_PagedAttentionExtension");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/op/add.hpp"
9+
#include "openvino/op/parameter.hpp"
10+
#include "openvino/pass/graph_rewrite.hpp"
11+
#include "openvino/pass/pattern/op/wrap_type.hpp"
12+
#include "transformations/utils/utils.hpp"
13+
#include "transformations_visibility.hpp"
14+
15+
namespace ov {
16+
namespace pass {
17+
18+
class PositionIDsReplacer;
19+
20+
} // namespace pass
21+
} // namespace ov
22+
23+
class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass {
24+
public:
25+
OPENVINO_RTTI("PositionIDsReplacer", "0");
26+
explicit PositionIDsReplacer(const std::shared_ptr<Output<Node>>& position_ids);
27+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/op/shape_of.hpp"
9+
#include "openvino/op/subtract.hpp"
10+
#include "openvino/pass/graph_rewrite.hpp"
11+
#include "openvino/pass/pattern/op/wrap_type.hpp"
12+
#include "transformations/utils/utils.hpp"
13+
#include "transformations_visibility.hpp"
14+
15+
namespace ov {
16+
namespace pass {
17+
18+
class PrevSequenceLengthPattern;
19+
20+
} // namespace pass
21+
} // namespace ov
22+
23+
class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
24+
public:
25+
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
26+
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len);
27+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/graph_rewrite.hpp"
8+
#include "transformations_visibility.hpp"
9+
10+
namespace ov {
11+
namespace pass {
12+
13+
class StateManagementPattern;
14+
15+
} // namespace pass
16+
} // namespace ov
17+
18+
class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
19+
public:
20+
OPENVINO_RTTI("StateManagementPattern", "0");
21+
StateManagementPattern(ParameterVector& kv_parameters,
22+
const ParameterVector& model_remaining_params,
23+
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
24+
ParameterVector& parameters_to_remove,
25+
NodeVector& assignes_to_remove,
26+
int& layer_index);
27+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/op/concat.hpp"
9+
#include "openvino/op/parameter.hpp"
10+
#include "openvino/pass/graph_rewrite.hpp"
11+
#include "openvino/pass/pattern/op/wrap_type.hpp"
12+
#include "transformations/utils/utils.hpp"
13+
#include "transformations_visibility.hpp"
14+
15+
namespace ov {
16+
namespace pass {
17+
18+
class TotalSequenceLengthPattern;
19+
20+
} // namespace pass
21+
} // namespace ov
22+
23+
class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass {
24+
public:
25+
OPENVINO_RTTI("TotalSequenceLengthPattern", "0");
26+
explicit TotalSequenceLengthPattern(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
27+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp"
6+
7+
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/op/gather.hpp"
9+
#include "openvino/pass/pattern/op/wrap_type.hpp"
10+
#include "transformations/utils/utils.hpp"
11+
12+
using namespace ov::op;
13+
14+
// TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when
15+
// position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using
16+
// a new slot_mapping parameter.
17+
ov::pass::PositionIDsReplacer::PositionIDsReplacer(const std::shared_ptr<Output<Node>>& position_ids) {
18+
MATCHER_SCOPE(PositionIDsReplacer);
19+
20+
auto input_ids = pattern::any_input();
21+
auto input_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), input_ids, pattern::any_input()});
22+
23+
auto position_ids_pattern = pattern::any_input();
24+
auto offset = pattern::wrap_type<v0::Constant>();
25+
auto add_offset = pattern::wrap_type<v1::Add>({position_ids_pattern, offset});
26+
auto convert = pattern::wrap_type<v0::Convert>({add_offset});
27+
auto position_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), convert, pattern::any_input()});
28+
29+
auto add = pattern::wrap_type<v1::Add>({input_embed, position_embed});
30+
31+
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
32+
const auto& pattern_map = m.get_pattern_value_map();
33+
replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids->get_node_shared_ptr());
34+
return true;
35+
};
36+
37+
auto m = std::make_shared<ov::pass::pattern::Matcher>(add, matcher_name);
38+
register_matcher(m, callback);
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"
6+
7+
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/op/gather.hpp"
9+
#include "openvino/op/shape_of.hpp"
10+
#include "openvino/pass/pattern/op/wrap_type.hpp"
11+
#include "transformations/utils/utils.hpp"
12+
13+
using namespace ov::op;
14+
15+
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
16+
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len) {
17+
MATCHER_SCOPE(PrevSequenceLengthPattern);
18+
19+
auto kv_past = pattern::wrap_type<v6::ReadValue>({pattern::any_input()});
20+
auto kv_gather = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
21+
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_gather});
22+
auto seq = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
23+
24+
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
25+
// TODO: Check that seq has axis that really takes sequence len but not any other dimension -- use symbolics or
26+
// look at the constant input
27+
auto gather = m.get_match_root();
28+
auto target_type = gather->get_output_element_type(0);
29+
std::shared_ptr<Node> replacement;
30+
if (prev_max_seq_len->get_output_element_type(0) != target_type) {
31+
replacement = std::make_shared<v0::Convert>(prev_max_seq_len, target_type);
32+
} else {
33+
replacement = prev_max_seq_len;
34+
}
35+
replace_node(gather, replacement);
36+
return true;
37+
};
38+
39+
auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);
40+
register_matcher(m, callback);
41+
}

0 commit comments

Comments
 (0)