Skip to content

Commit 7593a1e

Browse files
authored
Fix MarkDequantization transformation (#29151)
### Details: Move the precision check to the predicate It will fix the multiply commutativity to match both cases : mul(lp_weights, scale) and mul(scale, lp_weights) ### Tickets: - *CVS-160006* CVS-161724
1 parent e6b90c3 commit 7593a1e

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

src/common/low_precision_transformations/tests/mark_dequantization_subgraph_transformation.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,37 @@ TEST_F(TransformationTestsF, KeepConstPrecision2BranchesSameShapes) {
176176
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
177177
}
178178

179+
TEST_F(TransformationTestsF, MarkDequantizationScaleOnTheLeftBranch) {
180+
{
181+
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{27}, 1);
182+
auto scale = opset10::Constant::create(element::i64, Shape{}, {2});
183+
184+
auto convert_lp = std::make_shared<opset10::Convert>(lp_const, element::f32);
185+
auto convert_scale = std::make_shared<opset10::Convert>(scale, element::f32);
186+
auto multiply = std::make_shared<opset10::Multiply>(convert_scale, convert_lp);
187+
auto stub_op = std::make_shared<opset10::Relu>(multiply);
188+
model = std::make_shared<Model>(stub_op, ParameterVector{});
189+
}
190+
191+
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u4});
192+
manager.register_pass<pass::ConstantFolding>();
193+
194+
{
195+
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{27}, 1);
196+
auto scale = opset10::Constant::create(element::f32, Shape{}, {2});
197+
198+
auto convert_lp = std::make_shared<opset10::Convert>(lp_const, element::f32);
199+
auto multiply = std::make_shared<opset10::Multiply>(scale, convert_lp);
200+
auto stub_op = std::make_shared<opset10::Relu>(multiply);
201+
model_ref = std::make_shared<Model>(stub_op, ParameterVector{});
202+
203+
mark_as_dequantization_node(multiply);
204+
ov::pass::disable_constant_folding(convert_lp);
205+
}
206+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
207+
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
208+
}
209+
179210
TEST_F(TransformationTestsF, KeepConstPrecision) {
180211
{
181212
auto lp_const = std::make_shared<opset10::Constant>(element::u4, Shape{27}, 1);

src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "openvino/op/unsqueeze.hpp"
1212
#include "openvino/pass/manager.hpp"
1313
#include "openvino/pass/pattern/op/optional.hpp"
14+
#include "openvino/pass/pattern/op/pattern.hpp"
1415
#include "openvino/pass/pattern/op/wrap_type.hpp"
1516
#include "transformations/rt_info/dequantization_node.hpp"
1617
#include "transformations/rt_info/disable_constant_folding.hpp"
@@ -23,9 +24,13 @@ using namespace ov::pass::pattern;
2324

2425
namespace {
2526

26-
bool check_precision(const ov::element::Type_t type_to_check, const ov::element::TypeVector& precisions) {
27-
return std::find(precisions.begin(), precisions.end(), type_to_check) != precisions.end();
28-
};
27+
ov::pass::pattern::op::Predicate check_precision(const ov::element::TypeVector& precisions) {
28+
return ov::pass::pattern::op::Predicate(
29+
[=](const Output<Node>& output) -> bool {
30+
return std::find(precisions.begin(), precisions.end(), output.get_element_type()) != precisions.end();
31+
},
32+
"check_precision");
33+
}
2934

3035
using RTInfoSetter = std::function<void(const std::shared_ptr<ov::Node>& node)>;
3136
void set_rt_info(const PatternValueMap& pt_map,
@@ -35,10 +40,9 @@ void set_rt_info(const PatternValueMap& pt_map,
3540
for (const auto& pattern_node : pattern_nodes) {
3641
if (pt_map.count(pattern_node)) {
3742
auto node = pt_map.at(pattern_node).get_node_shared_ptr();
38-
3943
// we don't need to mark Converts with disable_cf attribute if the `from` type (input type)
4044
// is not in the `precisions` list.
41-
if (ov::as_type_ptr<v0::Convert>(node) && !check_precision(node->get_input_element_type(0), precisions)) {
45+
if (ov::as_type_ptr<v0::Convert>(node) && !check_precision(precisions)(node->input_value(0))) {
4246
continue;
4347
}
4448

@@ -196,7 +200,7 @@ ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& prec
196200
MATCHER_SCOPE(MarkDequantization);
197201

198202
// data input:
199-
auto input_pattern = any_input();
203+
auto input_pattern = any_input(check_precision(precisions));
200204
auto convert_pattern = wrap_type<v0::Convert>({input_pattern}, consumers_count(1));
201205

202206
// zero points:
@@ -217,7 +221,7 @@ ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& prec
217221
auto input = pt_map.at(input_pattern);
218222
const auto multiply = m.get_match_root();
219223

220-
if (!check_precision(input.get_element_type(), precisions) || transformation_callback(multiply)) {
224+
if (transformation_callback(multiply)) {
221225
return false;
222226
}
223227

@@ -290,8 +294,7 @@ ov::pass::KeepConstPrecision::KeepConstPrecision(const element::TypeVector& prec
290294
for (const auto& pattern_node : keep_const_precisions) {
291295
if (pt_map.count(pattern_node.first)) {
292296
auto node = pt_map.at(pattern_node.first).get_node_shared_ptr();
293-
const auto& precision = node->get_output_element_type(0);
294-
if (ov::as_type_ptr<v0::Constant>(node) && check_precision(precision, precisions)) {
297+
if (ov::as_type_ptr<v0::Constant>(node) && check_precision(precisions)(node->output(0))) {
295298
if (pattern_node.second) {
296299
ov::disable_keep_const_precision(node);
297300
} else {

src/core/src/preprocess/pre_post_process.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
1919
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
2020
#include "transformations/common_optimizations/ric_fusion.hpp"
21+
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
2122
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
2223
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
2324
#include "transformations/op_conversions/convert_divide.hpp"
@@ -74,6 +75,9 @@ void transformation_pipeline(std::shared_ptr<ov::Model>& model) {
7475
Manager manager("pre_post_processing");
7576
manager.set_per_pass_validation(false);
7677

78+
// prerequisite: the model structure optimization before applying of the markup
79+
REGISTER_PASS(manager, SharedOpOptimization)
80+
7781
// 1. Set "disable_const_folding" attribute
7882
REGISTER_PASS(manager, MarkDequantization, TypeVector{i8, u8, i4, u4, nf4, f4e2m1, f8e4m3, f8e5m2, f8e8m0});
7983
REGISTER_PASS(manager, DisableShapeOfConstantFolding, false);

0 commit comments

Comments
 (0)