Skip to content

Commit 908dac9

Browse files
authoredMar 22, 2024
[CPU] Fix SDPA pattern matching (openvinotoolkit#23581)
### Details: Limit the Concat layer to have maximum 3 children. The third one is allowed to be a ShapeOf op only (to support Mixtral). ### Tickets: - 135375
1 parent 0902fe4 commit 908dac9

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed
 

‎src/plugins/intel_cpu/src/nodes/scaled_attn.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,7 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {
11111111
presentv_input = inputs[ID_VCACHE];
11121112
} else {
11131113
if (m_config.config.fuse_concat) {
1114+
CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states");
11141115
// initialization will be also completed in this func
11151116
gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber));
11161117

‎src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
108108

109109
auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
110110
auto present_to = out.get_target_inputs();
111-
if (present_to.size() < 2)
112-
return false;
113111
for (auto& to : present_to) {
114112
auto to_node = to.get_node();
115113
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
@@ -149,6 +147,28 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
149147
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
150148
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
151149

150+
for (auto&& item : {concat_k_node, concat_v_node}) {
151+
auto&& children = item->get_output_target_inputs(0);
152+
switch (children.size()) {
153+
case 2:
154+
// pass, as the existence of Assign will be checked later
155+
break;
156+
case 3:
157+
// the first one leads to SDPA, otherwise the matcher doesn't find the pattern
158+
// the second one leads to Assign, and this is checked later
159+
// the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf
160+
if (!std::any_of(children.begin(), children.end(), [](const ov::Input<ov::Node>& child) {
161+
return ov::is_type<ov::op::v3::ShapeOf>(child.get_node()) ||
162+
ov::is_type<ov::op::v0::ShapeOf>(child.get_node());
163+
})) {
164+
return false;
165+
}
166+
break;
167+
default:
168+
return false;
169+
}
170+
}
171+
152172
opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
153173
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
154174
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))

0 commit comments

Comments
 (0)