Skip to content

Commit 288c5f9

Browse files
[NPUW] Support i4 patterns for compute pipeline (#26785)
Following up with i4 patterns on #25679
1 parent 416bfb4 commit 288c5f9

File tree

4 files changed

+127
-36
lines changed

4 files changed

+127
-36
lines changed

src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,8 @@ std::vector<Isolate> getIsolates(const std::string isolates_unparsed) {
140140
if (!isolates.empty()) {
141141
LOG_INFO("Online partitioning will isolate subgraphs containing specified patterns.");
142142
} else {
143-
LOG_WARN("Incorect pattern in NPUW_ONLINE_ISOLATE!"
144-
<< " Please, follow the example: "
145-
<< "Op:Select/NPU,P:DQMatMulGQ/compute,P:DQMatMulCW/compute,P:RMSNorm/compute. "
146-
<< "No isolate rules will be taken into account during partitioning!");
143+
LOG_WARN("Incorect pattern in NPUW_ONLINE_ISOLATE! No isolate rules will be taken into account during "
144+
"partitioning!");
147145
}
148146

149147
return isolates;
@@ -193,7 +191,8 @@ std::vector<std::string> getNoFolds(const std::string& nofolds_unparsed) {
193191

194192
void setComputeConfig(PassContext& ctx) {
195193
// FIXME: initialize via a dedicated function instead of parsing
196-
ctx.isolates = detail::getIsolates("P:DQMatMulGQ/compute,P:DQMatMulCW/compute,P:RMSNorm/compute");
194+
ctx.isolates = detail::getIsolates("P:DQMatMulGQu4/compute,P:DQMatMulCWu4/compute,P:DQMatMulGQi4/"
195+
"compute,P:DQMatMulCWi4/compute,P:RMSNorm/compute");
197196
ctx.nofolds = detail::getNoFolds("compute");
198197
}
199198

src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,21 @@ void Snapshot::earlyRegroup() {
404404
if (isolate.pattern == "RMSNorm") {
405405
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm>(shared_from_this(), isolate.tag);
406406
handle_patterns = true;
407-
} else if (isolate.pattern == "DQMatMulCW") {
408-
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCW>(shared_from_this(), isolate.tag);
407+
} else if (isolate.pattern == "DQMatMulCWu4") {
408+
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWu4>(shared_from_this(), isolate.tag);
409409
handle_patterns = true;
410-
} else if (isolate.pattern == "DQMatMulGQ") {
411-
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQ>(shared_from_this(), isolate.tag);
410+
} else if (isolate.pattern == "DQMatMulGQu4") {
411+
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQu4>(shared_from_this(), isolate.tag);
412+
handle_patterns = true;
413+
} else if (isolate.pattern == "DQMatMulCWi4") {
414+
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWi4>(shared_from_this(), isolate.tag);
415+
handle_patterns = true;
416+
} else if (isolate.pattern == "DQMatMulGQi4") {
417+
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQi4>(shared_from_this(), isolate.tag);
412418
handle_patterns = true;
413419
} else {
414-
LOG_WARN("OPENVINO_NPUW_ISOLATE only supports RMSNorm, DQMatMulCW, DQMatMulGQ "
420+
LOG_WARN("OPENVINO_NPUW_ISOLATE only supports RMSNorm, DQMatMulCWu4, DQMatMulGQu4, DQMatMulCWi4, "
421+
"DQMatMulGQi4 "
415422
<< "as patterns. Isolate pattern " << isolate.pattern << " is skipped!");
416423
}
417424
}

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp

+97-22
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,7 @@
77
#include "../../logging.hpp"
88
#include "../online/group.hpp" // online::Group
99
#include "../online/snapshot.hpp" // online::Snapshot
10-
#include "openvino/op/add.hpp"
11-
#include "openvino/op/broadcast.hpp"
12-
#include "openvino/op/concat.hpp"
13-
#include "openvino/op/convert.hpp"
14-
#include "openvino/op/divide.hpp"
15-
#include "openvino/op/gather.hpp"
16-
#include "openvino/op/greater.hpp"
17-
#include "openvino/op/matmul.hpp"
18-
#include "openvino/op/mod.hpp"
19-
#include "openvino/op/multiply.hpp"
20-
#include "openvino/op/power.hpp"
21-
#include "openvino/op/reduce_mean.hpp"
22-
#include "openvino/op/reshape.hpp"
23-
#include "openvino/op/shape_of.hpp"
24-
#include "openvino/op/sqrt.hpp"
25-
#include "openvino/op/subtract.hpp"
26-
#include "openvino/op/util/op_types.hpp"
27-
#include "openvino/op/variadic_split.hpp"
10+
#include "openvino/op/ops.hpp"
2811
#include "openvino/pass/pattern/op/label.hpp" // any_input
2912
#include "openvino/pass/pattern/op/wrap_type.hpp"
3013
#include "openvino/util/common_util.hpp"
@@ -37,7 +20,7 @@ namespace compute {
3720
namespace opp = ov::pass::pattern;
3821

3922
// TODO: visualize
40-
DQMatMulGQ::DQMatMulGQ(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
23+
DQMatMulGQu4::DQMatMulGQu4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
4124
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
4225
auto qzerop = opp::wrap_type<ov::op::v0::Constant>();
4326
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
@@ -87,11 +70,11 @@ DQMatMulGQ::DQMatMulGQ(const std::shared_ptr<ov::npuw::online::Snapshot>& snapsh
8770

8871
return false; // root hasn't changed
8972
};
90-
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulGQ"), std::move(callback));
73+
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulGQu4"), std::move(callback));
9174
}
9275

9376
// TODO: visualize
94-
DQMatMulCW::DQMatMulCW(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
77+
DQMatMulCWu4::DQMatMulCWu4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
9578
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
9679
auto qzerop = opp::wrap_type<ov::op::v0::Constant>();
9780
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
@@ -140,7 +123,99 @@ DQMatMulCW::DQMatMulCW(const std::shared_ptr<ov::npuw::online::Snapshot>& snapsh
140123

141124
return false; // root hasn't changed
142125
};
143-
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulCW"), std::move(callback));
126+
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulCWu4"), std::move(callback));
127+
}
128+
129+
// TODO: visualize
130+
DQMatMulGQi4::DQMatMulGQi4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
131+
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
132+
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
133+
134+
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
135+
136+
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
137+
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
138+
auto qcvtr = opp::wrap_type<ov::op::v0::Convert>({qreshp});
139+
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), qcvtr});
140+
141+
auto node_to_gptr = snapshot->getNodeToGroupMap();
142+
143+
// Note: Use [=] to make sure the above objects stay alive in the callback
144+
auto callback = [=](ov::pass::pattern::Matcher& m) {
145+
auto& node_to_output = m.get_pattern_value_map();
146+
147+
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
148+
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
149+
150+
NPUW_ASSERT(ov::op::util::is_constant(matched_node_qweight));
151+
NPUW_ASSERT(ov::op::util::is_constant(matched_node_qcoeff));
152+
153+
auto matched_qweight = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qweight);
154+
auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qcoeff);
155+
156+
if ((ov::element::i4 == matched_qweight->get_element_type() ||
157+
ov::element::i8 == matched_qweight->get_element_type()) &&
158+
ov::element::f16 == matched_qcoeff->get_element_type()) {
159+
// Partitioning ignores Const->Convert nodes, so qcvtw is not used
160+
auto matched_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
161+
auto matched_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
162+
auto matched_qcvtr = node_to_output.at(qcvtr).get_node_shared_ptr();
163+
auto matched_qmm = node_to_output.at(qmm).get_node_shared_ptr();
164+
165+
node_to_gptr->at(matched_qmuls)->isolate(isol_tag);
166+
node_to_gptr->at(matched_qreshp)->isolate(isol_tag);
167+
node_to_gptr->at(matched_qcvtr)->isolate(isol_tag);
168+
node_to_gptr->at(matched_qmm)->isolate(isol_tag);
169+
}
170+
171+
return false; // root hasn't changed
172+
};
173+
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulGQi4"), std::move(callback));
174+
}
175+
176+
// TODO: visualize
177+
DQMatMulCWi4::DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
178+
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
179+
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
180+
181+
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
182+
183+
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
184+
185+
auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qmuls});
186+
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), qcvtm});
187+
188+
auto node_to_gptr = snapshot->getNodeToGroupMap();
189+
190+
// Note: Use [=] to make sure the above objects stay alive in the callback
191+
auto callback = [=](ov::pass::pattern::Matcher& m) {
192+
auto& node_to_output = m.get_pattern_value_map();
193+
194+
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
195+
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
196+
197+
NPUW_ASSERT(ov::op::util::is_constant(matched_node_qweight));
198+
NPUW_ASSERT(ov::op::util::is_constant(matched_node_qcoeff));
199+
200+
auto matched_qweight = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qweight);
201+
auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qcoeff);
202+
203+
if ((ov::element::i4 == matched_qweight->get_element_type() ||
204+
ov::element::i8 == matched_qweight->get_element_type()) &&
205+
ov::element::f16 == matched_qcoeff->get_element_type()) {
206+
// Partitioning ignores Const->Convert nodes, so qcvtw is not used
207+
auto matched_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
208+
auto matched_qcvtm = node_to_output.at(qcvtm).get_node_shared_ptr();
209+
auto matched_qmm = node_to_output.at(qmm).get_node_shared_ptr();
210+
211+
node_to_gptr->at(matched_qmuls)->isolate(isol_tag);
212+
node_to_gptr->at(matched_qcvtm)->isolate(isol_tag);
213+
node_to_gptr->at(matched_qmm)->isolate(isol_tag);
214+
}
215+
216+
return false; // root hasn't changed
217+
};
218+
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulCWi4"), std::move(callback));
144219
}
145220

146221
// TODO: visualize

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,24 @@ class Snapshot; // Forward declaration
2121
namespace patterns {
2222
namespace compute {
2323

24-
class DQMatMulGQ : public ov::pass::MatcherPass {
24+
class DQMatMulGQu4 : public ov::pass::MatcherPass {
2525
public:
26-
DQMatMulGQ(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
26+
DQMatMulGQu4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
2727
};
2828

29-
class DQMatMulCW : public ov::pass::MatcherPass {
29+
class DQMatMulCWu4 : public ov::pass::MatcherPass {
3030
public:
31-
DQMatMulCW(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
31+
DQMatMulCWu4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
32+
};
33+
34+
class DQMatMulGQi4 : public ov::pass::MatcherPass {
35+
public:
36+
DQMatMulGQi4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
37+
};
38+
39+
class DQMatMulCWi4 : public ov::pass::MatcherPass {
40+
public:
41+
DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
3242
};
3343

3444
class RMSNorm : public ov::pass::MatcherPass {

0 commit comments

Comments
 (0)