@@ -46,25 +46,10 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
46
46
const auto dequantization_without_subtract_W = wrap_dequantization (ov::pass::pattern::any_input (), false );
47
47
const auto dequantization_without_subtract_R = wrap_dequantization (ov::pass::pattern::any_input (), false );
48
48
49
- auto X_in = std::make_shared<ov::pass::pattern::op::Or>(
50
- OutputVector{
51
- fq_X, dequantization_X, dequantization_without_subtract_X
52
- });
53
-
54
- auto H_in = std::make_shared<ov::pass::pattern::op::Or>(
55
- OutputVector{
56
- H_as_const, fq_H, dequantization_H, dequantization_without_subtract_H
57
- });
58
-
59
- auto W_in = std::make_shared<ov::pass::pattern::op::Or>(
60
- OutputVector{
61
- fq_W, dequantization_W, dequantization_without_subtract_W
62
- });
63
-
64
- auto R_in = std::make_shared<ov::pass::pattern::op::Or>(
65
- OutputVector{
66
- fq_R, dequantization_R, dequantization_without_subtract_R
67
- });
49
+ auto X_in = ov::pass::pattern::any_input ();
50
+ auto H_in = ov::pass::pattern::any_input ();
51
+ auto W_in = ov::pass::pattern::any_input ();
52
+ auto R_in = ov::pass::pattern::any_input ();
68
53
69
54
const auto lstm_seq = ov::pass::pattern::wrap_type<ov::opset5::LSTMSequence>(
70
55
{X_in, H_in, C, S, W_in, R_in, B});
@@ -91,8 +76,92 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
91
76
this ->register_matcher (m, callback);
92
77
}
93
78
79
+ namespace {
80
+
81
+ std::shared_ptr<ov::opset1::FakeQuantize> find_fake_quantize_upper (const std::shared_ptr<Node>& parent) {
82
+ if (is_type<ov::opset1::FakeQuantize>(parent)) {
83
+ return as_type_ptr<ov::opset1::FakeQuantize>(parent);
84
+ }
85
+
86
+ if (!NetworkHelper::isPrecisionPreserved (parent)) {
87
+ return nullptr ;
88
+ }
89
+
90
+ return find_fake_quantize_upper (parent->get_input_node_shared_ptr (0 ));
91
+ }
92
+
93
+ } // namespace
94
+
95
+ void RecurrentCellTransformation::propagate (TransformationContext& context, std::shared_ptr<ov::Node>& node) {
96
+ if (!NetworkHelper::isPrecisionPreserved (node)) {
97
+ return ;
98
+ }
99
+
100
+ const auto & normalized_node = NetworkHelper::separateInStandaloneBranch (node, defaultPrecisions);
101
+ auto dequantization = NetworkHelper::getDequantization (node, defaultPrecisions);
102
+ if (dequantization.empty ()) {
103
+ return ;
104
+ }
105
+ const auto & new_node = moveDequantizationAfter (context, normalized_node, dequantization);
106
+
107
+ const auto & new_dequantization = NetworkHelper::getDequantizationBelow (new_node);
108
+ if (new_dequantization.empty ()) {
109
+ return ;
110
+ }
111
+
112
+ for (auto output : new_dequantization.multiply ->outputs ()) {
113
+ for (auto input : output.get_target_inputs ()) {
114
+ auto & child = input.get_node ()->shared_from_this ();
115
+ propagate (context, child);
116
+ }
117
+ }
118
+ }
119
+
94
120
bool RecurrentCellTransformation::transform (TransformationContext& context, ov::pass::pattern::Matcher& m) {
95
121
const auto lstm = m.get_match_root ();
122
+
123
+ const auto inputs = is_type<ov::opset5::LSTMSequence>(lstm) ? std::vector<size_t >{0 , 1 , 4 , 5 } : std::vector<size_t >{0 , 1 , 3 , 4 };
124
+ for (const auto input : inputs) {
125
+ const auto & parent = lstm->get_input_node_shared_ptr (input);
126
+ if (!NetworkHelper::isPrecisionPreserved (parent)) {
127
+ continue ;
128
+ }
129
+
130
+ const auto & fq = find_fake_quantize_upper (parent);
131
+ if (fq != nullptr ) {
132
+ const auto & quantizationDetails = QuantizationDetails::getDetails (fq);
133
+ if ((quantizationDetails.inputLowValues .size () != 1 ) || (quantizationDetails.inputHighValues .size () != 1 ) ||
134
+ (quantizationDetails.outputLowValues .size () != 1 ) || (quantizationDetails.outputHighValues .size () != 1 )) {
135
+ continue ;
136
+ }
137
+
138
+ const auto & precisionsAttribute = getAttributeFromOutput<PrecisionsAttribute>(fq);
139
+ const auto & precisions = precisionsAttribute.empty () ?
140
+ defaultPrecisions :
141
+ precisionsAttribute.as <PrecisionsAttribute>().value ();
142
+ const auto & dataPrecision = getDataPrecision (fq, quantizationDetails, precisions);
143
+ if (dataPrecision.empty ()) {
144
+ continue ;
145
+ }
146
+
147
+ auto result = NetworkHelper::decomposeFakeQuantize (
148
+ fq,
149
+ dataPrecision.precision ,
150
+ dataPrecision.min ,
151
+ dataPrecision.max ,
152
+ dataPrecision.hasZeroPoint ,
153
+ updatePrecisions);
154
+ auto multiply = std::get<1 >(result);
155
+
156
+ for (const auto & output : multiply->outputs ()) {
157
+ for (const auto & input : output.get_target_inputs ()) {
158
+ const auto input_node = input.get_node ();
159
+ propagate (context, input_node->shared_from_this ());
160
+ }
161
+ }
162
+ }
163
+ }
164
+
96
165
if (!canBeTransformed (context, lstm)) {
97
166
return false ;
98
167
}
0 commit comments