6
6
7
7
#include < functional>
8
8
#include < memory>
9
+ #include < ngraph/log.hpp>
9
10
#include < set>
10
11
11
12
#include " ngraph/pass/pass.hpp"
@@ -21,6 +22,69 @@ namespace pass {
21
22
using ov::pass::BackwardGraphRewrite;
22
23
using ov::pass::GraphRewrite;
23
24
using ov::pass::MatcherPass;
24
- using ov::pass::RecurrentGraphRewrite;
25
+
26
+ class NGRAPH_DEPRECATED (" Use MatcherPass or FunctionPass instead." ) NGRAPH_API RecurrentGraphRewrite
27
+ : public FunctionPass {
28
+ public:
29
+ RecurrentGraphRewrite (size_t num_iters = 10 ) : ModelPass (), m_num_iters (num_iters) {}
30
+
31
+ void add_matcher (const std::shared_ptr<pattern::RecurrentMatcher>& m,
32
+ const ov::recurrent_graph_rewrite_callback& callback,
33
+ const PassPropertyMask& property) {
34
+ NGRAPH_SUPPRESS_DEPRECATED_START
35
+ m_matchers.push_back (std::make_shared<MatcherPass>(
36
+ " Recurrent matcher" ,
37
+ nullptr ,
38
+ [m, callback](const std::shared_ptr<Node>& node) {
39
+ NGRAPH_DEBUG << " Running recurrent matcher on " << node;
40
+ if (m->match (node->output (0 ))) {
41
+ NGRAPH_DEBUG << " Recurrent matcher matched " << m.get ();
42
+ return callback (*m.get ());
43
+ }
44
+ return false ;
45
+ },
46
+ property));
47
+ NGRAPH_SUPPRESS_DEPRECATED_END
48
+ }
49
+
50
+ // TODO: This interface may deprecate after all passes are refactored.
51
+ void add_matcher (const std::shared_ptr<pattern::RecurrentMatcher>& m,
52
+ const ov::recurrent_graph_rewrite_callback& callback) {
53
+ NGRAPH_SUPPRESS_DEPRECATED_START
54
+ // TODO: before deprecate this function, by default expect the
55
+ // callback require static shape.
56
+ add_matcher (m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
57
+ NGRAPH_SUPPRESS_DEPRECATED_END
58
+ }
59
+
60
+ bool run_on_model (const std::shared_ptr<ov::Model>& m) override {
61
+ NGRAPH_SUPPRESS_DEPRECATED_START
62
+ bool changed = false ;
63
+ size_t i = 0 ;
64
+
65
+ auto run_matchers = [&]() -> bool {
66
+ for (const auto & node : m->get_ops ()) {
67
+ for (auto & m_pass : m_matchers) {
68
+ if (m_pass->apply (node)) {
69
+ return true ;
70
+ }
71
+ }
72
+ }
73
+ return false ;
74
+ };
75
+
76
+ do {
77
+ changed = run_matchers ();
78
+ i++;
79
+ } while (changed && i < m_num_iters);
80
+ return changed;
81
+ NGRAPH_SUPPRESS_DEPRECATED_END
82
+ }
83
+
84
+ private:
85
+ size_t m_num_iters;
86
+
87
+ std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
88
+ };
25
89
} // namespace pass
26
90
} // namespace ngraph
0 commit comments