Skip to content

Commit e7840ff

Browse files
committed
[LPT] Quantized LSTM & GRU extended support
1 parent 596c2e8 commit e7840ff

File tree

14 files changed

+239
-35
lines changed

14 files changed

+239
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "transparent_base_transformation.hpp"
8+
9+
namespace ov {
10+
namespace pass {
11+
namespace low_precision {
12+
13+
/**
14+
* @ingroup ov_transformation_common_api
15+
* @brief BroadcastTransformation propagates dequantization operations through Broadcast operation.
16+
*
17+
* For more details about the transformation, refer to
18+
* [BroadcastTransformation](@ref openvino_docs_OV_UG_lpt_BroadcastTransformation) page
19+
* in the OpenVINO Developer Guide.
20+
*/
21+
class LP_TRANSFORMATIONS_API BroadcastTransformation : public TransparentBaseTransformation {
22+
public:
23+
OPENVINO_RTTI("BroadcastTransformation", "0");
24+
BroadcastTransformation(const Params& params = Params());
25+
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<ov::Node> layer) const override;
26+
};
27+
28+
} // namespace low_precision
29+
} // namespace pass
30+
} // namespace ov

src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform
2323
static std::shared_ptr<ov::Node> wrap_fake_quantize(const std::shared_ptr<ov::Node> parameter);
2424
static std::shared_ptr<ov::Node> wrap_quantization(const std::shared_ptr<ov::Node> parameter);
2525
static std::shared_ptr<ov::Node> wrap_dequantization(const std::shared_ptr<ov::Node> parameter, const bool with_subtract);
26+
27+
private:
28+
void propagate(TransformationContext& context, std::shared_ptr<ov::Node>& node);
2629
};
2730

2831
} // namespace low_precision

src/common/low_precision_transformations/include/low_precision/rt_info/precision_preserved_attribute.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class LP_TRANSFORMATIONS_API PrecisionPreservedAttribute : public SharedAttribut
2626

2727
PrecisionPreservedAttribute() = default;
2828
PrecisionPreservedAttribute(const bool value);
29+
bool is_copyable() const override;
30+
bool is_copyable(const std::shared_ptr<Node>& to) const override;
2931

3032
std::string to_string() const override;
3133
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "low_precision/broadcast.hpp"
6+
7+
#include <memory>
8+
#include "openvino/pass/pattern/op/wrap_type.hpp"
9+
#include "low_precision/network_helper.hpp"
10+
#include "itt.hpp"
11+
12+
using namespace ov::pass::low_precision;
13+
14+
BroadcastTransformation::BroadcastTransformation(const Params& params) : TransparentBaseTransformation(params) {
15+
MATCHER_SCOPE(BroadcastTransformation);
16+
auto matcher = pattern::wrap_type<ov::opset1::Broadcast>({
17+
pattern::wrap_type<ov::opset1::Multiply>(),
18+
ov::pass::pattern::any_input(),
19+
ov::pass::pattern::any_input() });
20+
21+
ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
22+
auto op = m.get_match_root();
23+
if (transformation_callback(op)) {
24+
return false;
25+
}
26+
return transform(*context, m);
27+
};
28+
29+
auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
30+
this->register_matcher(m, callback);
31+
}
32+
33+
bool BroadcastTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<ov::Node> layer) const {
34+
if (layer->get_friendly_name() == "model/bidirectional/backward_lstm_1/zeros_1") {
35+
std::cout << "BroadcastTransformation::canBeTransformed: " << layer->get_friendly_name() << std::endl;
36+
}
37+
if (!LayerTransformation::canBeTransformed(context, layer)) {
38+
return false;
39+
}
40+
41+
const auto& dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
42+
if (dequantization.multiply != nullptr) {
43+
if (!NetworkHelper::isScalarLike(dequantization.multiplyConstant)) {
44+
return false;
45+
}
46+
}
47+
48+
if (dequantization.subtract != nullptr) {
49+
if (!NetworkHelper::isScalarLike(dequantization.subtractConstant)) {
50+
return false;
51+
}
52+
}
53+
54+
return true;
55+
}

