Skip to content

Commit 4d01e56

Browse files
[NPUW] Add new compute pattern (#28935)
1 parent 40de634 commit 4d01e56

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ static const std::map<std::string, std::string> ISOL_PRESETS = {{"COMPUTE",
2727
"P:DQMatMulGQi4/compute,P:DQMatMulCWi4/compute,"
2828
"P:DQMatMulConv/compute,"
2929
"P:VocabMatMul/compute,"
30-
"P:RMSNorm/compute,P:RMSNorm2/compute"}};
30+
"P:RMSNorm/compute,P:RMSNorm2/compute,"
31+
"P:VariadicSplit/compute"}};
3132
}
3233

3334
// For missing declaration warning

src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ void Snapshot::earlyRegroup() {
478478
HNDL(DQMatMulGQi4);
479479
HNDL(DQMatMulConv);
480480
HNDL(VocabMatMul);
481+
HNDL(VariadicSplit);
481482
#undef HNDL
482483
}
483484
}

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,31 @@ RMSNorm2::RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot,
406406
register_matcher(std::make_shared<opp::Matcher>(multiply, "TagRMSNorm2"), std::move(callback));
407407
}
408408

409+
// TODO: visualize
410+
VariadicSplit::VariadicSplit(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
411+
auto vsplit = opp::wrap_type<ov::op::v1::VariadicSplit>({opp::any_input(), opp::any_input(), opp::any_input()});
412+
auto swish = opp::wrap_type<ov::op::v4::Swish>({vsplit});
413+
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({vsplit, swish});
414+
415+
auto node_to_gptr = snapshot->getNodeToGroupMap();
416+
417+
// Note: Use [=] to make sure the above objects stay alive in the callback
418+
auto callback = [=](ov::pass::pattern::Matcher& m) {
419+
auto& node_to_output = m.get_pattern_value_map();
420+
421+
auto matched_vsplit = node_to_output.at(vsplit).get_node_shared_ptr();
422+
auto matched_swish = node_to_output.at(swish).get_node_shared_ptr();
423+
auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr();
424+
425+
node_to_gptr->at(matched_vsplit)->isolate(isol_tag);
426+
node_to_gptr->at(matched_swish)->isolate(isol_tag);
427+
node_to_gptr->at(matched_multiply)->isolate(isol_tag);
428+
429+
return false; // root hasn't changed
430+
};
431+
register_matcher(std::make_shared<opp::Matcher>(multiply, "TagVariadicSplit"), std::move(callback));
432+
}
433+
409434
} // namespace compute
410435
} // namespace patterns
411436
} // namespace npuw

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ class RMSNorm2 : public ov::pass::MatcherPass {
6969
RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
7070
};
7171

72+
class VariadicSplit : public ov::pass::MatcherPass {
73+
public:
74+
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::compute::VariadicSplit");
75+
VariadicSplit(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
76+
};
77+
7278
} // namespace compute
7379
} // namespace patterns
7480
} // namespace npuw

0 commit comments

Comments
 (0)