@@ -25,6 +25,29 @@ bool ov::util::is_type_unsupported(const ov::element::Type& type) {
25
25
return std::find (unsupported_types.begin (), unsupported_types.end (), type) != unsupported_types.end ();
26
26
}
27
27
28
+ void ov::util::save_original_input_precisions (const std::shared_ptr<ov::Node>& node) {
29
+ for (size_t i = 0 ; i < node->get_input_size (); i++) {
30
+ auto input = node->input (i);
31
+ input.get_rt_info ()[" original_precision" ] = input.get_element_type ();
32
+ }
33
+ }
34
+
35
+ bool ov::util::has_original_input_precision (const ov::Input<ov::Node>& input) {
36
+ return input.get_rt_info ().count (" original_precision" ) > 0 ;
37
+ }
38
+
39
+ ov::element::Type ov::util::get_original_input_precision (const ov::Input<ov::Node>& input) {
40
+ return input.get_rt_info ().at (" original_precision" ).as <ov::element::Type>();
41
+ }
42
+
43
+ void ov::util::remove_original_input_precision_attribute (ov::Input<ov::Node>& input) {
44
+ auto & rt_info = input.get_rt_info ();
45
+ auto it = rt_info.find (" original_precision" );
46
+ if (it != rt_info.end ()) {
47
+ rt_info.erase (it);
48
+ }
49
+ }
50
+
28
51
namespace {
29
52
30
53
template <typename ... Args>
@@ -105,11 +128,11 @@ static const std::unordered_map<ov::NodeTypeInfo, std::function<bool(const std::
105
128
{ov::op::v4::Range::get_type_info_static (), convert_range_precision},
106
129
};
107
130
108
- std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision (const Node* const node) {
131
+ std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision (Node* const node) {
109
132
return ov::util::convert_to_supported_precision (node, node->input_values ());
110
133
}
111
134
112
- std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision (const Node* const node, const OutputVector& inputs) {
135
+ std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision (Node* const node, const OutputVector& inputs) {
113
136
size_t num_inputs = node->get_input_size ();
114
137
OutputVector converted_inputs;
115
138
converted_inputs.reserve (num_inputs);
@@ -128,23 +151,49 @@ std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* c
128
151
}
129
152
}
130
153
131
- // Create a new node with new (converted) inputs.
132
- auto cloned_node = node->clone_with_new_inputs (converted_inputs);
154
+ std::shared_ptr<Node> cloned_node;
155
+
156
+ auto type_relaxed = dynamic_cast <op::TypeRelaxedBase*>(node);
157
+ if (type_relaxed != nullptr ) {
158
+ // Save TypeRelaxed's origin input types
159
+ // If origin input type is undefined let's temporarily override it with original input precision attribute
160
+ // value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)).
161
+ // If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a
162
+ // clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the
163
+ // inputs have mismatched types.
164
+ element::TypeVector origin_input_types;
165
+ origin_input_types.reserve (num_inputs);
166
+ for (size_t i = 0 ; i < num_inputs; i++) {
167
+ const auto & origin_type = type_relaxed->get_origin_input_type (i);
168
+ origin_input_types.push_back (origin_type);
169
+ if (origin_type == element::undefined && has_original_input_precision (node->input (i))) {
170
+ type_relaxed->set_origin_input_type (get_original_input_precision (node->input (i)), i);
171
+ }
172
+ }
173
+
174
+ cloned_node = node->clone_with_new_inputs (converted_inputs);
175
+
176
+ // Restore TypeRelaxed's origin input types
177
+ for (size_t i = 0 ; i < num_inputs; i++) {
178
+ type_relaxed->set_origin_input_type (origin_input_types[i], i);
179
+ }
133
180
134
- // Override TypeRelaxed types
135
- auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(cloned_node);
136
- if (type_relaxed) {
181
+ auto cloned_type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(cloned_node);
182
+ // Override TypeRelaxed types
137
183
for (size_t i = 0 ; i < num_inputs; i++) {
138
- if (ov::util::is_type_unsupported (type_relaxed ->get_origin_input_type (i))) {
139
- type_relaxed ->set_origin_input_type (cloned_node->get_input_element_type (i), i);
184
+ if (ov::util::is_type_unsupported (cloned_type_relaxed ->get_origin_input_type (i))) {
185
+ cloned_type_relaxed ->set_origin_input_type (cloned_node->get_input_element_type (i), i);
140
186
}
141
187
}
142
188
for (size_t i = 0 ; i < cloned_node->get_output_size (); i++) {
143
189
if (ov::util::is_type_unsupported (cloned_node->get_output_element_type (i))) {
144
- type_relaxed ->set_overridden_output_type (element::f32, i);
190
+ cloned_type_relaxed ->set_overridden_output_type (element::f32, i);
145
191
}
146
192
}
147
193
cloned_node->validate_and_infer_types ();
194
+ } else {
195
+ // Create a new node with new (converted) inputs.
196
+ cloned_node = node->clone_with_new_inputs (converted_inputs);
148
197
}
149
198
150
199
// Handle nodes which outputs precisions don't depend on input precisions
@@ -221,9 +270,19 @@ bool ov::util::evaluate_node_with_unsupported_precision(const ov::Node* node,
221
270
}
222
271
}
223
272
224
- // evaluate converted node
225
- if (!node->evaluate (converted_output_tensors, converted_input_tensors)) {
226
- return false ;
273
+ auto type_relaxed = dynamic_cast <const op::TypeRelaxedBase*>(node);
274
+ if (type_relaxed == nullptr ) {
275
+ // evaluate node with converted tensors
276
+ if (!node->evaluate (converted_output_tensors, converted_input_tensors)) {
277
+ return false ;
278
+ }
279
+ } else {
280
+ // node is const so let's clone it
281
+ auto cloned = node->clone_with_new_inputs (node->input_values ());
282
+ cloned = convert_to_supported_precision (cloned.get ());
283
+ if (!cloned->evaluate (converted_output_tensors, converted_input_tensors)) {
284
+ return false ;
285
+ }
227
286
}
228
287
229
288
// convert outputs tensors from f32 to original type if necessary
0 commit comments