src/common/low_precision_transformations/src/layer_transformation.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ std::shared_ptr<ov::Node> LayerTransformation::moveDequantizationAfter(
401401
const FakeQuantizeDequantization& dequantization,
402402
const bool updateOutputPrecision,
403403
const bool moveSubtract) const {
404+
OPENVINO_ASSERT(!dequantization.empty());
404405
const auto result = ov::pass::low_precision::NetworkHelper::moveDequantizationAfter(operation,
405406
dequantization,
406407
updateOutputPrecision,

src/common/low_precision_transformations/src/low_precision.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "low_precision/assign_and_read_value.hpp"
4545
#include "low_precision/avg_pool.hpp"
4646
#include "low_precision/batch_to_space.hpp"
47+
#include "low_precision/broadcast.hpp"
4748
#include "low_precision/clamp.hpp"
4849
#include "low_precision/convolution.hpp"
4950
#include "low_precision/convolution_backprop_data.hpp"
@@ -240,6 +241,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
240241
ADD_MATCHER(common, AssignAndReadValueTransformation, f, params)
241242
ADD_MATCHER(common, AvgPoolTransformation, params)
242243
ADD_MATCHER(common, BatchToSpaceTransformation, params)
244+
ADD_MATCHER(common, BroadcastTransformation, params)
243245
ADD_MATCHER(common, ClampTransformation, params)
244246
ADD_MATCHER(common, ConcatTransformation, params)
245247
ADD_MATCHER(common, ConvolutionTransformation, params)

src/common/low_precision_transformations/src/markup_precisions.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ bool ov::pass::low_precision::MarkupPrecisions::isPrecisionPreserved(const std::
152152
{ name<opset1::Relu>() },
153153
// TODO: there are conditions
154154
{ name<opset2::BatchToSpace>() },
155+
{ name<opset1::Broadcast>() },
155156
{ name<opset1::Pad>() },
156157
{ name<ov::opset12::Pad>() },
157158
{ name<opset1::Reshape>() },
@@ -192,6 +193,7 @@ bool ov::pass::low_precision::MarkupPrecisions::isSupported(const std::shared_pt
192193
{ name<opset1::Add>() },
193194
{ name<opset1::AvgPool>() },
194195
{ name<opset2::BatchToSpace>() },
196+
{ name<opset2::Broadcast>() },
195197
{ name<opset1::Clamp>() },
196198
{ name<opset1::Concat>() },
197199
// ?

src/common/low_precision_transformations/src/recurrent_cell.cpp

+88-19
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,10 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
4646
const auto dequantization_without_subtract_W = wrap_dequantization(ov::pass::pattern::any_input(), false);
4747
const auto dequantization_without_subtract_R = wrap_dequantization(ov::pass::pattern::any_input(), false);
4848

49-
auto X_in = std::make_shared<ov::pass::pattern::op::Or>(
50-
OutputVector{
51-
fq_X, dequantization_X, dequantization_without_subtract_X
52-
});
53-
54-
auto H_in = std::make_shared<ov::pass::pattern::op::Or>(
55-
OutputVector{
56-
H_as_const, fq_H, dequantization_H, dequantization_without_subtract_H
57-
});
58-
59-
auto W_in = std::make_shared<ov::pass::pattern::op::Or>(
60-
OutputVector{
61-
fq_W, dequantization_W, dequantization_without_subtract_W
62-
});
63-
64-
auto R_in = std::make_shared<ov::pass::pattern::op::Or>(
65-
OutputVector{
66-
fq_R, dequantization_R, dequantization_without_subtract_R
67-
});
49+
auto X_in = ov::pass::pattern::any_input();
50+
auto H_in = ov::pass::pattern::any_input();
51+
auto W_in = ov::pass::pattern::any_input();
52+
auto R_in = ov::pass::pattern::any_input();
6853

6954
const auto lstm_seq = ov::pass::pattern::wrap_type<ov::opset5::LSTMSequence>(
7055
{X_in, H_in, C, S, W_in, R_in, B});
@@ -91,8 +76,92 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
9176
this->register_matcher(m, callback);
9277
}
9378

79+
namespace {
80+
81+
std::shared_ptr<ov::opset1::FakeQuantize> find_fake_quantize_upper(const std::shared_ptr<Node>& parent) {
82+
if (is_type<ov::opset1::FakeQuantize>(parent)) {
83+
return as_type_ptr<ov::opset1::FakeQuantize>(parent);
84+
}
85+
86+
if (!NetworkHelper::isPrecisionPreserved(parent)) {
87+
return nullptr;
88+
}
89+
90+
return find_fake_quantize_upper(parent->get_input_node_shared_ptr(0));
91+
}
92+
93+
} // namespace
94+
95+
void RecurrentCellTransformation::propagate(TransformationContext& context, std::shared_ptr<ov::Node>& node) {
96+
if (!NetworkHelper::isPrecisionPreserved(node)) {
97+
return;
98+
}
99+
100+
const auto& normalized_node = NetworkHelper::separateInStandaloneBranch(node, defaultPrecisions);
101+
auto dequantization = NetworkHelper::getDequantization(node, defaultPrecisions);
102+
if (dequantization.empty()) {
103+
return;
104+
}
105+
const auto& new_node = moveDequantizationAfter(context, normalized_node, dequantization);
106+
107+
const auto& new_dequantization = NetworkHelper::getDequantizationBelow(new_node);
108+
if (new_dequantization.empty()) {
109+
return;
110+
}
111+
112+
for (auto output : new_dequantization.multiply->outputs()) {
113+
for (auto input : output.get_target_inputs()) {
114+
auto& child = input.get_node()->shared_from_this();
115+
propagate(context, child);
116+
}
117+
}
118+
}
119+
94120
bool RecurrentCellTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
95121
const auto lstm = m.get_match_root();
122+
123+
const auto inputs = is_type<ov::opset5::LSTMSequence>(lstm) ? std::vector<size_t>{0, 1, 4, 5} : std::vector<size_t>{0, 1, 3, 4};
124+
for (const auto input : inputs) {
125+
const auto& parent = lstm->get_input_node_shared_ptr(input);
126+
if (!NetworkHelper::isPrecisionPreserved(parent)) {
127+
continue;
128+
}
129+
130+
const auto& fq = find_fake_quantize_upper(parent);
131+
if (fq != nullptr) {
132+
const auto& quantizationDetails = QuantizationDetails::getDetails(fq);
133+
if ((quantizationDetails.inputLowValues.size() != 1) || (quantizationDetails.inputHighValues.size() != 1) ||
134+
(quantizationDetails.outputLowValues.size() != 1) || (quantizationDetails.outputHighValues.size() != 1)) {
135+
continue;
136+
}
137+
138+
const auto& precisionsAttribute = getAttributeFromOutput<PrecisionsAttribute>(fq);
139+
const auto& precisions = precisionsAttribute.empty() ?
140+
defaultPrecisions :
141+
precisionsAttribute.as<PrecisionsAttribute>().value();
142+
const auto& dataPrecision = getDataPrecision(fq, quantizationDetails, precisions);
143+
if (dataPrecision.empty()) {
144+
continue;
145+
}
146+
147+
auto result = NetworkHelper::decomposeFakeQuantize(
148+
fq,
149+
dataPrecision.precision,
150+
dataPrecision.min,
151+
dataPrecision.max,
152+
dataPrecision.hasZeroPoint,
153+
updatePrecisions);
154+
auto multiply = std::get<1>(result);
155+
156+
for (const auto& output : multiply->outputs()) {
157+
for (const auto& input : output.get_target_inputs()) {
158+
const auto input_node = input.get_node();
159+
propagate(context, input_node->shared_from_this());
160+
}
161+
}
162+
}
163+
}
164+
96165
if (!canBeTransformed(context, lstm)) {
97166
return false;
98167
}

src/common/low_precision_transformations/src/rt_info/precision_preserved_attribute.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,11 @@ std::string PrecisionPreservedAttribute::to_string() const {
2020
ss << "value: " << (value() ? "true" : "false");
2121
return ss.str();
2222
}
23+
24+
bool PrecisionPreservedAttribute::is_copyable() const {
25+
return false;
26+
}
27+
28+
bool PrecisionPreservedAttribute::is_copyable(const std::shared_ptr<Node>& to) const {
29+
return false;
30+
}

src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
9292
::testing::ValuesIn(weights_shapes),
9393
::testing::Values(ov::test::utils::DEVICE_CPU),
9494
::testing::ValuesIn(trasformationParamValues),
95+
::testing::ValuesIn({ true, false }),
9596
::testing::ValuesIn(params)),
9697
RecurrentCellTransformation::getTestCaseName);
9798
} // namespace testValues1
@@ -171,6 +172,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
171172
::testing::ValuesIn(weights_shapes),
172173
::testing::Values(ov::test::utils::DEVICE_CPU),
173174
::testing::ValuesIn(trasformationParamValues),
175+
::testing::ValuesIn({ true, false }),
174176
::testing::ValuesIn(params)),
175177
RecurrentCellTransformation::getTestCaseName);
176178
} // namespace testValues2

