8
8
9
9
#include " common_test_utils/test_common.hpp"
10
10
#include " openvino/op/add.hpp"
11
+ #include " openvino/op/concat.hpp"
12
+ #include " openvino/op/constant.hpp"
13
+ #include " openvino/op/convert.hpp"
14
+ #include " openvino/op/divide.hpp"
15
+ #include " openvino/op/gather.hpp"
16
+ #include " openvino/op/power.hpp"
17
+ #include " openvino/op/read_value.hpp"
11
18
#include " openvino/op/scaled_dot_product_attention.hpp"
19
+ #include " openvino/op/shape_of.hpp"
12
20
#include " openvino/pass/manager.hpp"
21
+ #include " transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
13
22
14
23
using namespace ov ;
15
24
@@ -24,4 +33,122 @@ TEST(SDPATOPATest, SDPANotPresent) {
24
33
ov::pass::Manager manager;
25
34
manager.register_pass <pass::SDPAToPagedAttention>();
26
35
EXPECT_THROW (manager.run_passes (model), ov::Exception);
36
+ }
37
+
38
+ TEST (SDPATOPATest, GatherIdx_ConcatAxis_EQ) {
39
+ // Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
40
+ const int CONCAT_AXIS = 1 ;
41
+ const int GATHER_IDX = 1 ;
42
+
43
+ const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
44
+ auto variable = std::make_shared<ov::op::util::Variable>(
45
+ ov::op::util::VariableInfo{PartialShape::dynamic (), element::i32, " variable" });
46
+ const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);
47
+
48
+ const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
49
+ const auto gather_axis = op::v0::Constant::create (element::i64, Shape{}, {0 });
50
+ const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);
51
+
52
+ const auto concat_input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1 , 2 , 3 });
53
+ const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);
54
+
55
+ const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);
56
+
57
+ const auto gather_indices = op::v0::Constant::create (element::i64, Shape{}, {GATHER_IDX});
58
+ const auto gather_axis2 = op::v0::Constant::create (element::i64, Shape{}, {0 });
59
+ const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);
60
+
61
+ const auto result = std::make_shared<op::v0::Result>(gather1);
62
+ auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});
63
+
64
+ const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
65
+
66
+ ov::pass::Manager manager;
67
+ manager.set_per_pass_validation (false );
68
+ manager.register_pass <ov::pass::TotalSequenceLengthPattern>(max_context_len);
69
+ bool transformation_run = manager.run_passes (model);
70
+
71
+ EXPECT_TRUE (transformation_run);
72
+ const auto new_convert =
73
+ std::dynamic_pointer_cast<op::v0::Convert>(result->input (0 ).get_source_output ().get_node_shared_ptr ());
74
+ EXPECT_TRUE (new_convert);
75
+ const auto new_max_context_len =
76
+ std::dynamic_pointer_cast<op::v0::Parameter>(new_convert->input (0 ).get_source_output ().get_node_shared_ptr ());
77
+ EXPECT_TRUE (new_max_context_len);
78
+ EXPECT_TRUE (new_max_context_len == max_context_len);
79
+ }
80
+
81
+ TEST (SDPATOPATest, GatherIdx_ConcatAxis_NOTEQ_STATIC) {
82
+ // Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
83
+ const int CONCAT_AXIS = 1 ;
84
+ const int GATHER_IDX = 0 ;
85
+
86
+ const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
87
+ auto variable = std::make_shared<ov::op::util::Variable>(
88
+ ov::op::util::VariableInfo{PartialShape::dynamic (), element::i32, " variable" });
89
+ const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);
90
+
91
+ const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
92
+ const auto gather_axis = op::v0::Constant::create (element::i64, Shape{}, {0 });
93
+ const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);
94
+
95
+ const auto concat_input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1 , 2 , 3 });
96
+ const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);
97
+
98
+ const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);
99
+
100
+ const auto gather_indices = op::v0::Constant::create (element::i64, Shape{}, {GATHER_IDX});
101
+ const auto gather_axis2 = op::v0::Constant::create (element::i64, Shape{}, {0 });
102
+ const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);
103
+
104
+ const auto result = std::make_shared<op::v0::Result>(gather1);
105
+ auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});
106
+
107
+ const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
108
+
109
+ ov::pass::Manager manager;
110
+ manager.set_per_pass_validation (false );
111
+ manager.register_pass <ov::pass::TotalSequenceLengthPattern>(max_context_len);
112
+ bool transformation_run = manager.run_passes (model);
113
+
114
+ EXPECT_TRUE (transformation_run);
115
+ const auto new_constant =
116
+ std::dynamic_pointer_cast<op::v0::Constant>(result->input (0 ).get_source_output ().get_node_shared_ptr ());
117
+ EXPECT_TRUE (new_constant);
118
+ }
119
+
120
+ TEST (SDPATOPATest, GatherIdx_ConcatAxis_NOTEQ_DYNAMIC) {
121
+ // Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
122
+ const int CONCAT_AXIS = 1 ;
123
+ const int GATHER_IDX = 0 ;
124
+
125
+ const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
126
+ auto variable = std::make_shared<ov::op::util::Variable>(
127
+ ov::op::util::VariableInfo{PartialShape::dynamic (), element::i32, " variable" });
128
+ const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);
129
+
130
+ const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
131
+ const auto gather_axis = op::v0::Constant::create (element::i64, Shape{}, {0 });
132
+ const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);
133
+
134
+ const auto concat_input =
135
+ std::make_shared<op::v0::Parameter>(element::i32,
136
+ PartialShape{Dimension (1 , 2 ), Dimension (1 , 3 ), Dimension (1 , 4 )});
137
+ const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);
138
+
139
+ const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);
140
+
141
+ const auto gather_indices = op::v0::Constant::create (element::i64, Shape{}, {GATHER_IDX});
142
+ const auto gather_axis2 = op::v0::Constant::create (element::i64, Shape{}, {0 });
143
+ const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);
144
+
145
+ const auto result = std::make_shared<op::v0::Result>(gather1);
146
+ auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});
147
+
148
+ const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
149
+
150
+ ov::pass::Manager manager;
151
+ manager.set_per_pass_validation (false );
152
+ manager.register_pass <ov::pass::TotalSequenceLengthPattern>(max_context_len);
153
+ EXPECT_THROW (manager.run_passes (model), ov::Exception);
27
154
}
0 commit comments