Skip to content

Commit 2aea2e0

Browse files
[TRANSFORMATIONS] Make TotalSequenceLengthPattern pattern stricter (#25434)
[TRANSFORMATIONS] Make TotalSequenceLengthPattern pattern stricter Make TotalSequenceLengthPattern pattern stricter to match one of the cases when 'scale' is calculated from shape. ### Tickets: - CVS-138933 Signed-off-by: Andrii Staikov <andrii.staikov@intel.com> --------- Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent 19a5b95 commit 2aea2e0

File tree

8 files changed

+236
-52
lines changed

8 files changed

+236
-52
lines changed

src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,14 @@
44

55
#pragma once
66

7-
#include "openvino/cc/pass/itt.hpp"
87
#include "openvino/op/add.hpp"
9-
#include "openvino/op/parameter.hpp"
108
#include "openvino/pass/graph_rewrite.hpp"
11-
#include "openvino/pass/pattern/op/wrap_type.hpp"
12-
#include "transformations/utils/utils.hpp"
139
#include "transformations_visibility.hpp"
1410

1511
namespace ov {
1612
namespace pass {
1713

18-
class PositionIDsReplacer;
14+
class TRANSFORMATIONS_API PositionIDsReplacer;
1915

2016
} // namespace pass
2117
} // namespace ov

src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,14 @@
44

55
#pragma once
66

7-
#include "openvino/cc/pass/itt.hpp"
8-
#include "openvino/op/shape_of.hpp"
9-
#include "openvino/op/subtract.hpp"
107
#include "openvino/pass/graph_rewrite.hpp"
11-
#include "openvino/pass/pattern/op/wrap_type.hpp"
128
#include "transformations/utils/utils.hpp"
139
#include "transformations_visibility.hpp"
1410

1511
namespace ov {
1612
namespace pass {
1713

18-
class PrevSequenceLengthPattern;
14+
class TRANSFORMATIONS_API PrevSequenceLengthPattern;
1915

2016
} // namespace pass
2117
} // namespace ov

src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace ov {
1111
namespace pass {
1212

13-
class StateManagementPattern;
13+
class TRANSFORMATIONS_API StateManagementPattern;
1414

1515
} // namespace pass
1616
} // namespace ov

src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,13 @@
44

55
#pragma once
66

7-
#include "openvino/cc/pass/itt.hpp"
8-
#include "openvino/op/concat.hpp"
9-
#include "openvino/op/parameter.hpp"
107
#include "openvino/pass/graph_rewrite.hpp"
11-
#include "openvino/pass/pattern/op/wrap_type.hpp"
12-
#include "transformations/utils/utils.hpp"
138
#include "transformations_visibility.hpp"
149

1510
namespace ov {
1611
namespace pass {
1712

18-
class TotalSequenceLengthPattern;
13+
class TRANSFORMATIONS_API TotalSequenceLengthPattern;
1914

2015
} // namespace pass
2116
} // namespace ov

src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#include "openvino/op/gather.hpp"
1010
#include "openvino/op/reshape.hpp"
1111
#include "openvino/op/shape_of.hpp"
12+
#include "openvino/op/subtract.hpp"
1213
#include "openvino/pass/pattern/op/wrap_type.hpp"
13-
#include "transformations/utils/utils.hpp"
1414

1515
using namespace ov::op;
1616

src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp

+62-9
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
66

77
#include "openvino/cc/pass/itt.hpp"
8+
#include "openvino/core/validation_util.hpp"
89
#include "openvino/op/concat.hpp"
910
#include "openvino/op/gather.hpp"
10-
#include "openvino/op/reshape.hpp"
1111
#include "openvino/op/shape_of.hpp"
1212
#include "openvino/pass/pattern/op/wrap_type.hpp"
1313
#include "transformations/utils/utils.hpp"
@@ -23,21 +23,74 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
2323
auto kv_current = pattern::any_input();
2424
auto kv_concat = pattern::wrap_type<v0::Concat>({kv_gather, kv_current});
2525
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_concat});
26-
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, pattern::any_input(), pattern::any_input()});
26+
auto gather_idx_label = pattern::wrap_type<v0::Constant>();
27+
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, gather_idx_label, pattern::any_input()});
2728

