Skip to content

Commit e35cd4d

Browse files
committed
Fixing NodePredicate conversion to ValuePredicate
Signed-off-by: Evgeniia Nugmanova <evgeniia.nugmanova@intel.com>
1 parent 43fb02a commit e35cd4d

File tree

12 files changed

+41
-57
lines changed

12 files changed

+41
-57
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output)
294294
TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);
295295

296296
template <typename T>
297-
ov::pass::pattern::op::ValuePredicate constant_predicate(std::function<bool(const std::vector<T>&)> predicate) {
298-
return pass::pattern::op::as_value_predicate([=](std::shared_ptr<Node> n) -> bool {
297+
ov::pass::pattern::op::Predicate constant_predicate(std::function<bool(const std::vector<T>&)> predicate) {
298+
return ov::pass::pattern::op::Predicate([=](std::shared_ptr<Node> n) -> bool {
299299
if (auto constant = as_type_ptr<v0::Constant>(n)) {
300300
auto values = constant->cast_vector<T>();
301301
return predicate(values);

src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -893,12 +893,11 @@ pass::EliminateScatterUpdate::EliminateScatterUpdate() {
893893

894894
ov::pass::EliminateNopBroadcast::EliminateNopBroadcast() {
895895
MATCHER_SCOPE(EliminateNopBroadcast);
896-
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>(
897-
pattern::op::as_value_predicate([](std::shared_ptr<Node> node) {
898-
auto input_rank = node->get_input_partial_shape(0).rank();
899-
auto output_rank = node->get_output_partial_shape(0).rank();
900-
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
901-
}));
896+
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>([](std::shared_ptr<Node> node) {
897+
auto input_rank = node->get_input_partial_shape(0).rank();
898+
auto output_rank = node->get_output_partial_shape(0).rank();
899+
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
900+
});
902901

903902
ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
904903
const auto& op = m.get_match_root();

src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ void pull_reshape_through_optional_concat_and_bea(const ov::pass::pattern::Patte
182182
}
183183
} // namespace
184184

185-
#define IN_RESHAPE \
186-
pattern::wrap_type<op::v1::Reshape>(pattern::op::as_value_predicate([](std::shared_ptr<Node> n) -> bool { \
187-
return pattern::consumers_count(1)(n->output(0)) && reshape_keeps_last_two_dims(n); \
188-
}));
185+
#define IN_RESHAPE \
186+
pattern::wrap_type<op::v1::Reshape>([](std::shared_ptr<Node> n) -> bool { \
187+
return pattern::consumers_count(1)(n->output(0)) && reshape_keeps_last_two_dims(n); \
188+
});
189189

190190
#define SCALAR_INPUT \
191191
pattern::any_input([](ov::Output<Node> out) { \
@@ -252,10 +252,9 @@ ov::pass::DeReshapeMatMul::DeReshapeMatMul() {
252252

253253
auto matmul_or_add = std::make_shared<pattern::op::Or>(OutputVector{matmul, add});
254254
auto final_reshape =
255-
pattern::wrap_type<op::v1::Reshape>({matmul_or_add, pattern::any_input()},
256-
pattern::op::as_value_predicate([](std::shared_ptr<Node> n) -> bool {
257-
return reshape_keeps_last_two_dims(n);
258-
}));
255+
pattern::wrap_type<op::v1::Reshape>({matmul_or_add, pattern::any_input()}, [](std::shared_ptr<Node> n) -> bool {
256+
return reshape_keeps_last_two_dims(n);
257+
});
259258

260259
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
261260
const auto& pm = m.get_pattern_map();

src/core/include/openvino/pass/pattern/op/any.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ class OPENVINO_API Any : public Pattern {
2121
set_output_type(0, type, s);
2222
}
2323
Any(const element::Type& type, const PartialShape& s, const NodePredicate& pred, const NodeVector& wrapped_values)
24-
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
24+
: Any(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
2525
/// \brief creates a Any node containing a sub-pattern described by the type and
2626
/// shape of \sa node.
2727
template <typename TPredicate>
2828
Any(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
2929
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
3030
Any(const Output<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values)
31-
: Any(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
31+
: Any(node, Predicate(pred), as_output_vector(wrapped_values)) {}
3232

3333
bool match_value(pattern::Matcher* matcher,
3434
const Output<Node>& pattern_value,

src/core/include/openvino/pass/pattern/op/any_of.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ class OPENVINO_API AnyOf : public Pattern {
2929
set_output_type(0, type, s);
3030
}
3131
AnyOf(const element::Type& type, const PartialShape& s, const NodePredicate& pred, const NodeVector& wrapped_values)
32-
: AnyOf(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
32+
: AnyOf(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
3333

3434
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
3535
/// shape of \sa node.
3636
template <typename TPredicate>
3737
AnyOf(const Output<Node>& node, const TPredicate& pred, const OutputVector& wrapped_values)
3838
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
3939
AnyOf(const std::shared_ptr<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values)
40-
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
40+
: AnyOf(node, Predicate(pred), as_output_vector(wrapped_values)) {}
4141
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
4242
};
4343
} // namespace ov::pass::pattern::op

src/core/include/openvino/pass/pattern/op/label.hpp

+2-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class OPENVINO_API Label : public Pattern {
4646
const PartialShape& s,
4747
const NodePredicate& pred,
4848
const NodeVector& wrapped_values = {})
49-
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
49+
: Label(type, s, Predicate(pred), as_output_vector(wrapped_values)) {}
5050

5151
/// \brief creates a Label node containing a sub-pattern described by the type and
5252
/// shape of \sa node.
@@ -64,10 +64,7 @@ class OPENVINO_API Label : public Pattern {
6464
Label(const Output<Node>& value, const TPredicate& pred, const OutputVector& wrapped_values = {})
6565
: Label(value.get_element_type(), value.get_partial_shape(), Predicate(pred), wrapped_values) {}
6666
Label(const Output<Node>& node, const NodePredicate& pred, const NodeVector& wrapped_values = {})
67-
: Label(node.get_element_type(),
68-
node.get_partial_shape(),
69-
as_value_predicate(pred),
70-
as_output_vector(wrapped_values)) {}
67+
: Label(node.get_element_type(), node.get_partial_shape(), Predicate(pred), as_output_vector(wrapped_values)) {}
7168

7269
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
7370
: Label(type, s, nullptr, OutputVector{}) {}

src/core/include/openvino/pass/pattern/op/pattern.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ OPENVINO_API op::Predicate all_of(const std::vector<std::function<bool(Output<No
6464
OPENVINO_API op::Predicate shape_matches(const std::string& shape_notation);
6565

6666
namespace op {
67-
68-
OPENVINO_API
69-
ValuePredicate as_value_predicate(NodePredicate pred);
67+
OPENVINO_DEPRECATED("This method is deprecated. Use constructor of ov::pass::pattern::Predicate instead")
68+
OPENVINO_API Predicate as_value_predicate(NodePredicate pred);
7069

7170
class OPENVINO_API Pattern : public Node {
7271
public:

src/core/include/openvino/pass/pattern/op/predicate.hpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,8 @@ class OPENVINO_API Predicate {
9494
}
9595

9696
bool operator()(PatternSymbolMap& m, const Output<Node>& output) const;
97-
98-
template <typename T>
99-
bool operator()(const T& arg) const {
100-
OPENVINO_ASSERT(!m_requires_map,
101-
"Predicate " + m_name + " called with unexpected argument: " + std::string(typeid(T).name()));
102-
if constexpr (std::is_convertible_v<T, Output<Node>>) {
103-
PatternSymbolMap dummy_map;
104-
return m_pred(dummy_map, arg);
105-
}
106-
OPENVINO_ASSERT(false,
107-
"Predicate " + m_name + " called with unexpected argument: " + std::string(typeid(T).name()));
108-
}
97+
bool operator()(const std::shared_ptr<Node>& node) const;
98+
bool operator()(const Output<Node>& output) const;
10999

110100
template <typename TPredicate>
111101
Predicate operator||(const TPredicate& other) const {

src/core/src/pattern/op/pattern.cpp

+2-14
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,12 @@ constexpr bool value_true_predicate(const Output<Node>&) {
1717
}
1818
} // namespace
1919

20-
struct NodeValuePredicate {
21-
bool operator()(const Output<Node>& value) const {
22-
return pred(value.get_node_shared_ptr());
23-
}
24-
25-
NodePredicate pred;
26-
};
27-
2820
Pattern::Pattern(const OutputVector& patterns) : Node(patterns), m_predicate() {}
2921

3022
Pattern::Pattern(const OutputVector& patterns, const op::Predicate& pred) : Node(patterns), m_predicate(pred) {}
3123

32-
ValuePredicate as_value_predicate(NodePredicate pred) {
33-
if (pred) {
34-
return NodeValuePredicate{std::move(pred)};
35-
} else {
36-
return value_true_predicate;
37-
}
24+
Predicate as_value_predicate(NodePredicate pred) {
25+
return Predicate(pred);
3826
}
3927

4028
std::ostream& Pattern::write_type_description(std::ostream& out) const {

src/core/src/pattern/op/predicate.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,17 @@ bool Predicate::operator()(pass::pattern::PatternSymbolMap& m, const Output<Node
7878
OPENVINO_DEBUG("Predicate `", m_name, "` has ", (result ? "passed" : "failed"), ". Applied to ", output);
7979
return result;
8080
}
81+
82+
bool Predicate::operator()(const std::shared_ptr<Node>& node) const {
83+
OPENVINO_ASSERT(!m_requires_map, "Predicate " + m_name + " called with unexpected argument: std::shared_ptr<Node>");
84+
PatternSymbolMap dummy_map;
85+
return m_pred(dummy_map, node);
86+
}
87+
88+
bool Predicate::operator()(const Output<Node>& output) const {
89+
OPENVINO_ASSERT(!m_requires_map, "Predicate " + m_name + " called with unexpected argument: Output<Node>");
90+
PatternSymbolMap dummy_map;
91+
return m_pred(dummy_map, output);
92+
}
8193
} // namespace op
8294
} // namespace ov::pass::pattern

src/plugins/intel_gpu/src/plugin/transformations/clamp_fp16_output.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ ClampFP16Output::ClampFP16Output() {
2828
using namespace ov::pass::pattern;
2929
using namespace ov::pass::pattern::op;
3030

31-
auto in0 = any_input(as_value_predicate(class_other_than<v0::Constant>()));
32-
auto in1 = any_input(as_value_predicate(class_other_than<v0::Constant>()));
31+
auto in0 = any_input(class_other_than<v0::Constant>());
32+
auto in1 = any_input(class_other_than<v0::Constant>());
3333
auto matmul_m = wrap_type<v0::MatMul>({in0, in1}, type_matches(ov::element::f16) && consumers_count(1));
3434
auto reshape_m = wrap_type<v1::Reshape>({matmul_m, any_input()}, type_matches(ov::element::f16) && consumers_count(1));
3535
auto add_m = wrap_type<v1::Add>({matmul_m, any_input()}, type_matches(ov::element::f16) && consumers_count(1));

src/plugins/intel_gpu/src/plugin/transformations/lora_horizontal_fusion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ LoRAHorizontalFusion::LoRAHorizontalFusion() {
5858

5959
auto axis_const = wrap_type<ov::op::v0::Constant>();
6060
auto split_const = wrap_type<ov::op::v0::Constant>();
61-
auto split = wrap_type<ov::op::v1::VariadicSplit>({main_flow, axis_const, split_const}, ov::pass::pattern::op::as_value_predicate(is_target_pattern));
61+
auto split = wrap_type<ov::op::v1::VariadicSplit>({main_flow, axis_const, split_const}, is_target_pattern);
6262

6363
ov::matcher_pass_callback callback = [=](Matcher& m) {
6464
const auto& pattern_map = m.get_pattern_value_map();

0 commit comments

Comments
 (0)