@@ -108,8 +108,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
108
108
109
109
auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
110
110
auto present_to = out.get_target_inputs ();
111
- if (present_to.size () < 2 )
112
- return false ;
113
111
for (auto & to : present_to) {
114
112
auto to_node = to.get_node ();
115
113
if (auto convert = dynamic_cast <opset1::Convert*>(to_node)) {
@@ -149,6 +147,28 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
149
147
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at (concat_k).get_node_shared_ptr ());
150
148
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at (concat_v).get_node_shared_ptr ());
151
149
150
+ for (auto && item : {concat_k_node, concat_v_node}) {
151
+ auto && children = item->get_output_target_inputs (0 );
152
+ switch (children.size ()) {
153
+ case 2 :
154
+ // pass, as the existence of Assign will be checked later
155
+ break ;
156
+ case 3 :
157
+ // the first one leads to SDPA, otherwise the matcher doesn't find the pattern
158
+ // the second one leads to Assign, and this is checked later
159
+ // the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf
160
+ if (!std::any_of (children.begin (), children.end (), [](const ov::Input<ov::Node>& child) {
161
+ return ov::is_type<ov::op::v3::ShapeOf>(child.get_node ()) ||
162
+ ov::is_type<ov::op::v0::ShapeOf>(child.get_node ());
163
+ })) {
164
+ return false ;
165
+ }
166
+ break ;
167
+ default :
168
+ return false ;
169
+ }
170
+ }
171
+
152
172
opset6::Assign *assign_k_node = nullptr , *assign_v_node = nullptr ;
153
173
opset1::Convert *assign_cvt_k_node = nullptr , *assign_cvt_v_node = nullptr ;
154
174
if (!find_assign (concat_k_node, assign_k_node, assign_cvt_k_node))
0 commit comments