2829
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
2930
// TODO: Check that seq has axis that really takes sequence len but not any other dimension --
3031
// use symbolic infra or look at the constant input
31-
auto gather = m.get_match_root();
32-
auto target_type = gather->get_output_element_type(0);
32+
const auto& pattern_map = m.get_pattern_value_map();
33+
34+
auto concat = std::dynamic_pointer_cast<v0::Concat>(pattern_map.at(kv_concat).get_node_shared_ptr());
35+
auto gather = pattern_map.at(seq).get_node_shared_ptr();
36+
auto gather_idx =
37+
std::dynamic_pointer_cast<v0::Constant>(pattern_map.at(gather_idx_label).get_node_shared_ptr());
38+
39+
if (!concat || !gather || !gather_idx || !gather_idx) {
40+
return false;
41+
}
42+
43+
auto gather_idx_data = gather_idx->cast_vector<int64_t>();
44+
45+
if (gather_idx_data.size() != 1) {
46+
return false;
47+
}
48+
49+
int64_t gather_idx_to_compare = gather_idx_data[0];
50+
51+
if (gather_idx_data[0] < 0) {
52+
if (gather->input(0).get_partial_shape().is_static()) {
53+
const auto& gather_data_shape = gather->input(0).get_shape();
54+
gather_idx_to_compare = ov::util::normalize(gather_idx_data[0], gather_data_shape[0]);
55+
} else {
56+
return false;
57+
}
58+
}
59+
3360
std::shared_ptr<Node> replacement = max_context_len;
34-
if (replacement->get_output_element_type(0) != target_type) {
35-
replacement = std::make_shared<v0::Convert>(replacement, target_type);
61+
62+
int64_t concat_axis_to_compare = concat->get_axis();
63+
if (concat_axis_to_compare < 0) {
64+
// If it's dynamic, leave it negative as we cannot take dynamic
65+
// dimension here so the next comparison would fail
66+
if (concat->get_output_partial_shape(0).is_static()) {
67+
const auto& concat_output_shape = concat->output(0).get_partial_shape();
68+
concat_axis_to_compare =
69+
ov::util::normalize(concat_axis_to_compare, concat_output_shape.rank().get_length());
70+
}
3671
}
37-
auto required_shape = gather->get_output_partial_shape(0);
38-
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
39-
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
72+
73+
if (concat_axis_to_compare == gather_idx_to_compare) {
74+
auto target_type = gather->get_output_element_type(0);
75+
76+
if (replacement->get_output_element_type(0) != target_type) {
77+
replacement = std::make_shared<v0::Convert>(replacement, target_type);
78+
}
79+
80+
auto required_shape = gather->get_output_partial_shape(0);
81+
82+
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
83+
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
84+
}
85+
} else {
86+
// TODO: change in the future when we start supporting dynamic shapes here
87+
replacement = ov::util::get_constant_from_source(gather->output(0));
88+
OPENVINO_ASSERT(replacement,
89+
"TotalSequenceLengthPattern transformation failed to determine the dimension value after ",
90+
"the Gather operation. Most probably, the required dimension is dynamic: ",
91+
concat);
4092
}
93+
4194
replace_node(gather, replacement);
4295
return true;
4396
};

src/common/transformations/tests/sdpa_to_paged_attention_test.cpp

+127
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,17 @@
88

99
#include "common_test_utils/test_common.hpp"
1010
#include "openvino/op/add.hpp"
11+
#include "openvino/op/concat.hpp"
12+
#include "openvino/op/constant.hpp"
13+
#include "openvino/op/convert.hpp"
14+
#include "openvino/op/divide.hpp"
15+
#include "openvino/op/gather.hpp"
16+
#include "openvino/op/power.hpp"
17+
#include "openvino/op/read_value.hpp"
1118
#include "openvino/op/scaled_dot_product_attention.hpp"
19+
#include "openvino/op/shape_of.hpp"
1220
#include "openvino/pass/manager.hpp"
21+
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
1322

1423
using namespace ov;
1524