src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ typedef std::tuple<
4242
std::vector<ov::Shape>,
4343
std::string,
4444
ov::pass::low_precision::LayerTransformation::Params,
45+
bool, // use precision transparent operations
4546
RecurrentCellTransformationParam
4647
>RecurrentCellTransformationParams;
4748

src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ std::string RecurrentCellTransformation::getTestCaseName(testing::TestParamInfo<
2121
std::string targetDevice;
2222
RecurrentCellTransformationParam param;
2323
ov::pass::low_precision::LayerTransformation::Params params;
24-
std::tie(netPrecision, activationsShape, weightsShape, targetDevice, params, param) = obj.param;
24+
bool addPrecisionTransparentOperations;
25+
std::tie(netPrecision, activationsShape, weightsShape, targetDevice, params, addPrecisionTransparentOperations, param) = obj.param;
2526

2627
std::ostringstream result;
2728
result << get_test_case_name_by_params(netPrecision, activationsShape[0], targetDevice, params) <<
2829
"FQ_X_" << param.fakeQuantize_X << "_" <<
2930
"DQ_X_" << param.dequantization_X << "_" <<
3031
"FQ_W_" << param.fakeQuantize_W << "_" <<
31-
"DQ_W_" << param.dequantization_W;
32+
"DQ_W_" << param.dequantization_W << "_" <<
33+
"PTO" << addPrecisionTransparentOperations;
3234
return result.str();
3335
}
3436

