11
11
#include " openvino/op/unsqueeze.hpp"
12
12
#include " openvino/pass/manager.hpp"
13
13
#include " openvino/pass/pattern/op/optional.hpp"
14
+ #include " openvino/pass/pattern/op/pattern.hpp"
14
15
#include " openvino/pass/pattern/op/wrap_type.hpp"
15
16
#include " transformations/rt_info/dequantization_node.hpp"
16
17
#include " transformations/rt_info/disable_constant_folding.hpp"
@@ -23,9 +24,13 @@ using namespace ov::pass::pattern;
23
24
24
25
namespace {
25
26
26
- bool check_precision (const ov::element::Type_t type_to_check, const ov::element::TypeVector& precisions) {
27
- return std::find (precisions.begin (), precisions.end (), type_to_check) != precisions.end ();
28
- };
27
+ ov::pass::pattern::op::Predicate check_precision (const ov::element::TypeVector& precisions) {
28
+ return ov::pass::pattern::op::Predicate (
29
+ [=](const Output<Node>& output) -> bool {
30
+ return std::find (precisions.begin (), precisions.end (), output.get_element_type ()) != precisions.end ();
31
+ },
32
+ " check_precision" );
33
+ }
29
34
30
35
using RTInfoSetter = std::function<void (const std::shared_ptr<ov::Node>& node)>;
31
36
void set_rt_info (const PatternValueMap& pt_map,
@@ -35,10 +40,9 @@ void set_rt_info(const PatternValueMap& pt_map,
35
40
for (const auto & pattern_node : pattern_nodes) {
36
41
if (pt_map.count (pattern_node)) {
37
42
auto node = pt_map.at (pattern_node).get_node_shared_ptr ();
38
-
39
43
// we don't need to mark Converts with disable_cf attribute if the `from` type (input type)
40
44
// is not in the `precisions` list.
41
- if (ov::as_type_ptr<v0::Convert>(node) && !check_precision (node->get_input_element_type (0 ), precisions )) {
45
+ if (ov::as_type_ptr<v0::Convert>(node) && !check_precision (precisions)( node->input_value (0 ))) {
42
46
continue ;
43
47
}
44
48
@@ -196,7 +200,7 @@ ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& prec
196
200
MATCHER_SCOPE (MarkDequantization);
197
201
198
202
// data input:
199
- auto input_pattern = any_input ();
203
+ auto input_pattern = any_input (check_precision (precisions) );
200
204
auto convert_pattern = wrap_type<v0::Convert>({input_pattern}, consumers_count (1 ));
201
205
202
206
// zero points:
@@ -217,7 +221,7 @@ ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& prec
217
221
auto input = pt_map.at (input_pattern);
218
222
const auto multiply = m.get_match_root ();
219
223
220
- if (! check_precision (input. get_element_type (), precisions) || transformation_callback (multiply)) {
224
+ if (transformation_callback (multiply)) {
221
225
return false ;
222
226
}
223
227
@@ -290,8 +294,7 @@ ov::pass::KeepConstPrecision::KeepConstPrecision(const element::TypeVector& prec
290
294
for (const auto & pattern_node : keep_const_precisions) {
291
295
if (pt_map.count (pattern_node.first )) {
292
296
auto node = pt_map.at (pattern_node.first ).get_node_shared_ptr ();
293
- const auto & precision = node->get_output_element_type (0 );
294
- if (ov::as_type_ptr<v0::Constant>(node) && check_precision (precision, precisions)) {
297
+ if (ov::as_type_ptr<v0::Constant>(node) && check_precision (precisions)(node->output (0 ))) {
295
298
if (pattern_node.second ) {
296
299
ov::disable_keep_const_precision (node);
297
300
} else {
0 commit comments