7
7
#include " exceptions.hpp"
8
8
#include " openvino/frontend/exception.hpp"
9
9
#include " openvino/op/add.hpp"
10
- #include " openvino/op/broadcast.hpp"
11
10
#include " openvino/op/constant.hpp"
12
- #include " openvino/op/convert_like .hpp"
11
+ #include " openvino/op/convert .hpp"
13
12
#include " openvino/op/matmul.hpp"
14
13
#include " openvino/op/multiply.hpp"
15
- #include " openvino/op/shape_of .hpp"
14
+ #include " openvino/op/reshape .hpp"
16
15
#include " openvino/op/slice.hpp"
17
16
#include " openvino/op/subtract.hpp"
18
- #include " openvino/op/transpose.hpp"
19
17
#include " utils/common.hpp"
20
18
#include " utils/reshape.hpp"
21
19
@@ -111,142 +109,103 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
111
109
bias.get_partial_shape ());
112
110
}
113
111
112
+ ov::Output<ov::Node> mm_output;
114
113
{
115
114
const auto b_const = ov::as_type_ptr<v0::Constant>(b_quantized.get_node_shared_ptr ());
116
115
117
116
ov::Output<ov::Node> casted_b;
118
117
ov::Shape casted_b_shape;
119
118
ov::Output<ov::Node> default_zp;
120
119
// Casting/converting data of source constant.
121
- // For further calculations (sub and/or multiply) we need to reshape it from [N][n_blocks_per_col][blob_size *
122
- // X] to [N * n_blocks_per_col][blob_size * X] (where X is amount of values in 1 byte) because scale and
123
- // zero_point are represented as: ...with shape like: [N * n_blocks_per_col]...
120
+ // For further calculations (sub and/or multiply) we need to reshape
121
+ // b -> [N][n_blocks_per_col][block_size]
124
122
switch (bits) {
125
123
case 2 :
126
- casted_b_shape = ov::Shape{static_cast <size_t >(N * n_blocks_per_col), static_cast <size_t >(blob_size * 4 )};
124
+ casted_b_shape = ov::Shape{static_cast <size_t >(N),
125
+ static_cast <size_t >(n_blocks_per_col),
126
+ static_cast <size_t >(blob_size * 4 )};
127
127
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr ());
128
- if (a.get_element_type () != ov::element::dynamic) {
129
- default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 2 );
130
- } else {
131
- default_zp =
132
- std::make_shared<v1::ConvertLike>(a,
133
- std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2 .f ));
134
- }
128
+ default_zp = std::make_shared<v0::Constant>(ov::element::u2, Shape{1 }, 2 );
135
129
break ;
136
130
case 4 :
137
- casted_b_shape = ov::Shape{static_cast <size_t >(N * n_blocks_per_col), static_cast <size_t >(blob_size * 2 )};
131
+ casted_b_shape = ov::Shape{static_cast <size_t >(N),
132
+ static_cast <size_t >(n_blocks_per_col),
133
+ static_cast <size_t >(blob_size * 2 )};
138
134
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr ());
139
- if (a.get_element_type () != ov::element::dynamic) {
140
- default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 8 );
141
- } else {
142
- default_zp =
143
- std::make_shared<v1::ConvertLike>(a,
144
- std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8 .f ));
145
- }
135
+ default_zp = std::make_shared<v0::Constant>(ov::element::u4, Shape{1 }, 8 );
146
136
break ;
147
137
case 8 :
148
- casted_b_shape = ov::Shape{static_cast <size_t >(N * n_blocks_per_col), static_cast <size_t >(blob_size)};
149
- casted_b = op::util::reshape (b_const, casted_b_shape);
150
- if (a.get_element_type () != ov::element::dynamic) {
151
- default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 128 );
152
- } else {
153
- default_zp =
154
- std::make_shared<v1::ConvertLike>(a,
155
- std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128 .f ));
156
- }
138
+ casted_b_shape = ov::Shape{static_cast <size_t >(N),
139
+ static_cast <size_t >(n_blocks_per_col),
140
+ static_cast <size_t >(blob_size)};
141
+ casted_b = std::make_shared<v0::Constant>(ov::element::u8, casted_b_shape, b_const->get_data_ptr ());
142
+ default_zp = std::make_shared<v0::Constant>(ov::element::u8, Shape{1 }, 128 );
157
143
break ;
158
144
default :
159
145
FRONT_END_THROW (" Unsupported bits count" );
160
146
break ;
161
147
}
162
148
149
+ if (!zero_points.get_node_shared_ptr ()) {
150
+ zero_points = default_zp;
151
+ } else {
152
+ // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits
153
+ // according to the link, zero point are:
154
+ // Constrain quantized zero point types to uint8/int32/float16/float.
155
+ // Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B
156
+ zero_points =
157
+ op::util::reshape (zero_points,
158
+ ov::Shape{static_cast <size_t >(N), static_cast <size_t >(n_blocks_per_col), 1 });
159
+ }
160
+
163
161
// Possible issue with slice implementation, had to move convertion before slice, instead of slicing uint4
164
162
// TODO: Ticket
165
- const auto converted_b = std::make_shared<v1::ConvertLike>(casted_b, a);
163
+ // Comments: it is still there, so need to convert b to fp16 first.
166
164
167
165
// TODO: Need to collect performance data in case constant folding is applied. Possible some perf/mem-gap
168
-
169
- // Simple case
170
- if (n_blocks_per_col == 1 ) {
171
- // Removing unused items in case block is bigger than column count
172
- // For example, if data is (uint8)[1,2,3,4,5,6] then block will be (uint8)[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0].
173
- // And last zeros are unused.
174
- const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 0 );
175
- const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 1 );
176
- const auto elements_const =
177
- std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, static_cast <int32_t >(K));
178
- const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 1 );
179
- const auto slice_b =
180
- std::make_shared<v8::Slice>(converted_b, zero_const, elements_const, one_const, axis_const);
181
-
182
- // Transpose matrix
183
- const auto transposed_shape =
184
- std::make_shared<v0::Constant>(ov::element::i64, Shape{2 }, std::vector<int64_t >{1 , 0 });
185
- const auto transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);
186
-
187
- // If no zero-points provided - we generate default, depends on data size
188
- if (!zero_points.get_node_shared_ptr ()) {
189
- zero_points = default_zp;
190
- }
191
- const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);
192
-
193
- // Scaling
194
- const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);
195
-
196
- // Adding bias if required
197
- if (!bias.get_node_shared_ptr ()) {
198
- b = scaled_b;
199
- } else {
200
- b = std::make_shared<v1::Add>(scaled_b, bias);
201
- }
166
+ // Comments: in this latest code, the const folding is gone, it trigle the oneDNN kernel
167
+ // and use u2/u4/u8 weights as the kernel's input, won't do const folding anymore.
168
+
169
+ // use fp16 for compute
170
+
171
+ // convert b to fp16
172
+ auto converted_b = std::make_shared<v0::Convert>(casted_b, a.get_element_type ());
173
+ auto converted_zero_points = std::make_shared<v0::Convert>(zero_points, a.get_element_type ());
174
+
175
+ // sub and scale
176
+ const auto sub_b = std::make_shared<v1::Subtract>(converted_b, converted_zero_points);
177
+ const auto scales_fp16 = std::make_shared<v0::Convert>(scales, a.get_element_type ());
178
+ const auto scales_reshaped =
179
+ op::util::reshape (scales_fp16, ov::Shape{static_cast <size_t >(N), static_cast <size_t >(n_blocks_per_col), 1 });
180
+ const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales_reshaped);
181
+
182
+ // reshape b to [N, K]
183
+ auto shape_b = v0::Constant::create (ov::element::i32, ov::Shape{2 }, {0 , -1 });
184
+ auto reshaped_b = std::make_shared<v1::Reshape>(scaled_b, shape_b, true );
185
+
186
+ // if n_blocks_per_col*blob_size*X != K
187
+ // need slice it to K
188
+ // to produce b = [N, K]
189
+ const bool slice_needed = (K % block_size != 0 );
190
+ if (slice_needed) {
191
+ const auto zero = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 0 );
192
+ const auto one = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 1 );
193
+ const auto elements = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, static_cast <int32_t >(K));
194
+ const auto axis = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 1 );
195
+ b = std::make_shared<v8::Slice>(reshaped_b, zero, elements, one, axis);
202
196
} else {
203
- // Transpose matrix. Quantized B matrix is transposed and has a shape [N,K].
204
- // To apply further operations on it which operand's shape is [N] we do this
205
- // transpose to have a matrix [K,N]...
206
- const auto transposed_shape =
207
- std::make_shared<v0::Constant>(ov::element::i64, Shape{2 }, std::vector<int64_t >{1 , 0 });
208
- ov::Output<ov::Node> transposed_b = std::make_shared<v1::Transpose>(converted_b, transposed_shape);
209
-
210
- // If no zero-points provided - we generate default, depends on data size
211
- if (!zero_points.get_node_shared_ptr ()) {
212
- zero_points = default_zp;
213
- }
214
- const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);
215
-
216
- // Scaling
217
- const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);
218
-
219
- // Transpose again to make reshaping and slicing
220
- transposed_b = std::make_shared<v1::Transpose>(scaled_b, transposed_shape);
221
-
222
- const auto reshaped_b =
223
- op::util::reshape (transposed_b,
224
- ov::Shape{static_cast <size_t >(casted_b_shape[0 ] / n_blocks_per_col),
225
- static_cast <size_t >(casted_b_shape[1 ] * n_blocks_per_col)});
226
-
227
- // Removing unused items in case block is bigger than column count (see description for
228
- // Slice above)
229
- const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 0 );
230
- const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 1 );
231
- const auto elements_const =
232
- std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, static_cast <int32_t >(K));
233
- const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1 }, 1 );
234
- const auto slice_b =
235
- std::make_shared<v8::Slice>(reshaped_b, zero_const, elements_const, one_const, axis_const);
236
-
237
- // Adding bias if required
238
- if (!bias.get_node_shared_ptr ()) {
239
- return {std::make_shared<v0::MatMul>(a, slice_b, false , true )};
240
- } else {
241
- // Transpose again
242
- transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);
243
-
244
- b = std::make_shared<v1::Add>(transposed_b, bias);
245
- }
197
+ b = reshaped_b;
246
198
}
199
+
200
+ // mm = matmul(a,b)
201
+ mm_output = std::make_shared<v0::MatMul>(a, b, false , true );
247
202
}
248
203
249
- return {std::make_shared<v0::MatMul>(a, b)};
204
+ if (bias.get_node_shared_ptr ()) {
205
+ return {std::make_shared<v1::Add>(mm_output, bias)};
206
+ } else {
207
+ return {mm_output};
208
+ }
250
209
}
251
210
252
211
ONNX_OP (" MatMulNBits" , OPSET_SINCE(1 ), com_microsoft::opset_1::matmulnbits, MICROSOFT_DOMAIN);
@@ -255,4 +214,4 @@ ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICR
255
214
} // namespace com_microsoft
256
215
} // namespace onnx
257
216
} // namespace frontend
258
- } // namespace ov
217
+ } // namespace ov
0 commit comments