@@ -59,8 +59,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
59
59
" Expected rank of quantized weights is 3 [N][n_blocks_per_col][blob_size], got: " ,
60
60
b_quantized.get_partial_shape ().rank ());
61
61
CHECK_VALID_NODE (node,
62
- a.get_element_type () == ov::element::f16 || a.get_element_type () == ov::element::f32,
63
- " Unsupported input A type, accepted FP16, FP32, got: " ,
62
+ a.get_element_type () == ov::element::f16 || a.get_element_type () == ov::element::f32 ||
63
+ a.get_element_type () == ov::element::dynamic,
64
+ " Unsupported input A type, accepted dynamic, FP16, FP32, got: " ,
64
65
a.get_element_type ());
65
66
CHECK_VALID_NODE (
66
67
node,
@@ -96,7 +97,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
96
97
if (inputs.size () > 5 ) {
97
98
bias = inputs[5 ];
98
99
CHECK_VALID_NODE (node,
99
- bias.get_element_type () == a.get_element_type (),
100
+ bias.get_element_type () == a.get_element_type () ||
101
+ a.get_element_type () == ov::element::dynamic ||
102
+ bias.get_element_type () == ov::element::dynamic,
100
103
" Unsupported input bias type, must be equal to input A type, got: " ,
101
104
bias.get_element_type ());
102
105
CHECK_VALID_NODE (node,
@@ -121,17 +124,35 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
121
124
case 2 :
122
125
casted_b_shape = ov::Shape{static_cast <size_t >(N * n_blocks_per_col), static_cast <size_t >(blob_size * 4 )};
123
126
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr ());
124
- default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 2 );
127
+ if (a.get_element_type () != ov::element::dynamic) {
128
+ default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 2 );
129
+ } else {
130
+ default_zp =
131
+ std::make_shared<v1::ConvertLike>(a,
132
+ std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2 .f ));
133
+ }
125
134
break ;
126
135
case 4 :
127
136
casted_b_shape = ov::Shape{static_cast <size_t >(N * n_blocks_per_col), static_cast <size_t >(blob_size * 2 )};
128
137
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr ());
129
- default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 8 );
138
+ if (a.get_element_type () != ov::element::dynamic) {
139
+ default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 8 );
140
+ } else {
141
+ default_zp =
142
+ std::make_shared<v1::ConvertLike>(a,
143
+ std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8 .f ));
144
+ }
130
145
break ;
131
146
case 8 :
132
147
casted_b_shape = ov::Shape{static_cast <size_t >(N * n_blocks_per_col), static_cast <size_t >(blob_size)};
133
148
casted_b = op::util::reshape (b_const, casted_b_shape);
134
- default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 128 );
149
+ if (a.get_element_type () != ov::element::dynamic) {
150
+ default_zp = std::make_shared<v0::Constant>(a.get_element_type (), Shape{}, 128 );
151
+ } else {
152
+ default_zp =
153
+ std::make_shared<v1::ConvertLike>(a,
154
+ std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128 .f ));
155
+ }
135
156
break ;
136
157
default :
137
158
FRONT_END_THROW (" Unsupported bits count" );
0 commit comments