Skip to content

Commit 113efdd

Browse files
committed
refactor
Signed-off-by: Evgeniia Nugmanova <evgeniia.nugmanova@intel.com>
1 parent 97a4c9d commit 113efdd

File tree

4 files changed

+43
-48
lines changed

4 files changed

+43
-48
lines changed

src/core/include/openvino/pass/pattern/op/any.hpp

+5-10
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,17 @@ class OPENVINO_API Any : public Pattern {
1515
OPENVINO_RTTI("patternAny");
1616
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
1717
/// shape.
18-
template <typename TPredicate>
19-
Any(const element::Type& type, const PartialShape& s, const TPredicate& pred, const OutputVector& wrapped_values)
18+
template <typename TPredicate, typename TArg>
19+
Any(const element::Type& type, const PartialShape& s, const TPredicate& pred, const TArg& wrapped_values)
2020
: Pattern(wrapped_values, Predicate(pred)) {
2121
set_output_type(0, type, s);
2222
}
23-
template <typename TPredicate>
24-
Any(const element::Type& type, const PartialShape& s, const TPredicate& pred, const NodeVector& wrapped_values)
25-
: Any(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
23+
2624
/// \brief creates a Any node containing a sub-pattern described by the type and
2725
/// shape of \sa node.
28-
template <typename TPredicate>
29-
Any(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
26+
template <typename TPredicate, typename TArg>
27+
Any(const Output<Node>& node, const TPredicate& pred, const TArg& wrapped_values)
3028
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
31-
template <typename TPredicate>
32-
Any(const Output<Node>& node, const TPredicate& pred, const NodeVector& wrapped_values)
33-
: Any(node, Predicate(pred), as_output_vector(wrapped_values)) {}
3429

3530
bool match_value(pattern::Matcher* matcher,
3631
const Output<Node>& pattern_value,

src/core/include/openvino/pass/pattern/op/any_of.hpp

+5-10
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,21 @@ class OPENVINO_API AnyOf : public Pattern {
2020
OPENVINO_RTTI("patternAnyOf");
2121
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
2222
/// \sa shape.
23-
template <typename TPredicate>
24-
AnyOf(const element::Type& type, const PartialShape& s, const TPredicate& pred, const OutputVector& wrapped_values)
23+
template <typename TPredicate, typename TArg>
24+
AnyOf(const element::Type& type, const PartialShape& s, const TPredicate& pred, const TArg& wrapped_values)
2525
: Pattern(wrapped_values, Predicate(pred)) {
2626
if (wrapped_values.size() != 1) {
2727
OPENVINO_THROW("AnyOf expects exactly one argument");
2828
}
2929
set_output_type(0, type, s);
3030
}
31-
template <typename TPredicate>
32-
AnyOf(const element::Type& type, const PartialShape& s, const TPredicate& pred, const NodeVector& wrapped_values)
33-
: AnyOf(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
3431

3532
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
3633
/// shape of \sa node.
37-
template <typename TPredicate>
38-
AnyOf(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
34+
template <typename TPredicate, typename TArg>
35+
AnyOf(const Output<Node>& node, const TPredicate& pred, const TArg& wrapped_values)
3936
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
40-
template <typename TPredicate>
41-
AnyOf(const std::shared_ptr<Node>& node, const TPredicate& pred, const NodeVector& wrapped_values)
42-
: AnyOf(node, Predicate(pred), as_output_vector(wrapped_values)) {}
37+
4338
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
4439
};
4540
} // namespace ov::pass::pattern::op

src/core/include/openvino/pass/pattern/op/optional.hpp

+19-16
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,10 @@ class OPENVINO_API Optional : public Pattern {
5151
/// \param type_infos Optional operation types to exclude them from the matching
5252
/// in case the following op types do not exist in a pattern to match.
5353
/// \param patterns The pattern to match a graph.
54-
Optional(const std::vector<DiscreteTypeInfo>& type_infos, const OutputVector& inputs = {})
55-
: Pattern(inputs),
56-
optional_types(type_infos){};
57-
58-
template <typename TPredicate>
59-
Optional(const std::vector<DiscreteTypeInfo>& type_infos, const OutputVector& inputs, const TPredicate& pred)
54+
template <typename TPredicate = nullptr_t>
55+
Optional(const std::vector<DiscreteTypeInfo>& type_infos,
56+
const OutputVector& inputs = {},
57+
const TPredicate& pred = nullptr)
6058
: Pattern(inputs, Predicate(pred)),
6159
optional_types(type_infos){};
6260

@@ -84,8 +82,10 @@ void collect_type_info(std::vector<DiscreteTypeInfo>& type_info_vec) {
8482
collect_type_info<NodeTypeArgs...>(type_info_vec);
8583
}
8684

87-
template <class... NodeTypes, typename TPredicate>
88-
std::shared_ptr<Node> optional(const OutputVector& inputs, const TPredicate& pred, const Attributes& attrs = {}) {
85+
template <class... NodeTypes, typename TPredicate = nullptr_t>
86+
std::shared_ptr<Node> optional(const OutputVector& inputs,
87+
const TPredicate& pred = nullptr,
88+
const Attributes& attrs = {}) {
8989
std::vector<DiscreteTypeInfo> optional_type_info_vec;
9090
collect_type_info<NodeTypes...>(optional_type_info_vec);
9191
return std::make_shared<op::Optional>(
@@ -94,8 +94,10 @@ std::shared_ptr<Node> optional(const OutputVector& inputs, const TPredicate& pre
9494
attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred));
9595
}
9696

97-
template <class... NodeTypes, typename TPredicate>
98-
std::shared_ptr<Node> optional(const Output<Node>& input, const TPredicate& pred, const Attributes& attrs = {}) {
97+
template <class... NodeTypes, typename TPredicate = nullptr_t>
98+
std::shared_ptr<Node> optional(const Output<Node>& input,
99+
const TPredicate& pred = nullptr,
100+
const Attributes& attrs = {}) {
99101
return optional<NodeTypes...>(OutputVector{input}, op::Predicate(pred), attrs);
100102
}
101103

@@ -107,21 +109,22 @@ std::shared_ptr<Node> optional(const TPredicate& pred, const Attributes& attrs =
107109
}
108110

109111
template <class... NodeTypes>
110-
std::shared_ptr<Node> optional(const std::initializer_list<Output<Node>>& inputs, const Attributes& attrs = {}) {
111-
return optional<NodeTypes...>(OutputVector(inputs), attrs.empty() ? op::Predicate() : attrs_match(attrs));
112+
std::shared_ptr<Node> optional(const OutputVector& inputs, const Attributes& attrs) {
113+
return optional<NodeTypes...>(inputs, attrs.empty() ? op::Predicate() : attrs_match(attrs));
112114
}
115+
113116
template <class... NodeTypes>
114-
std::shared_ptr<Node> optional(const OutputVector& inputs, const Attributes& attrs = {}) {
115-
return optional<NodeTypes...>(inputs, attrs.empty() ? op::Predicate() : attrs_match(attrs));
117+
std::shared_ptr<Node> optional(const std::initializer_list<Output<Node>>& inputs, const Attributes& attrs = {}) {
118+
return optional<NodeTypes...>(OutputVector(inputs), attrs);
116119
}
117120

118121
template <class... NodeTypes>
119-
std::shared_ptr<Node> optional(const Output<Node>& input, const Attributes& attrs = {}) {
122+
std::shared_ptr<Node> optional(const Output<Node>& input, const Attributes& attrs) {
120123
return optional<NodeTypes...>(OutputVector{input}, attrs);
121124
}
122125

123126
template <class... NodeTypes>
124127
std::shared_ptr<Node> optional(const Attributes& attrs = {}) {
125-
return optional<NodeTypes...>(OutputVector{}, attrs.empty() ? op::Predicate() : attrs_match(attrs));
128+
return optional<NodeTypes...>(OutputVector{}, attrs);
126129
}
127130
} // namespace ov::pass::pattern

src/core/include/openvino/pass/pattern/op/wrap_type.hpp

+14-12
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ void collect_wrap_info(std::vector<DiscreteTypeInfo>& info) {
6363
collect_wrap_info<Targs...>(info);
6464
}
6565

66-
template <class... Args, typename TPredicate>
67-
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const TPredicate& pred, const Attributes& attrs = {}) {
66+
template <class... Args, typename TPredicate = nullptr_t>
67+
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {},
68+
const TPredicate& pred = nullptr,
69+
const Attributes& attrs = {}) {
6870
std::vector<DiscreteTypeInfo> info;
6971
collect_wrap_info<Args...>(info);
7072
return std::make_shared<op::WrapType>(
@@ -73,25 +75,25 @@ std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const TPredicate& pr
7375
inputs);
7476
}
7577

78+
template <class... Args,
79+
typename TPredicate,
80+
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr>
81+
std::shared_ptr<Node> wrap_type(const TPredicate& pred, const Attributes& attrs = {}) {
82+
return wrap_type<Args...>({}, op::Predicate(pred), attrs);
83+
}
84+
7685
template <class... Args>
77-
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const Attributes& attrs = {}) {
86+
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const Attributes& attrs) {
7887
return wrap_type<Args...>(inputs, (attrs.empty() ? op::Predicate() : attrs_match(attrs)));
7988
}
8089

8190
template <class... Args>
82-
std::shared_ptr<Node> wrap_type(const std::initializer_list<Output<Node>>& inputs = {}, const Attributes& attrs = {}) {
91+
std::shared_ptr<Node> wrap_type(const std::initializer_list<Output<Node>>& inputs, const Attributes& attrs = {}) {
8392
return wrap_type<Args...>(OutputVector(inputs), attrs);
8493
}
8594

86-
template <class... Args,
87-
typename TPredicate,
88-
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr>
89-
std::shared_ptr<Node> wrap_type(const TPredicate& pred, const Attributes& attrs = {}) {
90-
return wrap_type<Args...>({}, op::Predicate(pred), attrs);
91-
}
92-
9395
template <class... Args>
9496
std::shared_ptr<Node> wrap_type(const Attributes& attrs) {
95-
return wrap_type<Args...>({}, attrs);
97+
return wrap_type<Args...>(attrs_match(attrs));
9698
}
9799
} // namespace ov::pass::pattern

0 commit comments

Comments
 (0)