Skip to content

Commit ab23914

Browse files
committed
Resolve NodePredicate ambiguity
Signed-off-by: Evgeniia Nugmanova <evgeniia.nugmanova@intel.com>
1 parent ba52522 commit ab23914

File tree

8 files changed

+34
-19
lines changed

8 files changed

+34
-19
lines changed

src/core/include/openvino/pass/pattern/op/any.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ class OPENVINO_API Any : public Pattern {
2020
: Pattern(wrapped_values, Predicate(pred)) {
2121
set_output_type(0, type, s);
2222
}
23-
Any(const element::Type& type, const PartialShape& s, const NodePredicate& pred, const NodeVector& wrapped_values)
23+
template <typename TPredicate>
24+
Any(const element::Type& type, const PartialShape& s, const TPredicate& pred, const NodeVector& wrapped_values)
2425
: Any(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
2526
/// \brief creates a Any node containing a sub-pattern described by the type and
2627
/// shape of \sa node.
2728
template <typename TPredicate>
2829
Any(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
2930
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
30-
Any(const Output<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values)
31+
template <typename TPredicate>
32+
Any(const Output<Node>& node, const TPredicate& pred, const NodeVector& wrapped_values)
3133
: Any(node, Predicate(pred), as_output_vector(wrapped_values)) {}
3234

3335
bool match_value(pattern::Matcher* matcher,

src/core/include/openvino/pass/pattern/op/any_of.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@ class OPENVINO_API AnyOf : public Pattern {
2828
}
2929
set_output_type(0, type, s);
3030
}
31-
AnyOf(const element::Type& type, const PartialShape& s, const NodePredicate& pred, const NodeVector& wrapped_values)
31+
template <typename TPredicate>
32+
AnyOf(const element::Type& type, const PartialShape& s, const TPredicate& pred, const NodeVector& wrapped_values)
3233
: AnyOf(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
3334

3435
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
3536
/// shape of \sa node.
3637
template <typename TPredicate>
3738
AnyOf(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
3839
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
39-
AnyOf(const std::shared_ptr<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values)
40+
template <typename TPredicate>
41+
AnyOf(const std::shared_ptr<Node>& node, const TPredicate& pred, const NodeVector& wrapped_values)
4042
: AnyOf(node, Predicate(pred), as_output_vector(wrapped_values)) {}
4143
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
4244
};

src/core/include/openvino/pass/pattern/op/label.hpp

+5-11
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,14 @@ class OPENVINO_API Label : public Pattern {
3434
/// nullptr,
3535
/// OutputVector{add});
3636
/// \endcode
37-
template <typename TPredicate>
37+
template <typename TPredicate, typename TArg = OutputVector>
3838
Label(const element::Type& type,
3939
const PartialShape& s,
4040
const TPredicate& pred,
41-
const OutputVector& wrapped_values = {})
41+
const TArg& wrapped_values = OutputVector{})
4242
: Pattern(OutputVector{wrap_values(wrapped_values)}, Predicate(pred)) {
4343
set_output_type(0, type, s);
4444
}
45-
Label(const element::Type& type,
46-
const PartialShape& s,
47-
const NodePredicate& pred,
48-
const NodeVector& wrapped_values = {})
49-
: Label(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
5045

5146
/// \brief creates a Label node containing a sub-pattern described by the type and
5247
/// shape of \sa node.
@@ -60,11 +55,9 @@ class OPENVINO_API Label : public Pattern {
6055
/// nullptr,
6156
/// OutputVector{add});
6257
/// \endcode
63-
template <typename TPredicate>
64-
Label(const Output<Node>& value, const TPredicate& pred, const OutputVector& wrapped_values = {})
58+
template <typename TPredicate, typename TArg = OutputVector>
59+
Label(const Output<Node>& value, const TPredicate& pred, const TArg& wrapped_values = OutputVector{})
6560
: Label(value.get_element_type(), value.get_partial_shape(), Predicate(pred), wrapped_values) {}
66-
Label(const Output<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values = {})
67-
: Label(node.get_element_type(), node.get_partial_shape(), Predicate(pred), as_output_vector(wrapped_values)) {}
6861

6962
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
7063
: Label(type, s, nullptr, OutputVector{}) {}
@@ -75,6 +68,7 @@ class OPENVINO_API Label : public Pattern {
7568

7669
protected:
7770
static Output<Node> wrap_values(const OutputVector& wrapped_values);
71+
static Output<Node> wrap_values(const NodeVector& wrapped_values);
7872
};
7973
} // namespace op
8074

src/core/include/openvino/pass/pattern/op/pattern.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,12 @@ OPENVINO_API Predicate as_value_predicate(NodePredicate pred);
6969

7070
class OPENVINO_API Pattern : public Node {
7171
public:
72+
Pattern();
73+
explicit Pattern(const OutputVector& patterns);
74+
explicit Pattern(const NodeVector& patterns);
7275
/// \brief A base class for all the utility operators used to describe a pattern to match
7376
Pattern(const OutputVector& patterns, const Predicate& pred);
74-
Pattern(const OutputVector& patterns);
77+
Pattern(const NodeVector& patterns, const Predicate& pred);
7578

7679
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
7780
OPENVINO_THROW("Uncopyable");

src/core/include/openvino/pass/pattern/op/true.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class OPENVINO_API True : public Pattern {
1313
public:
1414
OPENVINO_RTTI("patternTrue");
1515
/// \brief Always matches, does not add node to match list.
16-
True() : Pattern(OutputVector{}) {}
16+
True() : Pattern() {}
1717
bool match_value(pattern::Matcher* matcher,
1818
const Output<Node>& pattern_value,
1919
const Output<Node>& graph_value) override;

src/core/include/openvino/pass/pattern/op/wrap_type.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class OPENVINO_API WrapType : public Pattern {
1414
public:
1515
OPENVINO_RTTI("WrapType");
1616

17-
explicit WrapType(const std::vector<NodeTypeInfo>& wrapped_types) : Pattern({}), m_wrapped_types(wrapped_types) {
17+
explicit WrapType(const std::vector<NodeTypeInfo>& wrapped_types) : Pattern(), m_wrapped_types(wrapped_types) {
1818
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
1919
}
2020

@@ -27,7 +27,7 @@ class OPENVINO_API WrapType : public Pattern {
2727
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
2828
}
2929

30-
explicit WrapType(NodeTypeInfo wrapped_type) : Pattern({}), m_wrapped_types({wrapped_type}) {
30+
explicit WrapType(NodeTypeInfo wrapped_type) : Pattern(), m_wrapped_types({wrapped_type}) {
3131
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
3232
}
3333

src/core/src/pattern/op/label.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::OutputV
2121
}
2222
}
2323

24+
ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::NodeVector& wrapped_values) {
25+
switch (wrapped_values.size()) {
26+
case 0:
27+
return std::make_shared<pattern::op::True>()->output(0);
28+
case 1:
29+
return wrapped_values[0];
30+
default:
31+
return std::make_shared<pattern::op::Or>(as_output_vector(wrapped_values))->output(0);
32+
}
33+
}
34+
2435
bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher,
2536
const ov::Output<ov::Node>& pattern_value,
2637
const ov::Output<ov::Node>& graph_value) {

src/core/src/pattern/op/pattern.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
namespace ov::pass::pattern {
1313
namespace op {
1414

15+
Pattern::Pattern() : Node(), m_predicate() {}
1516
Pattern::Pattern(const OutputVector& patterns) : Node(patterns), m_predicate() {}
17+
Pattern::Pattern(const NodeVector& patterns) : Pattern(as_output_vector(patterns)) {}
1618

1719
Pattern::Pattern(const OutputVector& patterns, const op::Predicate& pred) : Node(patterns), m_predicate(pred) {}
20+
Pattern::Pattern(const NodeVector& patterns, const op::Predicate& pred) : Pattern(as_output_vector(patterns), pred) {}
1821

1922
Predicate as_value_predicate(NodePredicate pred) {
2023
return Predicate(pred);

0 commit comments

Comments
 (0)