Skip to content

Commit 3056b53

Browse files
authoredJul 27, 2024
[LPT] Quantized LSTMSequence & GRUSequence extended support (openvinotoolkit#25654)
### Details: - *Low Precision Transformations: Quantized LSTMSequence & GRUSequence extended support* ### Tickets: - Current implementation for: *CVS-146067* - Will be changed in feature request: *CVS-147588*
1 parent 4a5bd43 commit 3056b53

File tree

16 files changed

+608
-76
lines changed

16 files changed

+608
-76
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "transparent_base_transformation.hpp"
8+
9+
namespace ov {
10+
namespace pass {
11+
namespace low_precision {
12+
13+
/**
14+
* @ingroup ov_transformation_common_api
15+
* @brief BroadcastTransformation propagates dequantization operations through Broadcast operation.
16+
*
17+
* For more details about the transformation, refer to
18+
* [BroadcastTransformation](@ref openvino_docs_OV_UG_lpt_BroadcastTransformation) page
19+
* in the OpenVINO Developer Guide.
20+
*/
21+
class LP_TRANSFORMATIONS_API BroadcastTransformation : public TransparentBaseTransformation {
22+
public:
23+
OPENVINO_RTTI("BroadcastTransformation", "0");
24+
BroadcastTransformation(const Params& params = Params());
25+
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<ov::Node> layer) const override;
26+
};
27+
28+
} // namespace low_precision
29+
} // namespace pass
30+
} // namespace ov

‎src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2022 Intel Corporation
1+
// Copyright (C) 2022-2024 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

@@ -23,6 +23,9 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform
2323
static std::shared_ptr<ov::Node> wrap_fake_quantize(const std::shared_ptr<ov::Node> parameter);
2424
static std::shared_ptr<ov::Node> wrap_quantization(const std::shared_ptr<ov::Node> parameter);
2525
static std::shared_ptr<ov::Node> wrap_dequantization(const std::shared_ptr<ov::Node> parameter, const bool with_subtract);
26+
27+
private:
28+
void propagate(TransformationContext& context, const std::shared_ptr<ov::Node> node);
2629
};
2730

2831
} // namespace low_precision
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "low_precision/broadcast.hpp"
6+
7+
#include <memory>
8+
9+
#include "openvino/opsets/opset1.hpp"
10+
#include "openvino/opsets/opset3.hpp"
11+
#include "openvino/pass/pattern/op/or.hpp"
12+
#include "openvino/pass/pattern/op/wrap_type.hpp"
13+
#include "low_precision/network_helper.hpp"
14+
15+
#include "itt.hpp"
16+
17+
using namespace ov::pass::low_precision;
18+
19+
BroadcastTransformation::BroadcastTransformation(const Params& params) : TransparentBaseTransformation(params) {
20+
MATCHER_SCOPE(BroadcastTransformation);
21+
auto broadcast1 = pattern::wrap_type<ov::opset1::Broadcast>({
22+
pattern::wrap_type<ov::opset1::Multiply>(),
23+
ov::pass::pattern::any_input(),
24+
ov::pass::pattern::any_input() });
25+
26+
auto broadcast3 = pattern::wrap_type<ov::opset3::Broadcast>({
27+
pattern::wrap_type<ov::opset1::Multiply>(),
28+
ov::pass::pattern::any_input(),
29+
ov::pass::pattern::any_input() });
30+
31+
const auto matcher = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{ broadcast1, broadcast3 });
32+
33+
ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
34+
auto op = m.get_match_root();
35+
if (transformation_callback(op)) {
36+
return false;
37+
}
38+
return transform(*context, m);
39+
};
40+
41+
auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
42+
this->register_matcher(m, callback);
43+
}
44+
45+
bool BroadcastTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<ov::Node> layer) const {
46+
if (!LayerTransformation::canBeTransformed(context, layer)) {
47+
return false;
48+
}
49+
50+
const auto& dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions);
51+
if (dequantization.empty()) {
52+
return false;
53+
}
54+
55+
if (dequantization.isPerTensor()) {
56+
return true;
57+
}
58+
59+
const auto& inputShape = layer->get_input_partial_shape(0);
60+
if (inputShape.rank().is_dynamic() || inputShape[dequantization.channelDimIndex].is_dynamic()) {
61+
return false;
62+
}
63+
64+
const auto targetShapeConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr(1));
65+
const auto& targetShape = targetShapeConstant->cast_vector<int64_t>();
66+
if (targetShape[dequantization.channelDimIndex] != inputShape[dequantization.channelDimIndex].get_length()) {
67+
return false;
68+
}
69+
70+
const auto axesMappingConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr(2));
71+
const auto& axesMapping = axesMappingConstant->cast_vector<int64_t>();
72+
if (static_cast<size_t>(axesMapping[dequantization.channelDimIndex]) != dequantization.channelDimIndex) {
73+
return false;
74+
}
75+
76+
return true;
77+
}

