|
38 | 38 | #include "openvino/pass/graph_rewrite.hpp"
|
39 | 39 | #include "openvino/pass/manager.hpp"
|
40 | 40 | #include "openvino/pass/pattern/matcher.hpp"
|
41 |
| -#include "openvino/pass/pattern/op/branch.hpp" |
42 | 41 | #include "openvino/pass/pattern/op/label.hpp"
|
43 | 42 | #include "openvino/pass/pattern/op/optional.hpp"
|
44 | 43 | #include "openvino/pass/pattern/op/or.hpp"
|
@@ -310,16 +309,10 @@ TEST(pattern, matcher) {
|
310 | 309 | ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
|
311 | 310 |
|
312 | 311 | auto abs = make_shared<op::v0::Abs>(a);
|
313 |
| - auto any = std::make_shared<pattern::op::Skip>(a); |
314 |
| - ASSERT_TRUE(n.match(any, abs)); |
315 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a})); |
316 | 312 |
|
317 | 313 | auto false_pred = [](std::shared_ptr<Node> /* no */) {
|
318 | 314 | return false;
|
319 | 315 | };
|
320 |
| - auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred); |
321 |
| - ASSERT_TRUE(n.match(any_false, a)); |
322 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a})); |
323 | 316 |
|
324 | 317 | auto pattern = std::make_shared<pattern::op::Label>(a);
|
325 | 318 | ASSERT_TRUE(n.match(pattern, a));
|
@@ -371,39 +364,6 @@ TEST(pattern, matcher) {
|
371 | 364 | ASSERT_FALSE(n.match(std::make_shared<op::v1::Add>(abs, b), std::make_shared<op::v1::Add>(b, b)));
|
372 | 365 | ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
|
373 | 366 |
|
374 |
| - auto add_absb = std::make_shared<op::v1::Add>(abs, b); |
375 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(any, b), add_absb)); |
376 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b})); |
377 |
| - |
378 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(pattern, b), add_absb)); |
379 |
| - ASSERT_EQ(n.get_pattern_map()[pattern], abs); |
380 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b})); |
381 |
| - |
382 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(b, pattern), add_absb)); |
383 |
| - ASSERT_EQ(n.get_pattern_map()[pattern], abs); |
384 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b})); |
385 |
| - |
386 |
| - auto c = make_shared<op::v0::Parameter>(element::i32, shape); |
387 |
| - auto mul_add_absb = std::make_shared<op::v1::Multiply>(c, add_absb); |
388 |
| - ASSERT_TRUE( |
389 |
| - n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(b, pattern)), mul_add_absb)); |
390 |
| - ASSERT_EQ(n.get_pattern_map()[pattern], abs); |
391 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b})); |
392 |
| - |
393 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any, b)), |
394 |
| - mul_add_absb)); // nested any |
395 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b})); |
396 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any, b)), |
397 |
| - std::make_shared<op::v1::Multiply>(std::make_shared<op::v1::Add>(b, abs), |
398 |
| - c))); // permutations w/ any |
399 |
| - auto mul_c_add_ab = make_shared<op::v1::Multiply>(c, add_ab); |
400 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any_false, b)), |
401 |
| - std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(a, b)))); // |
402 |
| - // nested any |
403 |
| - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any_false, b)), |
404 |
| - mul_c_add_ab)); // permutations w/ any_false |
405 |
| - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b})); |
406 |
| - |
407 | 367 | auto iconst1_0 = construct_constant_node(1);
|
408 | 368 | auto iconst1_1 = construct_constant_node(1);
|
409 | 369 | ASSERT_TRUE(n.match(make_shared<op::v1::Multiply>(pattern, iconst1_0),
|
@@ -462,18 +422,6 @@ TEST(pattern, matcher) {
|
462 | 422 | std::make_shared<op::v1::Subtract>(a, b)}),
|
463 | 423 | std::make_shared<op::v1::Subtract>(a, b)));
|
464 | 424 |
|
465 |
| - // Branch |
466 |
| - { |
467 |
| - auto branch = std::make_shared<pattern::op::Branch>(); |
468 |
| - auto star = std::make_shared<pattern::op::Or>(OutputVector{branch, std::make_shared<pattern::op::True>()}); |
469 |
| - auto pattern = std::make_shared<op::v1::Add>(star, star); |
470 |
| - branch->set_destination(pattern); |
471 |
| - auto arg = |
472 |
| - std::make_shared<op::v1::Add>(std::make_shared<op::v1::Add>(a, b), std::make_shared<op::v1::Add>(b, a)); |
473 |
| - ASSERT_TRUE(n.match(pattern, std::make_shared<op::v1::Add>(arg, a))); |
474 |
| - ASSERT_EQ(n.get_matched_nodes().size(), 4); |
475 |
| - } |
476 |
| - |
477 | 425 | // strict mode
|
478 | 426 | {
|
479 | 427 | TestMatcher sm(Output<Node>{}, "TestMatcher", true);
|
@@ -959,47 +907,6 @@ TEST(pattern, test_sort) {
|
959 | 907 | }
|
960 | 908 | }
|
961 | 909 |
|
962 |
| -TEST(pattern, label_on_skip) { |
963 |
| - const auto zero = std::string{"0"}; |
964 |
| - const auto is_zero = [&zero](const Output<Node>& node) { |
965 |
| - if (const auto c = as_type_ptr<op::v0::Constant>(node.get_node_shared_ptr())) { |
966 |
| - return (c->get_all_data_elements_bitwise_identical() && c->convert_value_to_string(0) == zero); |
967 |
| - } else { |
968 |
| - return false; |
969 |
| - } |
970 |
| - }; |
971 |
| - |
972 |
| - Shape shape{2, 2}; |
973 |
| - auto a = make_shared<op::v0::Parameter>(element::i32, shape); |
974 |
| - auto b = make_shared<op::v0::Parameter>(element::i32, Shape{}); |
975 |
| - auto iconst = op::v0::Constant::create(element::i32, Shape{}, {0.0f}); |
976 |
| - auto label = std::make_shared<pattern::op::Label>(iconst); |
977 |
| - auto const_label = std::make_shared<pattern::op::Label>(iconst, is_zero, NodeVector{iconst}); |
978 |
| - |
979 |
| - auto bcst_pred = [](std::shared_ptr<Node> n) { |
980 |
| - return ov::as_type_ptr<op::v1::Broadcast>(n) != nullptr; |
981 |
| - }; |
982 |
| - |
983 |
| - auto shape_const = ov::op::v0::Constant::create(element::u64, Shape{shape.size()}, shape); |
984 |
| - auto axes_const = ov::op::v0::Constant::create(element::u8, Shape{}, {0}); |
985 |
| - auto bcst = std::make_shared<pattern::op::Skip>(OutputVector{const_label, shape_const, axes_const}, bcst_pred); |
986 |
| - auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst}); |
987 |
| - auto matcher = |
988 |
| - std::make_shared<pattern::Matcher>(std::make_shared<op::v1::Multiply>(label, bcst_label), "label_on_skip"); |
989 |
| - |
990 |
| - auto const_broadcast = make_shared<op::v1::Broadcast>(iconst, shape_const); |
991 |
| - std::shared_ptr<Node> mul = std::make_shared<op::v1::Multiply>(a, const_broadcast); |
992 |
| - std::shared_ptr<Node> mul_scalar = std::make_shared<op::v1::Multiply>(b, iconst); |
993 |
| - ASSERT_TRUE(matcher->match(mul)); |
994 |
| - ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast); |
995 |
| - ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst); |
996 |
| - ASSERT_EQ(matcher->get_pattern_map()[label], a); |
997 |
| - ASSERT_TRUE(matcher->match(mul_scalar)); |
998 |
| - ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst); |
999 |
| - ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst); |
1000 |
| - ASSERT_EQ(matcher->get_pattern_map()[label], b); |
1001 |
| -} |
1002 |
| - |
1003 | 910 | TEST(pattern, is_contained_match) {
|
1004 | 911 | Shape shape{};
|
1005 | 912 | auto a = make_shared<op::v0::Parameter>(element::i32, shape);
|
|
0 commit comments