Skip to content

Commit cce35ac

Browse files
authored
[Transformations] Fixing NodePredicate conversion to ValuePredicate (#29172)
### Details: - *Fixing NodePredicate conversion to ValuePredicate* ### Tickets: - *CVS-163062,CVS-163051* --------- Signed-off-by: Evgeniia Nugmanova <evgeniia.nugmanova@intel.com>
1 parent d6147c2 commit cce35ac

File tree

15 files changed

+73
-79
lines changed

15 files changed

+73
-79
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output)
294294
TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);
295295

296296
template <typename T>
297-
ov::pass::pattern::op::ValuePredicate constant_predicate(std::function<bool(const std::vector<T>&)> predicate) {
298-
return pass::pattern::op::as_value_predicate([=](std::shared_ptr<Node> n) -> bool {
297+
ov::pass::pattern::op::Predicate constant_predicate(std::function<bool(const std::vector<T>&)> predicate) {
298+
return ov::pass::pattern::op::Predicate([=](std::shared_ptr<Node> n) -> bool {
299299
if (auto constant = as_type_ptr<v0::Constant>(n)) {
300300
auto values = constant->cast_vector<T>();
301301
return predicate(values);

src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -893,12 +893,11 @@ pass::EliminateScatterUpdate::EliminateScatterUpdate() {
893893

894894
ov::pass::EliminateNopBroadcast::EliminateNopBroadcast() {
895895
MATCHER_SCOPE(EliminateNopBroadcast);
896-
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>(
897-
pattern::op::as_value_predicate([](std::shared_ptr<Node> node) {
898-
auto input_rank = node->get_input_partial_shape(0).rank();
899-
auto output_rank = node->get_output_partial_shape(0).rank();
900-
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
901-
}));
896+
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>([](std::shared_ptr<Node> node) {
897+
auto input_rank = node->get_input_partial_shape(0).rank();
898+
auto output_rank = node->get_output_partial_shape(0).rank();
899+
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
900+
});
902901

903902
ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
904903
const auto& op = m.get_match_root();

src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ void pull_reshape_through_optional_concat_and_bea(const ov::pass::pattern::Patte
182182
}
183183
} // namespace
184184

185-
#define IN_RESHAPE \
186-
pattern::wrap_type<op::v1::Reshape>(pattern::op::as_value_predicate([](std::shared_ptr<Node> n) -> bool { \
187-
return pattern::consumers_count(1)(n->output(0)) && reshape_keeps_last_two_dims(n); \
188-
}));
185+
#define IN_RESHAPE \
186+
pattern::wrap_type<op::v1::Reshape>([](std::shared_ptr<Node> n) -> bool { \
187+
return pattern::consumers_count(1)(n->output(0)) && reshape_keeps_last_two_dims(n); \
188+
});
189189

190190
#define SCALAR_INPUT \
191191
pattern::any_input([](ov::Output<Node> out) { \
@@ -252,10 +252,9 @@ ov::pass::DeReshapeMatMul::DeReshapeMatMul() {
252252

253253
auto matmul_or_add = std::make_shared<pattern::op::Or>(OutputVector{matmul, add});
254254
auto final_reshape =
255-
pattern::wrap_type<op::v1::Reshape>({matmul_or_add, pattern::any_input()},
256-
pattern::op::as_value_predicate([](std::shared_ptr<Node> n) -> bool {
257-
return reshape_keeps_last_two_dims(n);
258-
}));
255+
pattern::wrap_type<op::v1::Reshape>({matmul_or_add, pattern::any_input()}, [](std::shared_ptr<Node> n) -> bool {
256+
return reshape_keeps_last_two_dims(n);
257+
});
259258

260259
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
261260
const auto& pm = m.get_pattern_map();

src/core/include/openvino/pass/pattern/op/any.hpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ class OPENVINO_API Any : public Pattern {
2020
: Pattern(wrapped_values, Predicate(pred)) {
2121
set_output_type(0, type, s);
2222
}
23-
Any(const element::Type& type, const PartialShape& s, const NodePredicate& pred, const NodeVector& wrapped_values)
24-
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
23+
template <typename TPredicate>
24+
Any(const element::Type& type, const PartialShape& s, const TPredicate& pred, const NodeVector& wrapped_values)
25+
: Any(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
2526
/// \brief creates a Any node containing a sub-pattern described by the type and
2627
/// shape of \sa node.
2728
template <typename TPredicate>
2829
Any(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
2930
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
30-
Any(const Output<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values)
31-
: Any(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
31+
template <typename TPredicate>
32+
Any(const Output<Node>& node, const TPredicate& pred, const NodeVector& wrapped_values)
33+
: Any(node, Predicate(pred), as_output_vector(wrapped_values)) {}
3234

3335
bool match_value(pattern::Matcher* matcher,
3436
const Output<Node>& pattern_value,

src/core/include/openvino/pass/pattern/op/any_of.hpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@ class OPENVINO_API AnyOf : public Pattern {
2828
}
2929
set_output_type(0, type, s);
3030
}
31-
AnyOf(const element::Type& type, const PartialShape& s, const NodePredicate& pred, const NodeVector& wrapped_values)
32-
: AnyOf(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
31+
template <typename TPredicate>
32+
AnyOf(const element::Type& type, const PartialShape& s, const TPredicate& pred, const NodeVector& wrapped_values)
33+
: AnyOf(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
3334

3435
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
3536
/// shape of \sa node.
3637
template <typename TPredicate>
3738
AnyOf(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
3839
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
39-
AnyOf(const std::shared_ptr<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values)
40-
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
40+
template <typename TPredicate>
41+
AnyOf(const std::shared_ptr<Node>& node, const TPredicate& pred, const NodeVector& wrapped_values)
42+
: AnyOf(node, Predicate(pred), as_output_vector(wrapped_values)) {}
4143
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
4244
};
4345
} // namespace ov::pass::pattern::op

src/core/include/openvino/pass/pattern/op/label.hpp

+5-14
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,14 @@ class OPENVINO_API Label : public Pattern {
3434
/// nullptr,
3535
/// OutputVector{add});
3636
/// \endcode
37-
template <typename TPredicate>
37+
template <typename TPredicate, typename TArg = OutputVector>
3838
Label(const element::Type& type,
3939
const PartialShape& s,
4040
const TPredicate& pred,
41-
const OutputVector& wrapped_values = {})
41+
const TArg& wrapped_values = OutputVector{})
4242
: Pattern(OutputVector{wrap_values(wrapped_values)}, Predicate(pred)) {
4343
set_output_type(0, type, s);
4444
}
45-
Label(const element::Type& type,
46-
const PartialShape& s,
47-
const NodePredicate& pred,
48-
const NodeVector& wrapped_values = {})
49-
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
5045

5146
/// \brief creates a Label node containing a sub-pattern described by the type and
5247
/// shape of \sa node.
@@ -60,14 +55,9 @@ class OPENVINO_API Label : public Pattern {
6055
/// nullptr,
6156
/// OutputVector{add});
6257
/// \endcode
63-
template <typename TPredicate>
64-
Label(const Output<Node>& value, const TPredicate& pred, const OutputVector& wrapped_values = {})
58+
template <typename TPredicate, typename TArg = OutputVector>
59+
Label(const Output<Node>& value, const TPredicate& pred, const TArg& wrapped_values = OutputVector{})
6560
: Label(value.get_element_type(), value.get_partial_shape(), Predicate(pred), wrapped_values) {}
66-
Label(const Output<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values = {})
67-
: Label(node.get_element_type(),
68-
node.get_partial_shape(),
69-
as_value_predicate(pred),
70-
as_output_vector(wrapped_values)) {}
7161

7262
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
7363
: Label(type, s, nullptr, OutputVector{}) {}
@@ -78,6 +68,7 @@ class OPENVINO_API Label : public Pattern {
7868

7969
protected:
8070
static Output<Node> wrap_values(const OutputVector& wrapped_values);
71+
static Output<Node> wrap_values(const NodeVector& wrapped_values);
8172
};
8273
} // namespace op
8374

src/core/include/openvino/pass/pattern/op/pattern.hpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,17 @@ OPENVINO_API op::Predicate all_of(const std::vector<std::function<bool(Output<No
6464
OPENVINO_API op::Predicate shape_matches(const std::string& shape_notation);
6565

6666
namespace op {
67-
68-
OPENVINO_API
69-
ValuePredicate as_value_predicate(NodePredicate pred);
67+
OPENVINO_DEPRECATED("This method is deprecated. Use constructor of ov::pass::pattern::Predicate instead")
68+
OPENVINO_API Predicate as_value_predicate(NodePredicate pred);
7069

7170
class OPENVINO_API Pattern : public Node {
7271
public:
72+
Pattern();
73+
explicit Pattern(const OutputVector& patterns);
74+
explicit Pattern(const NodeVector& patterns);
7375
/// \brief A base class for all the utility operators used to describe a pattern to match
7476
Pattern(const OutputVector& patterns, const Predicate& pred);
75-
Pattern(const OutputVector& patterns);
77+
Pattern(const NodeVector& patterns, const Predicate& pred);
7678

7779
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
7880
OPENVINO_THROW("Uncopyable");

src/core/include/openvino/pass/pattern/op/predicate.hpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,8 @@ class OPENVINO_API Predicate {
9494
}
9595

9696
bool operator()(PatternSymbolMap& m, const Output<Node>& output) const;
97-
98-
template <typename T>
99-
bool operator()(const T& arg) const {
100-
OPENVINO_ASSERT(!m_requires_map,
101-
"Predicate " + m_name + " called with unexpected argument: " + std::string(typeid(T).name()));
102-
if constexpr (std::is_convertible_v<T, Output<Node>>) {
103-
PatternSymbolMap dummy_map;
104-
return m_pred(dummy_map, arg);
105-
}
106-
OPENVINO_ASSERT(false,
107-
"Predicate " + m_name + " called with unexpected argument: " + std::string(typeid(T).name()));
108-
}
97+
bool operator()(const std::shared_ptr<Node>& node) const;
98+
bool operator()(const Output<Node>& output) const;
10999

110100
template <typename TPredicate>
111101
Predicate operator||(const TPredicate& other) const {

src/core/include/openvino/pass/pattern/op/true.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class OPENVINO_API True : public Pattern {
1313
public:
1414
OPENVINO_RTTI("patternTrue");
1515
/// \brief Always matches, does not add node to match list.
16-
True() : Pattern(OutputVector{}) {}
16+
True() : Pattern() {}
1717
bool match_value(pattern::Matcher* matcher,
1818
const Output<Node>& pattern_value,
1919
const Output<Node>& graph_value) override;

src/core/include/openvino/pass/pattern/op/wrap_type.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class OPENVINO_API WrapType : public Pattern {
1414
public:
1515
OPENVINO_RTTI("WrapType");
1616

17-
explicit WrapType(const std::vector<NodeTypeInfo>& wrapped_types) : Pattern({}), m_wrapped_types(wrapped_types) {
17+
explicit WrapType(const std::vector<NodeTypeInfo>& wrapped_types) : Pattern(), m_wrapped_types(wrapped_types) {
1818
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
1919
}
2020

@@ -27,7 +27,7 @@ class OPENVINO_API WrapType : public Pattern {
2727
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
2828
}
2929

30-
explicit WrapType(NodeTypeInfo wrapped_type) : Pattern({}), m_wrapped_types({wrapped_type}) {
30+
explicit WrapType(NodeTypeInfo wrapped_type) : Pattern(), m_wrapped_types({wrapped_type}) {
3131
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
3232
}
3333

src/core/src/pattern/op/label.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::OutputV
2121
}
2222
}
2323

24+
ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::NodeVector& wrapped_values) {
25+
switch (wrapped_values.size()) {
26+
case 0:
27+
return std::make_shared<pattern::op::True>()->output(0);
28+
case 1:
29+
return wrapped_values[0];
30+
default:
31+
return std::make_shared<pattern::op::Or>(as_output_vector(wrapped_values))->output(0);
32+
}
33+
}
34+
2435
bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher,
2536
const ov::Output<ov::Node>& pattern_value,
2637
const ov::Output<ov::Node>& graph_value) {

src/core/src/pattern/op/pattern.cpp

+5-19
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,16 @@
1111

1212
namespace ov::pass::pattern {
1313
namespace op {
14-
namespace {
15-
constexpr bool value_true_predicate(const Output<Node>&) {
16-
return true;
17-
}
18-
} // namespace
19-
20-
struct NodeValuePredicate {
21-
bool operator()(const Output<Node>& value) const {
22-
return pred(value.get_node_shared_ptr());
23-
}
24-
25-
NodePredicate pred;
26-
};
2714

15+
Pattern::Pattern() : Node(), m_predicate() {}
2816
Pattern::Pattern(const OutputVector& patterns) : Node(patterns), m_predicate() {}
17+
Pattern::Pattern(const NodeVector& patterns) : Pattern(as_output_vector(patterns)) {}
2918

3019
Pattern::Pattern(const OutputVector& patterns, const op::Predicate& pred) : Node(patterns), m_predicate(pred) {}
20+
Pattern::Pattern(const NodeVector& patterns, const op::Predicate& pred) : Pattern(as_output_vector(patterns), pred) {}
3121

32-
ValuePredicate as_value_predicate(NodePredicate pred) {
33-
if (pred) {
34-
return NodeValuePredicate{std::move(pred)};
35-
} else {
36-
return value_true_predicate;
37-
}
22+
Predicate as_value_predicate(NodePredicate pred) {
23+
return Predicate(pred);
3824
}
3925

4026
std::ostream& Pattern::write_type_description(std::ostream& out) const {

src/core/src/pattern/op/predicate.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,17 @@ bool Predicate::operator()(pass::pattern::PatternSymbolMap& m, const Output<Node
7878
OPENVINO_DEBUG("Predicate `", m_name, "` has ", (result ? "passed" : "failed"), ". Applied to ", output);
7979
return result;
8080
}
81+
82+
bool Predicate::operator()(const std::shared_ptr<Node>& node) const {
83+
OPENVINO_ASSERT(!m_requires_map, "Predicate " + m_name + " called with unexpected argument: std::shared_ptr<Node>");
84+
PatternSymbolMap dummy_map;
85+
return m_pred(dummy_map, node);
86+
}
87+
88+
bool Predicate::operator()(const Output<Node>& output) const {
89+
OPENVINO_ASSERT(!m_requires_map, "Predicate " + m_name + " called with unexpected argument: Output<Node>");
90+
PatternSymbolMap dummy_map;
91+
return m_pred(dummy_map, output);
92+
}
8193
} // namespace op
8294
} // namespace ov::pass::pattern

src/plugins/intel_gpu/src/plugin/transformations/clamp_fp16_output.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ ClampFP16Output::ClampFP16Output() {
2828
using namespace ov::pass::pattern;
2929
using namespace ov::pass::pattern::op;
3030

31-
auto in0 = any_input(as_value_predicate(class_other_than<v0::Constant>()));
32-
auto in1 = any_input(as_value_predicate(class_other_than<v0::Constant>()));
31+
auto in0 = any_input(class_other_than<v0::Constant>());
32+
auto in1 = any_input(class_other_than<v0::Constant>());
3333
auto matmul_m = wrap_type<v0::MatMul>({in0, in1}, type_matches(ov::element::f16) && consumers_count(1));
3434
auto reshape_m = wrap_type<v1::Reshape>({matmul_m, any_input()}, type_matches(ov::element::f16) && consumers_count(1));
3535
auto add_m = wrap_type<v1::Add>({matmul_m, any_input()}, type_matches(ov::element::f16) && consumers_count(1));

src/plugins/intel_gpu/src/plugin/transformations/lora_horizontal_fusion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ LoRAHorizontalFusion::LoRAHorizontalFusion() {
5858

5959
auto axis_const = wrap_type<ov::op::v0::Constant>();
6060
auto split_const = wrap_type<ov::op::v0::Constant>();
61-
auto split = wrap_type<ov::op::v1::VariadicSplit>({main_flow, axis_const, split_const}, ov::pass::pattern::op::as_value_predicate(is_target_pattern));
61+
auto split = wrap_type<ov::op::v1::VariadicSplit>({main_flow, axis_const, split_const}, is_target_pattern);
6262

6363
ov::matcher_pass_callback callback = [=](Matcher& m) {
6464
const auto& pattern_map = m.get_pattern_value_map();

0 commit comments

Comments
 (0)