Skip to content

Commit 0902fe4

Browse files
[TRANSFORMATIONS] Create python binding for pattern::Optional (openvinotoolkit#23558)
[TRANSFORMATIONS] Create python binding for pattern::Optional Expose the C++ op::pattern::Optional to Python in order to simplify patterns creation. Cover the functionality with the dedicated tests. ### Tickets: CVS-133523 Signed-off-by: Andrii Staikov <andrii.staikov@intel.com> --------- Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent fea5487 commit 0902fe4

File tree

4 files changed

+245
-19
lines changed

4 files changed

+245
-19
lines changed

src/bindings/python/src/openvino/runtime/passes/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# type: ignore
44
# flake8: noqa
55

6-
from openvino._pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput
6+
from openvino._pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput, Optional
77
from openvino._pyopenvino.passes import (
88
consumers_count,
99
has_static_dim,

src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp

+113
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <string>
1414

1515
#include "openvino/pass/pattern/op/label.hpp"
16+
#include "openvino/pass/pattern/op/optional.hpp"
1617
#include "openvino/pass/pattern/op/or.hpp"
1718
#include "openvino/pass/pattern/op/pattern.hpp"
1819
#include "openvino/pass/pattern/op/wrap_type.hpp"
@@ -482,6 +483,117 @@ static void reg_pattern_any_input(py::module m) {
482483
});
483484
}
484485