‎src/common/low_precision_transformations/src/layer_transformation.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ std::shared_ptr<ov::Node> LayerTransformation::moveDequantizationAfter(
401401
const FakeQuantizeDequantization& dequantization,
402402
const bool updateOutputPrecision,
403403
const bool moveSubtract) const {
404+
OPENVINO_ASSERT(!dequantization.empty());
404405
const auto result = ov::pass::low_precision::NetworkHelper::moveDequantizationAfter(operation,
405406
dequantization,
406407
updateOutputPrecision,

‎src/common/low_precision_transformations/src/low_precision.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "low_precision/assign_and_read_value.hpp"
4545
#include "low_precision/avg_pool.hpp"
4646
#include "low_precision/batch_to_space.hpp"
47+
#include "low_precision/broadcast.hpp"
4748
#include "low_precision/clamp.hpp"
4849
#include "low_precision/convolution.hpp"
4950
#include "low_precision/convolution_backprop_data.hpp"
@@ -240,6 +241,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
240241
ADD_MATCHER(common, AssignAndReadValueTransformation, f, params)
241242
ADD_MATCHER(common, AvgPoolTransformation, params)
242243
ADD_MATCHER(common, BatchToSpaceTransformation, params)
244+
ADD_MATCHER(common, BroadcastTransformation, params)
243245
ADD_MATCHER(common, ClampTransformation, params)
244246
ADD_MATCHER(common, ConcatTransformation, params)
245247
ADD_MATCHER(common, ConvolutionTransformation, params)

‎src/common/low_precision_transformations/src/markup_precisions.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "openvino/opsets/opset1.hpp"
1313
#include "openvino/opsets/opset2.hpp"
14+
#include "openvino/opsets/opset3.hpp"
1415
#include "openvino/opsets/opset4.hpp"
1516
#include "openvino/opsets/opset5.hpp"
1617
#include "openvino/opsets/opset6.hpp"
@@ -152,6 +153,8 @@ bool ov::pass::low_precision::MarkupPrecisions::isPrecisionPreserved(const std::
152153
{ name<opset1::Relu>() },
153154
// TODO: there are conditions
154155
{ name<opset2::BatchToSpace>() },
156+
{ name<opset1::Broadcast>() },
157+
{ name<opset3::Broadcast>() },
155158
{ name<opset1::Pad>() },
156159
{ name<ov::opset12::Pad>() },
157160
{ name<opset1::Reshape>() },
@@ -192,6 +195,8 @@ bool ov::pass::low_precision::MarkupPrecisions::isSupported(const std::shared_pt
192195
{ name<opset1::Add>() },
193196
{ name<opset1::AvgPool>() },
194197
{ name<opset2::BatchToSpace>() },
198+
{ name<opset1::Broadcast>() },
199+
{ name<opset3::Broadcast>() },
195200
{ name<opset1::Clamp>() },
196201
{ name<opset1::Concat>() },
197202
// ?

0 commit comments

Comments
 (0)