Skip to content

Commit 2efd2b1

Browse files
authored
[ONNX] Added support for dynamic input shapes in com.microsoft.MatMulNBits (openvinotoolkit#26898)
### Details: - Added option to receive an input with dynamic shape, shape must be calculated later while shape inference ### Tickets: - N/A
1 parent 9c65ba2 commit 2efd2b1

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp

+27-6
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
5959
"Expected rank of quantized weights is 3 [N][n_blocks_per_col][blob_size], got: ",
6060
b_quantized.get_partial_shape().rank());
6161
CHECK_VALID_NODE(node,
62-
a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32,
63-
"Unsupported input A type, accepted FP16, FP32, got: ",
62+
a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32 ||
63+
a.get_element_type() == ov::element::dynamic,
64+
"Unsupported input A type, accepted dynamic, FP16, FP32, got: ",
6465
a.get_element_type());
6566
CHECK_VALID_NODE(
6667
node,
@@ -96,7 +97,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
9697
if (inputs.size() > 5) {
9798
bias = inputs[5];
9899
CHECK_VALID_NODE(node,
99-
bias.get_element_type() == a.get_element_type(),
100+
bias.get_element_type() == a.get_element_type() ||
101+
a.get_element_type() == ov::element::dynamic ||
102+
bias.get_element_type() == ov::element::dynamic,
100103
"Unsupported input bias type, must be equal to input A type, got: ",
101104
bias.get_element_type());
102105
CHECK_VALID_NODE(node,
@@ -121,17 +124,35 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
121124
case 2:
122125
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
123126
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
124-
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
127+
if (a.get_element_type() != ov::element::dynamic) {
128+
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
129+
} else {
130+
default_zp =
131+
std::make_shared<v1::ConvertLike>(a,
132+
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f));
133+
}
125134
break;
126135
case 4:
127136
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
128137
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
129-
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
138+
if (a.get_element_type() != ov::element::dynamic) {
139+
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
140+
} else {
141+
default_zp =
142+
std::make_shared<v1::ConvertLike>(a,
143+
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f));
144+
}
130145
break;
131146
case 8:
132147
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
133148
casted_b = op::util::reshape(b_const, casted_b_shape);
134-
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
149+
if (a.get_element_type() != ov::element::dynamic) {
150+
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
151+
} else {
152+
default_zp =
153+
std::make_shared<v1::ConvertLike>(a,
154+
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f));
155+
}
135156
break;
136157
default:
137158
FRONT_END_THROW("Unsupported bits count");

0 commit comments

Comments
 (0)