7
7
#include " ../../logging.hpp"
8
8
#include " ../online/group.hpp" // online::Group
9
9
#include " ../online/snapshot.hpp" // online::Snapshot
10
- #include " openvino/op/add.hpp"
11
- #include " openvino/op/broadcast.hpp"
12
- #include " openvino/op/concat.hpp"
13
- #include " openvino/op/convert.hpp"
14
- #include " openvino/op/divide.hpp"
15
- #include " openvino/op/gather.hpp"
16
- #include " openvino/op/greater.hpp"
17
- #include " openvino/op/matmul.hpp"
18
- #include " openvino/op/mod.hpp"
19
- #include " openvino/op/multiply.hpp"
20
- #include " openvino/op/power.hpp"
21
- #include " openvino/op/reduce_mean.hpp"
22
- #include " openvino/op/reshape.hpp"
23
- #include " openvino/op/shape_of.hpp"
24
- #include " openvino/op/sqrt.hpp"
25
- #include " openvino/op/subtract.hpp"
26
- #include " openvino/op/util/op_types.hpp"
27
- #include " openvino/op/variadic_split.hpp"
10
+ #include " openvino/op/ops.hpp"
28
11
#include " openvino/pass/pattern/op/label.hpp" // any_input
29
12
#include " openvino/pass/pattern/op/wrap_type.hpp"
30
13
#include " openvino/util/common_util.hpp"
@@ -37,7 +20,7 @@ namespace compute {
37
20
namespace opp = ov::pass::pattern;
38
21
39
22
// TODO: visualize
40
- DQMatMulGQ::DQMatMulGQ (const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
23
+ DQMatMulGQu4::DQMatMulGQu4 (const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
41
24
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
42
25
auto qzerop = opp::wrap_type<ov::op::v0::Constant>();
43
26
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
@@ -87,11 +70,11 @@ DQMatMulGQ::DQMatMulGQ(const std::shared_ptr<ov::npuw::online::Snapshot>& snapsh
87
70
88
71
return false ; // root hasn't changed
89
72
};
90
- register_matcher (std::make_shared<opp::Matcher>(qmm, " TagDQMatMulGQ " ), std::move (callback));
73
+ register_matcher (std::make_shared<opp::Matcher>(qmm, " TagDQMatMulGQu4 " ), std::move (callback));
91
74
}
92
75
93
76
// TODO: visualize
94
- DQMatMulCW::DQMatMulCW (const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
77
+ DQMatMulCWu4::DQMatMulCWu4 (const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
95
78
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
96
79
auto qzerop = opp::wrap_type<ov::op::v0::Constant>();
97
80
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
@@ -140,7 +123,99 @@ DQMatMulCW::DQMatMulCW(const std::shared_ptr<ov::npuw::online::Snapshot>& snapsh
140
123
141
124
return false ; // root hasn't changed
142
125
};
143
- register_matcher (std::make_shared<opp::Matcher>(qmm, " TagDQMatMulCW" ), std::move (callback));
126
+ register_matcher (std::make_shared<opp::Matcher>(qmm, " TagDQMatMulCWu4" ), std::move (callback));
127
+ }
128
+
129
+ // TODO: visualize
130
+ DQMatMulGQi4::DQMatMulGQi4 (const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
131
+ auto qweight = opp::wrap_type<ov::op::v0::Constant>();
132
+ auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
133
+
134
+ auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
135
+
136
+ auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
137
+ auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input ()});
138
+ auto qcvtr = opp::wrap_type<ov::op::v0::Convert>({qreshp});
139
+ auto qmm = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input (), qcvtr});
140
+
141
+ auto node_to_gptr = snapshot->getNodeToGroupMap ();
142
+
143
+ // Note: Use [=] to make sure the above objects stay alive in the callback
144
+ auto callback = [=](ov::pass::pattern::Matcher& m) {
145
+ auto & node_to_output = m.get_pattern_value_map ();
146
+
147
+ auto matched_node_qweight = node_to_output.at (qweight).get_node_shared_ptr ();
148
+ auto matched_node_qcoeff = node_to_output.at (qcoeff).get_node_shared_ptr ();
149
+
150
+ NPUW_ASSERT (ov::op::util::is_constant (matched_node_qweight));
151
+ NPUW_ASSERT (ov::op::util::is_constant (matched_node_qcoeff));
152
+
153
+ auto matched_qweight = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qweight);
154
+ auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qcoeff);
155
+
156
+ if ((ov::element::i4 == matched_qweight->get_element_type () ||
157
+ ov::element::i8 == matched_qweight->get_element_type ()) &&
158
+ ov::element::f16 == matched_qcoeff->get_element_type ()) {
159
+ // Partitioning ignores Const->Convert nodes, so qcvtw is not used
160
+ auto matched_qmuls = node_to_output.at (qmuls).get_node_shared_ptr ();
161
+ auto matched_qreshp = node_to_output.at (qreshp).get_node_shared_ptr ();
162
+ auto matched_qcvtr = node_to_output.at (qcvtr).get_node_shared_ptr ();
163
+ auto matched_qmm = node_to_output.at (qmm).get_node_shared_ptr ();
164
+
165
+ node_to_gptr->at (matched_qmuls)->isolate (isol_tag);
166
+ node_to_gptr->at (matched_qreshp)->isolate (isol_tag);
167
+ node_to_gptr->at (matched_qcvtr)->isolate (isol_tag);
168
+ node_to_gptr->at (matched_qmm)->isolate (isol_tag);
169
+ }
170
+
171
+ return false ; // root hasn't changed
172
+ };
173
+ register_matcher (std::make_shared<opp::Matcher>(qmm, " TagDQMatMulGQi4" ), std::move (callback));
174
+ }
175
+
176
+ // TODO: visualize
177
+ DQMatMulCWi4::DQMatMulCWi4 (const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
178
+ auto qweight = opp::wrap_type<ov::op::v0::Constant>();
179
+ auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
180
+
181
+ auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
182
+
183
+ auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
184
+
185
+ auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qmuls});
186
+ auto qmm = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input (), qcvtm});
187
+
188
+ auto node_to_gptr = snapshot->getNodeToGroupMap ();
189
+
190
+ // Note: Use [=] to make sure the above objects stay alive in the callback
191
+ auto callback = [=](ov::pass::pattern::Matcher& m) {
192
+ auto & node_to_output = m.get_pattern_value_map ();
193
+
194
+ auto matched_node_qweight = node_to_output.at (qweight).get_node_shared_ptr ();
195
+ auto matched_node_qcoeff = node_to_output.at (qcoeff).get_node_shared_ptr ();
196
+
197
+ NPUW_ASSERT (ov::op::util::is_constant (matched_node_qweight));
198
+ NPUW_ASSERT (ov::op::util::is_constant (matched_node_qcoeff));
199
+
200
+ auto matched_qweight = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qweight);
201
+ auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qcoeff);
202
+
203
+ if ((ov::element::i4 == matched_qweight->get_element_type () ||
204
+ ov::element::i8 == matched_qweight->get_element_type ()) &&
205
+ ov::element::f16 == matched_qcoeff->get_element_type ()) {
206
+ // Partitioning ignores Const->Convert nodes, so qcvtw is not used
207
+ auto matched_qmuls = node_to_output.at (qmuls).get_node_shared_ptr ();
208
+ auto matched_qcvtm = node_to_output.at (qcvtm).get_node_shared_ptr ();
209
+ auto matched_qmm = node_to_output.at (qmm).get_node_shared_ptr ();
210
+
211
+ node_to_gptr->at (matched_qmuls)->isolate (isol_tag);
212
+ node_to_gptr->at (matched_qcvtm)->isolate (isol_tag);
213
+ node_to_gptr->at (matched_qmm)->isolate (isol_tag);
214
+ }
215
+
216
+ return false ; // root hasn't changed
217
+ };
218
+ register_matcher (std::make_shared<opp::Matcher>(qmm, " TagDQMatMulCWi4" ), std::move (callback));
144
219
}
145
220
146
221
// TODO: visualize
0 commit comments