Skip to content

Commit f8d0710

Browse files
Fix convert_to_supported_precision for TypeRelaxed types (openvinotoolkit#23143)
If TypeRelaxed's origin input type is undefined let's temporarily override it with original input precision attribute value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)). If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the inputs have mismatched types. Ticket: CVS-134604
1 parent a3c7e15 commit f8d0710

File tree

4 files changed

+121
-44
lines changed

4 files changed

+121
-44
lines changed

src/core/dev_api/openvino/core/constant_fold_utils.hpp

+15-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@ OPENVINO_API
1414
const element::TypeVector& unsupported_types();
1515

1616
OPENVINO_API
17-
bool is_type_unsupported(const ov::element::Type& type);
17+
bool is_type_unsupported(const element::Type& type);
18+
19+
OPENVINO_API
20+
void save_original_input_precisions(const std::shared_ptr<Node>& node);
21+
22+
OPENVINO_API
23+
bool has_original_input_precision(const Input<Node>& input);
24+
25+
OPENVINO_API
26+
element::Type get_original_input_precision(const Input<Node>& input);
27+
28+
OPENVINO_API
29+
void remove_original_input_precision_attribute(Input<Node>& input);
1830

1931
OPENVINO_API bool node_requires_precision_conversion(const Node* const node);
2032

@@ -25,9 +37,9 @@ OPENVINO_API bool node_requires_precision_conversion(const Node* const node);
2537
/// \param node
2638
///
2739
/// \return New node with f32 inputs if the inputs require conversion or the input node otherwise
28-
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(const Node* const node);
40+
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(Node* const node);
2941

30-
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(const Node* const node, const OutputVector& inputs);
42+
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(Node* const node, const OutputVector& inputs);
3143

3244
OPENVINO_API bool evaluate_node_with_unsupported_precision(const Node* node,
3345
TensorVector& outputs,

src/core/src/constant_fold_utils.cpp

+72-13
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,29 @@ bool ov::util::is_type_unsupported(const ov::element::Type& type) {
2525
return std::find(unsupported_types.begin(), unsupported_types.end(), type) != unsupported_types.end();
2626
}
2727

28+
void ov::util::save_original_input_precisions(const std::shared_ptr<ov::Node>& node) {
29+
for (size_t i = 0; i < node->get_input_size(); i++) {
30+
auto input = node->input(i);
31+
input.get_rt_info()["original_precision"] = input.get_element_type();
32+
}
33+
}
34+
35+
bool ov::util::has_original_input_precision(const ov::Input<ov::Node>& input) {
36+
return input.get_rt_info().count("original_precision") > 0;
37+
}
38+
39+
ov::element::Type ov::util::get_original_input_precision(const ov::Input<ov::Node>& input) {
40+
return input.get_rt_info().at("original_precision").as<ov::element::Type>();
41+
}
42+
43+
void ov::util::remove_original_input_precision_attribute(ov::Input<ov::Node>& input) {
44+
auto& rt_info = input.get_rt_info();
45+
auto it = rt_info.find("original_precision");
46+
if (it != rt_info.end()) {
47+
rt_info.erase(it);
48+
}
49+
}
50+
2851
namespace {
2952

3053
template <typename... Args>
@@ -105,11 +128,11 @@ static const std::unordered_map<ov::NodeTypeInfo, std::function<bool(const std::
105128
{ov::op::v4::Range::get_type_info_static(), convert_range_precision},
106129
};
107130

108-
std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* const node) {
131+
std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(Node* const node) {
109132
return ov::util::convert_to_supported_precision(node, node->input_values());
110133
}
111134

112-
std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* const node, const OutputVector& inputs) {
135+
std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(Node* const node, const OutputVector& inputs) {
113136
size_t num_inputs = node->get_input_size();
114137
OutputVector converted_inputs;
115138
converted_inputs.reserve(num_inputs);
@@ -128,23 +151,49 @@ std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* c
128151
}
129152
}
130153

131-
// Create a new node with new (converted) inputs.
132-
auto cloned_node = node->clone_with_new_inputs(converted_inputs);
154+
std::shared_ptr<Node> cloned_node;
155+
156+
auto type_relaxed = dynamic_cast<op::TypeRelaxedBase*>(node);
157+
if (type_relaxed != nullptr) {
158+
// Save TypeRelaxed's origin input types
159+
// If origin input type is undefined let's temporarily override it with original input precision attribute
160+
// value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)).
161+
// If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a
162+
// clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the
163+
// inputs have mismatched types.
164+
element::TypeVector origin_input_types;
165+
origin_input_types.reserve(num_inputs);
166+
for (size_t i = 0; i < num_inputs; i++) {
167+
const auto& origin_type = type_relaxed->get_origin_input_type(i);
168+
origin_input_types.push_back(origin_type);
169+
if (origin_type == element::undefined && has_original_input_precision(node->input(i))) {
170+
type_relaxed->set_origin_input_type(get_original_input_precision(node->input(i)), i);
171+
}
172+
}
173+
174+
cloned_node = node->clone_with_new_inputs(converted_inputs);
175+
176+
// Restore TypeRelaxed's origin input types
177+
for (size_t i = 0; i < num_inputs; i++) {
178+
type_relaxed->set_origin_input_type(origin_input_types[i], i);
179+
}
133180

134-
// Override TypeRelaxed types
135-
auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(cloned_node);
136-
if (type_relaxed) {
181+
auto cloned_type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(cloned_node);
182+
// Override TypeRelaxed types
137183
for (size_t i = 0; i < num_inputs; i++) {
138-
if (ov::util::is_type_unsupported(type_relaxed->get_origin_input_type(i))) {
139-
type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i);
184+
if (ov::util::is_type_unsupported(cloned_type_relaxed->get_origin_input_type(i))) {
185+
cloned_type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i);
140186
}
141187
}
142188
for (size_t i = 0; i < cloned_node->get_output_size(); i++) {
143189
if (ov::util::is_type_unsupported(cloned_node->get_output_element_type(i))) {
144-
type_relaxed->set_overridden_output_type(element::f32, i);
190+
cloned_type_relaxed->set_overridden_output_type(element::f32, i);
145191
}
146192
}
147193
cloned_node->validate_and_infer_types();
194+
} else {
195+
// Create a new node with new (converted) inputs.
196+
cloned_node = node->clone_with_new_inputs(converted_inputs);
148197
}
149198

150199
// Handle nodes which outputs precisions don't depend on input precisions
@@ -221,9 +270,19 @@ bool ov::util::evaluate_node_with_unsupported_precision(const ov::Node* node,
221270
}
222271
}
223272