@@ -37,9 +39,10 @@ void RecurrentCellTransformation::SetUp() {
3739
std::vector<ov::PartialShape> activations_shapes;
3840
std::vector<ov::Shape> weights_shapes;
3941
RecurrentCellTransformationParam param;
42+
bool addPrecisionTransparentOperations;
4043
ov::pass::low_precision::LayerTransformation::Params params;
4144

42-
std::tie(precision, activations_shapes, weights_shapes, targetDevice, params, param) = this->GetParam();
45+
std::tie(precision, activations_shapes, weights_shapes, targetDevice, params, addPrecisionTransparentOperations, param) = this->GetParam();
4346

4447
init_input_shapes(activations_shapes);
4548

@@ -64,13 +67,15 @@ void RecurrentCellTransformation::SetUp() {
6467
param.dequantization_H,
6568
param.dequantization_W,
6669
param.dequantization_R
67-
});
70+
},
71+
addPrecisionTransparentOperations);
72+
ov::pass::Serialize("test.original.xml", "test.original.bin").run_on_model(function);
6873
}
6974

7075
void RecurrentCellTransformation::run() {
7176
LayerTransformation::run();
7277

73-
const auto params = std::get<5>(GetParam());
78+
const auto params = std::get<6>(GetParam());
7479
const auto actualPrecision = get_runtime_precision_by_type(params.layerName);
7580
auto expectedPrecision = params.expectedKernelType;
7681
if (expectedPrecision == "FP32" && std::get<0>(GetParam()) == ov::element::f16) {

src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/recurrent_cell.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ class RecurrentCellFunction {
2525
const RNNType type,
2626
const std::vector<FakeQuantizeOnDataWithConstant>& fqOnDatas,
2727
const std::vector<DequantizationOperations::Convert>& converts,
28-
const std::vector<DequantizationOperations>& dequantizations);
28+
const std::vector<DequantizationOperations>& dequantizations,
29+
const bool addPrecisionTransparentOperations);
2930
};
3031

3132
std::shared_ptr<Node> makeQuantizationAndDequantization(const std::shared_ptr<Node> input,
3233
const ov::element::Type inputPrecision,
3334
const std::string friendly_name,
3435
const FakeQuantizeOnDataWithConstant& fqOnData,
3536
const DequantizationOperations::Convert& convert,
36-
const DequantizationOperations& dequantization);
37+
const DequantizationOperations& dequantization,
38+
const bool addPrecisionTransparentOperations = false);
3739
} // namespace subgraph
3840
} // namespace builder
3941
} // namespace ov

0 commit comments

Comments
 (0)