486+
static void reg_pattern_optional(py::module m) {
487+
py::class_<ov::pass::pattern::op::Optional, std::shared_ptr<ov::pass::pattern::op::Optional>, ov::Node>
488+
optional_type(m, "Optional");
489+
optional_type.doc() = "openvino.runtime.passes.Optional wraps ov::pass::pattern::op::Optional";
490+
491+
optional_type.def(py::init([](const std::vector<std::string>& type_names) {
492+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names));
493+
}),
494+
py::arg("type_name"),
495+
R"(
496+
Create Optional with the given node type.
497+
498+
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
499+
:type type_names: List[str]
500+
)");
501+
502+
optional_type.def(py::init([](const std::vector<std::string>& type_names, const Predicate& predicate) {
503+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), predicate);
504+
}),
505+
py::arg("type_names"),
506+
py::arg("predicate"),
507+
R"(
508+
Create Optional with the given node type and predicate.
509+
510+
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
511+
:type type_names: List[str]
512+
513+
:param predicate: Function that performs additional checks for matching.
514+
:type predicate: function
515+
)");
516+
517+
optional_type.def(
518+
py::init([](const std::vector<std::string>& type_names,
519+
const ov::Output<ov::Node>& input,
520+
const Predicate& predicate) {
521+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, predicate);
522+
}),
523+
py::arg("type_names"),
524+
py::arg("input"),
525+
py::arg("predicate"),
526+
R"(
527+
Create Optional with the given node type, input node and predicate.
528+
529+
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
530+
:type type_names: List[str]
531+
532+
:param input: input node's output.
533+
:type input: openvino.runtime.Output
534+
535+
:param predicate: Function that performs additional checks for matching.
536+
:type predicate: function
537+
)");
538+
539+
optional_type.def(
540+
py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
541+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, nullptr);
542+
}),
543+
py::arg("type_names"),
544+
py::arg("input"),
545+
R"(
546+
Create Optional with the given node type and input node.
547+
548+
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
549+
:type type_names: List[str]
550+
551+
:param input: input node's output.
552+
:type input: openvino.runtime.Output
553+
)");
554+
555+
optional_type.def(
556+
py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
557+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, nullptr);
558+
}),
559+
py::arg("type_names"),
560+
py::arg("input"),
561+
R"(
562+
Create Optional with the given node type and input node.
563+
564+
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
565+
:type type_names: List[str]
566+
567+
:param input: input node
568+
:type input: openvino.runtime.Node
569+
)");
570+
571+
optional_type.def(py::init([](const std::vector<std::string>& type_names,
572+
const std::shared_ptr<ov::Node>& input,
573+
const Predicate& pred) {
574+
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, pred);
575+
}),
576+
py::arg("type_names"),
577+
py::arg("input"),
578+
py::arg("pred"),
579+
R"(
580+
Create Optional with the given node type, input node and predicate.
581+
582+
:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
583+
:type type_names: List[str]
584+
585+
:param input: input node
586+
:type input: openvino.runtime.Node
587+
588+
:param predicate: Function that performs additional checks for matching.
589+
:type predicate: function
590+
)");
591+
592+
optional_type.def("__repr__", [](const ov::pass::pattern::op::Optional& self) {
593+
return Common::get_simple_repr(self);
594+
});
595+
}
596+
485597
inline void reg_predicates(py::module m) {
486598
m.def("consumers_count", &ov::pass::pattern::consumers_count);
487599
m.def("has_static_dim", &ov::pass::pattern::has_static_dim);
@@ -497,5 +609,6 @@ void reg_passes_pattern_ops(py::module m) {
497609
reg_pattern_any_input(m);
498610
reg_pattern_wrap_type(m);
499611
reg_pattern_or(m);
612+
reg_pattern_optional(m);
500613
reg_predicates(m);
501614
}

src/bindings/python/tests/test_transformations/test_pattern_ops.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# Copyright (C) 2018-2024 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44
import numpy as np
5+
import pytest
56

67
from openvino import PartialShape
78
from openvino.runtime import opset13 as ops
8-
from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput
9+
from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput, Optional
910
from openvino.runtime.passes import (
1011
consumers_count,
1112
has_static_dim,
@@ -85,6 +86,99 @@ def test_any_input_predicate():
8586
assert not matcher.match(slope)
8687

8788

89+
def test_optional_full_match():
90+
model_input = ops.parameter(PartialShape.dynamic())
91+
model_abs = ops.abs(model_input)
92+
model_relu = ops.relu(model_abs.output(0))
93+
94+
pattern_abs = Optional(["opset13.Abs"])
95+
pattern_relu = ops.relu(pattern_abs.output(0))
96+
97+
matcher = Matcher(pattern_relu, "FindRelu")
98+
assert matcher.match(model_relu)
99+
100+
101+
@pytest.mark.skip("Optional is not working properly yet CVS-136454")
102+
def test_optional_half_match():
103+
model_input = ops.parameter(PartialShape.dynamic())
104+
model_relu = ops.relu(model_input)
105+
model_relu1 = ops.relu(model_relu.output(0))
106+
107+
pattern_abs = Optional(["opset13.Abs"])
108+
pattern_relu = ops.relu(pattern_abs.output(0))
109+
110+
matcher = Matcher(pattern_relu, "FindRelu")
111+
assert matcher.match(model_relu1)
112+
113+
114+
@pytest.mark.skip("Optional is not working properly yet CVS-136454")
115+
def test_optional_one_node():
116+
model_input = ops.parameter(PartialShape.dynamic())
117+
model_relu = ops.relu(model_input)
118+
model_abs = ops.abs(model_input)
119+
120+
assert Matcher(Optional(["opset13.Relu"]), "OneNodeTest").match(model_relu)
121+
assert not Matcher(Optional(["opset13.Abs"]), "OneNodeTest").match(model_relu)
122+
123+
assert not Matcher(Optional(["opset13.Relu"]), "OneNodeTest").match(model_abs)
124+
125+
assert Matcher(Optional(["opset13.Parameter"]), "OneNodeTest").match(ops.parameter(PartialShape.dynamic()))
126+
assert not Matcher(Optional(["opset13.Relu"]), "OneNodeTest").match(ops.parameter(PartialShape.dynamic()))
127+
128+
129+
@pytest.mark.skip("Optional is not working properly yet CVS-136454")
130+
def test_optional_predicate():
131+
model_input = ops.parameter(PartialShape.dynamic())
132+
model_add = ops.add(model_input, model_input)
133+
model_relu = ops.relu(model_add.output(0))
134+
model_abs = ops.abs(model_add.output(0))
135+
136+
assert Matcher(Optional(["opset13.Relu"], lambda x: True), "TestInputPredicate").match(model_relu)
137+
assert not Matcher(Optional(["opset13.Relu"], lambda x: False), "TestInputPredicate").match(model_relu)
138+
assert Matcher(Optional(["opset13.Add"], consumers_count(2)), "FindPredicate").match(model_add)
139+
assert not Matcher(Optional(["opset13.Add"], consumers_count(1)), "FindPredicate").match(model_add)
140+
assert Matcher(Optional(["opset13.Abs", "opset13.Result"], consumers_count(0)), "FindPredicate").match(model_abs)
141+
142+
143+
def test_optional_with_input():
144+
model_input = ops.parameter(PartialShape.dynamic())
145+
model_add = ops.add(model_input, model_input)
146+
model_relu = ops.relu(model_add.output(0))
147+
148+
assert Matcher(Optional(["opset13.Relu"], model_add.output(0)), "TestInput").match(model_relu)
149+
assert not Matcher(Optional(["opset13.Cos"], model_add.output(0)), "TestInput").match(model_relu)
150+
151+
152+
def test_optional_with_input_and_predicate():
153+
model_input = ops.parameter(PartialShape.dynamic())
154+
model_add = ops.add(model_input, model_input)
155+
model_relu = ops.relu(model_add.output(0))
156+
157+
pattern_add = ops.add(AnyInput(), AnyInput())
158+
159+
assert Matcher(Optional(["opset13.Relu"], pattern_add.output(0), lambda x: True), "TestInputPredicate").match(model_relu)
160+
assert not Matcher(Optional(["opset13.Relu"], pattern_add.output(0), lambda x: False), "TestInputPredicate").match(model_relu)
161+
162+
163+
def test_optional_with_input_node():
164+
model_input = ops.parameter(PartialShape.dynamic())
165+
model_add = ops.add(model_input, model_input)
166+
model_relu = ops.relu(model_add.output(0))
167+
168+
assert Matcher(Optional(["opset13.Relu"], model_add), "TestInputNode").match(model_relu)
169+
assert not Matcher(Optional(["opset13.Cos"], model_add), "TestInputNode").match(model_relu)
170+
171+
172+
def test_optional_with_input_node_and_predicate():
173+
model_input = ops.parameter(PartialShape.dynamic())
174+
model_add = ops.add(model_input, model_input)
175+
model_relu = ops.relu(model_add.output(0))
176+
177+
assert Matcher(Optional(["opset13.Relu"], model_add, lambda x: True), "TestInputNodePredicate").match(model_relu)
178+
assert not Matcher(Optional(["opset13.Relu"], model_add, lambda x: False), "TestInputNodePredicate").match(model_relu)
179+
assert not Matcher(Optional(["opset13.Cos"], model_add, lambda x: True), "TestInputNodePredicate").match(model_relu)
180+
181+
88182
def test_all_predicates():
89183
static_param = ops.parameter(PartialShape([1, 3, 22, 22]), np.float32)
90184
dynamic_param = ops.parameter(PartialShape([-1, 6]), np.compat.long)

src/core/tests/pattern.cpp

+36-17
Original file line numberDiff line numberDiff line change
@@ -510,37 +510,38 @@ TEST(pattern, matching_optional) {
510510
std::make_shared<op::v0::Abs>(c)));
511511
}
512512

