Skip to content

Commit 68ecdfb

Browse files
bopeng1234wine99
andauthored
optimize com.microsoft.MatMulNbits operator (#28504)
This PR is doing some optimization work on onnxfrontend com.microsoft.MatMulNbits operators with this changes: 1. it disabled const folding with use 75GB for phi3 INT4 model and 200+GB for llama3 INT4 model. 2. it trigger oneDNN matmul primitives, much benefits the GPU performance we tested this changes along with another PR #28163 , and confirmed phi3/llama3 INT4 model run well in LNL. --------- Co-authored-by: Yu, Zijun <zijun.yu@intel.com>
1 parent 047976e commit 68ecdfb

File tree

1 file changed

+71
-112
lines changed

1 file changed

+71
-112
lines changed

src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp

+71-112
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77
#include "exceptions.hpp"
88
#include "openvino/frontend/exception.hpp"
99
#include "openvino/op/add.hpp"
10-
#include "openvino/op/broadcast.hpp"
1110
#include "openvino/op/constant.hpp"
12-
#include "openvino/op/convert_like.hpp"
11+
#include "openvino/op/convert.hpp"
1312
#include "openvino/op/matmul.hpp"
1413
#include "openvino/op/multiply.hpp"
15-
#include "openvino/op/shape_of.hpp"
14+
#include "openvino/op/reshape.hpp"
1615
#include "openvino/op/slice.hpp"
1716
#include "openvino/op/subtract.hpp"
18-
#include "openvino/op/transpose.hpp"
1917
#include "utils/common.hpp"
2018
#include "utils/reshape.hpp"
2119

@@ -111,142 +109,103 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
111109
bias.get_partial_shape());
112110
}
113111

112+
ov::Output<ov::Node> mm_output;
114113
{
115114
const auto b_const = ov::as_type_ptr<v0::Constant>(b_quantized.get_node_shared_ptr());
116115

117116
ov::Output<ov::Node> casted_b;
118117
ov::Shape casted_b_shape;
119118
ov::Output<ov::Node> default_zp;
120119
// Casting/converting data of source constant.
121-
// For further calculations (sub and/or multiply) we need to reshape it from [N][n_blocks_per_col][blob_size *
122-
// X] to [N * n_blocks_per_col][blob_size * X] (where X is amount of values in 1 byte) because scale and
123-
// zero_point are represented as: ...with shape like: [N * n_blocks_per_col]...
120+
// For further calculations (sub and/or multiply) we need to reshape
121+
// b -> [N][n_blocks_per_col][block_size]
124122
switch (bits) {
125123
case 2:
126-
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
124+
casted_b_shape = ov::Shape{static_cast<size_t>(N),
125+
static_cast<size_t>(n_blocks_per_col),
126+
static_cast<size_t>(blob_size * 4)};
127127
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
128-
if (a.get_element_type() != ov::element::dynamic) {
129-
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
130-
} else {
131-
default_zp =
132-
std::make_shared<v1::ConvertLike>(a,
133-
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f));
134-
}
128+
default_zp = std::make_shared<v0::Constant>(ov::element::u2, Shape{1}, 2);
135129
break;
136130
case 4:
137-
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
131+
casted_b_shape = ov::Shape{static_cast<size_t>(N),
132+
static_cast<size_t>(n_blocks_per_col),
133+
static_cast<size_t>(blob_size * 2)};
138134
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
139-
if (a.get_element_type() != ov::element::dynamic) {
140-
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
141-
} else {
142-
default_zp =
143-
std::make_shared<v1::ConvertLike>(a,
144-
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f));
145-
}
135+
default_zp = std::make_shared<v0::Constant>(ov::element::u4, Shape{1}, 8);
146136
break;
147137
case 8:
148-
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
149-
casted_b = op::util::reshape(b_const, casted_b_shape);
150-
if (a.get_element_type() != ov::element::dynamic) {
151-
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
152-
} else {
153-
default_zp =
154-
std::make_shared<v1::ConvertLike>(a,
155-
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f));
156-
}
138+
casted_b_shape = ov::Shape{static_cast<size_t>(N),
139+
static_cast<size_t>(n_blocks_per_col),
140+
static_cast<size_t>(blob_size)};
141+
casted_b = std::make_shared<v0::Constant>(ov::element::u8, casted_b_shape, b_const->get_data_ptr());
142+
default_zp = std::make_shared<v0::Constant>(ov::element::u8, Shape{1}, 128);
157143
break;
158144
default:
159145
FRONT_END_THROW("Unsupported bits count");
160146
break;
161147
}
162148