224-
// evaluate converted node
225-
if (!node->evaluate(converted_output_tensors, converted_input_tensors)) {
226-
return false;
273+
auto type_relaxed = dynamic_cast<const op::TypeRelaxedBase*>(node);
274+
if (type_relaxed == nullptr) {
275+
// evaluate node with converted tensors
276+
if (!node->evaluate(converted_output_tensors, converted_input_tensors)) {
277+
return false;
278+
}
279+
} else {
280+
// node is const so let's clone it
281+
auto cloned = node->clone_with_new_inputs(node->input_values());
282+
cloned = convert_to_supported_precision(cloned.get());
283+
if (!cloned->evaluate(converted_output_tensors, converted_input_tensors)) {
284+
return false;
285+
}
227286
}
228287

229288
// convert outputs tensors from f32 to original type if necessary

src/core/src/pass/constant_folding.cpp

+5-28
Original file line numberDiff line numberDiff line change
@@ -49,42 +49,19 @@ const auto friendly_name_from = [](const ov::Node& node, const size_t output_cou
4949
}
5050
};
5151

52-
static void save_original_input_precisions(const std::shared_ptr<ov::Node>& node) {
53-
for (size_t i = 0; i < node->get_input_size(); i++) {
54-
auto input = node->input(i);
55-
input.get_rt_info()["original_precision"] = input.get_element_type();
56-
}
57-
}
58-
59-
static bool has_original_input_precision(const ov::Input<ov::Node>& input) {
60-
return input.get_rt_info().count("original_precision") > 0;
61-
}
62-
63-
static ov::element::Type get_original_input_precision(const ov::Input<ov::Node>& input) {
64-
return input.get_rt_info().at("original_precision").as<ov::element::Type>();
65-
}
66-
67-
static void remove_original_input_precision_attribute(ov::Input<ov::Node>& input) {
68-
auto& rt_info = input.get_rt_info();
69-
auto it = rt_info.find("original_precision");
70-
if (it != rt_info.end()) {
71-
rt_info.erase(it);
72-
}
73-
}
74-
7552
static bool restore_original_input_precision(const std::shared_ptr<ov::Node>& node) {
7653
bool restored = false;
7754
if (ov::is_type<ov::op::v0::Convert>(node)) {
7855
auto input = node->input(0);
79-
remove_original_input_precision_attribute(input);
56+
ov::util::remove_original_input_precision_attribute(input);
8057
return restored;
8158
}
8259
for (size_t i = 0; i < node->get_input_size(); i++) {
8360
auto input = node->input(i);
84-
if (!has_original_input_precision(input))
61+
if (!ov::util::has_original_input_precision(input))
8562
continue;
86-
const auto original_type = get_original_input_precision(input);
87-
remove_original_input_precision_attribute(input);
63+
const auto original_type = ov::util::get_original_input_precision(input);
64+
ov::util::remove_original_input_precision_attribute(input);
8865
if (original_type != node->get_input_element_type(i)) {
8966
auto convert = std::make_shared<ov::op::v0::Convert>(node->input_value(i), original_type);
9067
ov::OutputVector replacements(1);
@@ -206,7 +183,7 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_
206183
// we need to convert constants with those types to f32. And at some point - this f32 constant may
207184
// become an input to a node that's not constfoldable. Then we need to convert that constant back to
208185
// that input's original precision.
209-
save_original_input_precisions(node);
186+
util::save_original_input_precisions(node);
210187
if (!node_has_disabled_constant_folding && util::node_requires_precision_conversion(node.get())) {
211188
mark_node_requires_precision_conversion(node);
212189
}

src/core/tests/pass/constant_folding.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "openvino/op/convert_like.hpp"
1717
#include "openvino/op/loop.hpp"
1818
#include "openvino/op/multiply.hpp"
19+
#include "ov_ops/type_relaxed.hpp"
1920
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
2021
#include "transformations/utils/utils.hpp"
2122

@@ -4000,6 +4001,34 @@ TEST_P(UnsupportedTypesTest, convert_like) {
40004001
ASSERT_EQ(m->get_results().size(), 1);
40014002
}
40024003

4004+
TEST_P(UnsupportedTypesTest, type_relaxed) {
4005+
Shape shape_in{2, 4, 1};
4006+
4007+
const auto& type = GetParam();
4008+
auto cond = op::v0::Constant::create(element::boolean, shape_in, {1});
4009+
auto param = std::make_shared<op::v0::Parameter>(type, shape_in);
4010+
auto constant1 = op::v0::Constant::create(type, shape_in, {2});
4011+
auto then_value = std::make_shared<op::v0::Concat>(OutputVector{param, constant1}, 2);
4012+
auto constant2 = op::v0::Constant::create(type, shape_in, {3});
4013+
auto else_value = std::make_shared<op::v3::Broadcast>(
4014+
constant2,
4015+
op::v0::Constant::create(element::u64, Shape{shape_in.size()}, Shape{shape_in[0], shape_in[1], 2}));
4016+
auto select = make_shared<op::v1::Select>(cond, then_value, else_value);
4017+
auto type_relaxed = make_shared<op::TypeRelaxed<op::v1::Select>>(*select,
4018+
element::TypeVector{element::boolean},
4019+
element::TypeVector{});
4020+
auto m = make_shared<Model>(type_relaxed, ParameterVector{param});
4021+
4022+
run_constant_folding(m);
4023+
4024+
EXPECT_EQ(m->get_ops().size(), 7);
4025+
EXPECT_EQ(count_ops_of_type<op::v1::Select>(m), 1);
4026+
EXPECT_EQ(count_ops_of_type<op::v0::Constant>(m), 3);
4027+
EXPECT_EQ(count_ops_of_type<op::v3::Broadcast>(m), 0);
4028+
EXPECT_EQ(count_ops_of_type<op::v0::Concat>(m), 1);
4029+
ASSERT_EQ(m->get_results().size(), 1);
4030+
}
4031+
40034032
static std::string unsupported_types_test_case_name(const testing::TestParamInfo<element::Type>& info) {
40044033
return info.param.get_type_name();
40054034
}

0 commit comments

Comments
 (0)