@@ -24,4 +33,122 @@ TEST(SDPATOPATest, SDPANotPresent) {
2433
ov::pass::Manager manager;
2534
manager.register_pass<pass::SDPAToPagedAttention>();
2635
EXPECT_THROW(manager.run_passes(model), ov::Exception);
36+
}
37+
38+
TEST(SDPATOPATest, GatherIdx_ConcatAxis_EQ) {
39+
// Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
40+
const int CONCAT_AXIS = 1;
41+
const int GATHER_IDX = 1;
42+
43+
const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
44+
auto variable = std::make_shared<ov::op::util::Variable>(
45+
ov::op::util::VariableInfo{PartialShape::dynamic(), element::i32, "variable"});
46+
const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);
47+
48+
const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
49+
const auto gather_axis = op::v0::Constant::create(element::i64, Shape{}, {0});
50+
const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);
51+
52+
const auto concat_input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1, 2, 3});
53+
const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);
54+
55+
const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);
56+
57+
const auto gather_indices = op::v0::Constant::create(element::i64, Shape{}, {GATHER_IDX});
58+
const auto gather_axis2 = op::v0::Constant::create(element::i64, Shape{}, {0});
59+
const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);
60+
61+
const auto result = std::make_shared<op::v0::Result>(gather1);
62+
auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});
63+
64+
const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
65+
66+
ov::pass::Manager manager;
67+
manager.set_per_pass_validation(false);
68+
manager.register_pass<ov::pass::TotalSequenceLengthPattern>(max_context_len);
69+
bool transformation_run = manager.run_passes(model);
70+
71+
EXPECT_TRUE(transformation_run);
72+
const auto new_convert =
73+
std::dynamic_pointer_cast<op::v0::Convert>(result->input(0).get_source_output().get_node_shared_ptr());
74+
EXPECT_TRUE(new_convert);
75+
const auto new_max_context_len =
76+
std::dynamic_pointer_cast<op::v0::Parameter>(new_convert->input(0).get_source_output().get_node_shared_ptr());
77+
EXPECT_TRUE(new_max_context_len);
78+
EXPECT_TRUE(new_max_context_len == max_context_len);
79+
}
80+
81+
TEST(SDPATOPATest, GatherIdx_ConcatAxis_NOTEQ_STATIC) {
82+
// Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
83+
const int CONCAT_AXIS = 1;
84+
const int GATHER_IDX = 0;
85+
86+
const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
87+
auto variable = std::make_shared<ov::op::util::Variable>(
88+
ov::op::util::VariableInfo{PartialShape::dynamic(), element::i32, "variable"});
89+
const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);
90+
91+
const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
92+
const auto gather_axis = op::v0::Constant::create(element::i64, Shape{}, {0});
93+
const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);
94+
95+
const auto concat_input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1, 2, 3});
96+
const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);
97+
98+
const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);
99+
100+
const auto gather_indices = op::v0::Constant::create(element::i64, Shape{}, {GATHER_IDX});
101+
const auto gather_axis2 = op::v0::Constant::create(element::i64, Shape{}, {0});
102+
const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);
103+
104+
const auto result = std::make_shared<op::v0::Result>(gather1);
105+
auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});
106+
107+
const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
108+
109+
ov::pass::Manager manager;
110+
manager.set_per_pass_validation(false);
111+
manager.register_pass<ov::pass::TotalSequenceLengthPattern>(max_context_len);
112+
bool transformation_run = manager.run_passes(model);
113+
114+
EXPECT_TRUE(transformation_run);
115+
const auto new_constant =
116+
std::dynamic_pointer_cast<op::v0::Constant>(result->input(0).get_source_output().get_node_shared_ptr());
117+
EXPECT_TRUE(new_constant);
118+
}
119+
120+
TEST(SDPATOPATest, GatherIdx_ConcatAxis_NOTEQ_DYNAMIC) {
121+
// Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
122+
const int CONCAT_AXIS = 1;
123+
const int GATHER_IDX = 0;
124+
125+
const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
126+
auto variable = std::make_shared<ov::op::util::Variable>(
127+
ov::op::util::VariableInfo{PartialShape::dynamic(), element::i32, "variable"});
128+
const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);
129+
130+
const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
131+
const auto gather_axis = op::v0::Constant::create(element::i64, Shape{}, {0});
132+
const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);
133+
134+
const auto concat_input =
135+
std::make_shared<op::v0::Parameter>(element::i32,
136+
PartialShape{Dimension(1, 2), Dimension(1, 3), Dimension(1, 4)});
137+
const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);
138+
139+
const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);
140+
141+
const auto gather_indices = op::v0::Constant::create(element::i64, Shape{}, {GATHER_IDX});
142+
const auto gather_axis2 = op::v0::Constant::create(element::i64, Shape{}, {0});
143+
const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);
144+
145+
const auto result = std::make_shared<op::v0::Result>(gather1);
146+
auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});
147+
148+
const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
149+
150+
ov::pass::Manager manager;
151+
manager.set_per_pass_validation(false);
152+
manager.register_pass<ov::pass::TotalSequenceLengthPattern>(max_context_len);
153+
EXPECT_THROW(manager.run_passes(model), ov::Exception);
27154
}

0 commit comments

Comments
 (0)