149+
if (!zero_points.get_node_shared_ptr()) {
150+
zero_points = default_zp;
151+
} else {
152+
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits
153+
// according to the link, zero point are:
154+
// Constrain quantized zero point types to uint8/int32/float16/float.
155+
// Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B
156+
zero_points =
157+
op::util::reshape(zero_points,
158+
ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
159+
}
160+
163161
// Possible issue with slice implementation, had to move convertion before slice, instead of slicing uint4
164162
// TODO: Ticket
165-
const auto converted_b = std::make_shared<v1::ConvertLike>(casted_b, a);
163+
// Comments: it is still there, so need to convert b to fp16 first.
166164

167165
// TODO: Need to collect performance data in case constant folding is applied. Possible some perf/mem-gap
168-
169-
// Simple case
170-
if (n_blocks_per_col == 1) {
171-
// Removing unused items in case block is bigger than column count
172-
// For example, if data is (uint8)[1,2,3,4,5,6] then block will be (uint8)[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0].
173-
// And last zeros are unused.
174-
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
175-
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
176-
const auto elements_const =
177-
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
178-
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
179-
const auto slice_b =
180-
std::make_shared<v8::Slice>(converted_b, zero_const, elements_const, one_const, axis_const);
181-
182-
// Transpose matrix
183-
const auto transposed_shape =
184-
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
185-
const auto transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);
186-
187-
// If no zero-points provided - we generate default, depends on data size
188-
if (!zero_points.get_node_shared_ptr()) {
189-
zero_points = default_zp;
190-
}
191-
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);
192-
193-
// Scaling
194-
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);
195-
196-
// Adding bias if required
197-
if (!bias.get_node_shared_ptr()) {
198-
b = scaled_b;
199-
} else {
200-
b = std::make_shared<v1::Add>(scaled_b, bias);
201-
}
166+
// Comments: in this latest code, the const folding is gone, it trigle the oneDNN kernel
167+
// and use u2/u4/u8 weights as the kernel's input, won't do const folding anymore.
168+
169+
// use fp16 for compute
170+
171+
// convert b to fp16
172+
auto converted_b = std::make_shared<v0::Convert>(casted_b, a.get_element_type());
173+
auto converted_zero_points = std::make_shared<v0::Convert>(zero_points, a.get_element_type());
174+
175+
// sub and scale
176+
const auto sub_b = std::make_shared<v1::Subtract>(converted_b, converted_zero_points);
177+
const auto scales_fp16 = std::make_shared<v0::Convert>(scales, a.get_element_type());
178+
const auto scales_reshaped =
179+
op::util::reshape(scales_fp16, ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
180+
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales_reshaped);
181+
182+
// reshape b to [N, K]
183+
auto shape_b = v0::Constant::create(ov::element::i32, ov::Shape{2}, {0, -1});
184+
auto reshaped_b = std::make_shared<v1::Reshape>(scaled_b, shape_b, true);
185+
186+
// if n_blocks_per_col*blob_size*X != K
187+
// need slice it to K
188+
// to produce b = [N, K]
189+
const bool slice_needed = (K % block_size != 0);
190+
if (slice_needed) {
191+
const auto zero = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
192+
const auto one = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
193+
const auto elements = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
194+
const auto axis = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
195+
b = std::make_shared<v8::Slice>(reshaped_b, zero, elements, one, axis);
202196
} else {
203-
// Transpose matrix. Quantized B matrix is transposed and has a shape [N,K].
204-
// To apply further operations on it which operand's shape is [N] we do this
205-
// transpose to have a matrix [K,N]...
206-
const auto transposed_shape =
207-
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
208-
ov::Output<ov::Node> transposed_b = std::make_shared<v1::Transpose>(converted_b, transposed_shape);
209-
210-
// If no zero-points provided - we generate default, depends on data size
211-
if (!zero_points.get_node_shared_ptr()) {
212-
zero_points = default_zp;
213-
}
214-
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);
215-
216-
// Scaling
217-
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);
218-
219-
// Transpose again to make reshaping and slicing
220-
transposed_b = std::make_shared<v1::Transpose>(scaled_b, transposed_shape);
221-
222-
const auto reshaped_b =
223-
op::util::reshape(transposed_b,
224-
ov::Shape{static_cast<size_t>(casted_b_shape[0] / n_blocks_per_col),
225-
static_cast<size_t>(casted_b_shape[1] * n_blocks_per_col)});
226-
227-
// Removing unused items in case block is bigger than column count (see description for
228-
// Slice above)
229-
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
230-
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
231-
const auto elements_const =
232-
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
233-
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
234-
const auto slice_b =
235-
std::make_shared<v8::Slice>(reshaped_b, zero_const, elements_const, one_const, axis_const);
236-
237-
// Adding bias if required
238-
if (!bias.get_node_shared_ptr()) {
239-
return {std::make_shared<v0::MatMul>(a, slice_b, false, true)};
240-
} else {
241-
// Transpose again
242-
transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);
243-
244-
b = std::make_shared<v1::Add>(transposed_b, bias);
245-
}
197+
b = reshaped_b;
246198
}
199+
200+
// mm = matmul(a,b)
201+
mm_output = std::make_shared<v0::MatMul>(a, b, false, true);
247202
}
248203

249-
return {std::make_shared<v0::MatMul>(a, b)};
204+
if (bias.get_node_shared_ptr()) {
205+
return {std::make_shared<v1::Add>(mm_output, bias)};
206+
} else {
207+
return {mm_output};
208+
}
250209
}
251210

252211
ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICROSOFT_DOMAIN);
@@ -255,4 +214,4 @@ ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICR
255214
} // namespace com_microsoft
256215
} // namespace onnx
257216
} // namespace frontend
258-
} // namespace ov
217+
} // namespace ov

0 commit comments

Comments
 (0)