513-
TEST(pattern, optional_full_match) {
513+
// Optional is not working properly yet CVS-136454
514+
TEST(pattern, DISABLED_optional_full_match) {
514515
Shape shape{};
515-
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
516-
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
517-
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
518-
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
516+
auto model_input = std::make_shared<op::v0::Parameter>(element::i32, shape);
517+
auto model_relu = std::make_shared<op::v0::Relu>(model_input);
518+
auto model_relu1 = std::make_shared<op::v0::Relu>(model_relu->output(0));
519519

520-
auto pattern_add = ov::pass::pattern::optional<op::v1::Add>();
521-
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_add->output(0));
520+
auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
521+
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));
522522

523523
TestMatcher tm;
524524

525-
ASSERT_TRUE(tm.match(pattern_relu, model_relu));
525+
ASSERT_TRUE(tm.match(pattern_relu1, model_relu1));
526526
}
527527

528-
TEST(pattern, optional_half_match) {
528+
// Optional is not working properly yet CVS-136454
529+
TEST(pattern, DISABLED_optional_half_match) {
529530
Shape shape{};
530-
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
531-
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
532-
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
533-
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
531+
auto model_input = std::make_shared<op::v0::Parameter>(element::i32, shape);
532+
auto model_relu = std::make_shared<op::v0::Relu>(model_input);
533+
auto model_relu1 = std::make_shared<op::v0::Relu>(model_relu->output(0));
534534

535-
auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
536-
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));
535+
auto pattern_abs = ov::pass::pattern::optional<op::v0::Abs>();
536+
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_abs->output(0));
537537

538538
TestMatcher tm;
539539

540-
ASSERT_TRUE(tm.match(pattern_relu1, model_relu));
540+
ASSERT_TRUE(tm.match(pattern_relu, model_relu1));
541541
}
542542

543-
TEST(pattern, optional_testing) {
543+
// Optional is not working properly yet CVS-136454
544+
TEST(pattern, DISABLED_optional_testing) {
544545
Shape shape{};
545546
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
546547
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
@@ -572,6 +573,24 @@ TEST(pattern, optional_testing) {
572573
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));
573574
}
574575

576+
// Optional is not working properly yet CVS-136454
577+
TEST(pattern, DISABLED_optional_one_node) {
578+
Shape shape{};
579+
auto model_input = std::make_shared<op::v0::Parameter>(element::i32, shape);
580+
auto model_relu = std::make_shared<op::v0::Relu>(model_input);
581+
auto model_abs = std::make_shared<op::v0::Abs>(model_input);
582+
583+
TestMatcher tm;
584+
585+
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(), model_relu));
586+
ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Abs>(), model_relu));
587+
588+
ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(), model_abs));
589+
590+
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Parameter>(), model_input));
591+
ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(), model_input));
592+
}
593+
575594
TEST(pattern, mean) {
576595
// construct mean
577596
TestMatcher n;

0 commit comments

Comments
 (0)