Skip to content

Commit 2ef42d4

Browse files
authored
[CPU] [TRANSFORMATIONS] FakeConvert decomposition transformation (#28118)
### Details: - *Implement FakeConvert decomposition transformation.* ### Tickets: - *[CVS-156963](https://jira.devtools.intel.com/browse/CVS-156963)* ### Prerequisites: - *#27949
1 parent 82d553e commit 2ef42d4

File tree

8 files changed

+426
-0
lines changed

8 files changed

+426
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/matcher_pass.hpp"
8+
#include "transformations_visibility.hpp"
9+
10+
namespace ov {
11+
namespace pass {
12+
13+
class TRANSFORMATIONS_API FakeConvertDecomposition;
14+
15+
} // namespace pass
16+
} // namespace ov
17+
18+
/**
19+
* @ingroup ov_transformation_common_api
20+
* @brief FakeConvertDecomposition transformation decomposes FakeConvert layer.
21+
* f8: f8e4m3, f8e5m2
22+
* downconvert: f32->f8, f16->f8, bf16->f8
23+
* upconvert: f8->f32, f8->f16, f8->bf16
24+
* output = (upconvert(downconvert(input * scale - shift)) + shift) / scale
25+
*
26+
*/
27+
28+
class ov::pass::FakeConvertDecomposition : public ov::pass::MatcherPass {
29+
public:
30+
OPENVINO_MATCHER_PASS_RTTI("FakeConvertDecomposition");
31+
FakeConvertDecomposition();
32+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/op_conversions/fake_convert_decomposition.hpp"
6+
7+
#include "itt.hpp"
8+
#include "openvino/core/rt_info.hpp"
9+
#include "openvino/op/add.hpp"
10+
#include "openvino/op/constant.hpp"
11+
#include "openvino/op/convert.hpp"
12+
#include "openvino/op/divide.hpp"
13+
#include "openvino/op/fake_convert.hpp"
14+
#include "openvino/op/multiply.hpp"
15+
#include "openvino/op/subtract.hpp"
16+
#include "openvino/pass/pattern/op/wrap_type.hpp"
17+
18+
ov::pass::FakeConvertDecomposition::FakeConvertDecomposition() {
19+
MATCHER_SCOPE(FakeConvertDecomposition);
20+
auto data = pattern::any_input();
21+
22+
auto fake_convert = ov::pass::pattern::wrap_type<ov::op::v13::FakeConvert>();
23+
24+
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
25+
auto& pattern_to_output = m.get_pattern_value_map();
26+
const auto fake_convert_node =
27+
ov::as_type_ptr<ov::op::v13::FakeConvert>(pattern_to_output.at(fake_convert).get_node_shared_ptr());
28+
29+
if (fake_convert_node == nullptr || transformation_callback(fake_convert_node)) {
30+
return false;
31+
}
32+
33+
Output<Node> data{fake_convert_node->input_value(0)};
34+
const Output<Node> input_scale{fake_convert_node->input_value(1)};
35+
auto input_type = data.get_element_type();
36+
37+
ov::pass::NodeRegistry decomp_ops;
38+
if (input_type != input_scale.get_element_type()) {
39+
input_type = input_scale.get_element_type();
40+
data = std::make_shared<ov::op::v0::Convert>(data, input_type);
41+
data = decomp_ops.add(data.get_node_shared_ptr());
42+
}
43+
44+
std::shared_ptr<Node> result;
45+
const auto scale = decomp_ops.make<ov::op::v1::Multiply>(data, input_scale);
46+
if (fake_convert_node->get_input_size() == 2) {
47+
const auto downconvert =
48+
decomp_ops.make<ov::op::v0::Convert>(scale, fake_convert_node->get_destination_element_type());
49+
const auto upconvert = decomp_ops.make<ov::op::v0::Convert>(downconvert, input_type);
50+
51+
result = decomp_ops.make<ov::op::v1::Divide>(upconvert, input_scale);
52+
} else {
53+
const Output<Node> input_shift{fake_convert_node->input_value(2)};
54+
const auto shift = decomp_ops.make<ov::op::v1::Subtract>(scale, input_shift);
55+
56+
const auto downconvert =
57+
decomp_ops.make<ov::op::v0::Convert>(shift, fake_convert_node->get_destination_element_type());
58+
const auto upconvert = decomp_ops.make<ov::op::v0::Convert>(downconvert, input_type);
59+
60+
const auto deshift = decomp_ops.make<ov::op::v1::Add>(upconvert, input_shift);
61+
result = decomp_ops.make<ov::op::v1::Divide>(deshift, input_scale);
62+
}
63+
64+
if (result->get_output_element_type(0) != fake_convert_node->get_output_element_type(0)) {
65+
result = decomp_ops.make<ov::op::v0::Convert>(result, fake_convert_node->get_output_element_type(0));
66+
}
67+
68+
result->set_friendly_name(m.get_match_root()->get_friendly_name());
69+
ov::copy_runtime_info(fake_convert_node, decomp_ops.get());
70+
ov::replace_node(m.get_match_root(), result);
71+
return true;
72+
};
73+
74+
auto m = std::make_shared<ov::pass::pattern::Matcher>(fake_convert, matcher_name);
75+
register_matcher(m, callback);
76+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/op_conversions/fake_convert_decomposition.hpp"
6+
7+
#include <gtest/gtest.h>
8+
9+
#include "common_test_utils/common_utils.hpp"
10+
#include "common_test_utils/ov_test_utils.hpp"
11+
#include "openvino/opsets/opset1.hpp"
12+
#include "openvino/opsets/opset13.hpp"
13+
14+
using namespace ov;
15+
16+
using FakeConvertDecompositionParams = std::tuple<Shape, // data shape
17+
Shape, // scale shape
18+
Shape, // shift shape
19+
element::Type_t, // input precision
20+
element::Type_t, // destination precision
21+
bool>; // default shift
22+
23+
class FakeConvertDecompositionTest : public ov::test::TestsCommon,
24+
public ::testing::WithParamInterface<FakeConvertDecompositionParams> {
25+
public:
26+
static std::string getTestCaseName(::testing::TestParamInfo<FakeConvertDecompositionParams> obj) {
27+
FakeConvertDecompositionParams params = obj.param;
28+
29+
Shape data_shape, scale_shape, shift_shape;
30+
element::Type_t data_prec, dst_prec;
31+
bool default_shift;
32+
std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params;
33+
34+
std::ostringstream result;
35+
result << "dataShape=" << ov::test::utils::vec2str(data_shape) << "_";
36+
result << "scaleShape=" << ov::test::utils::vec2str(scale_shape) << "_";
37+
result << "shiftShape=" << ov::test::utils::vec2str(shift_shape) << "_";
38+
result << "dataPrecision=" << element::Type(data_prec) << "_";
39+
result << "destinationPrecision=" << element::Type(dst_prec) << "_";
40+
if (default_shift)
41+
result << "defaultShift=true";
42+
else
43+
result << "defaultShift=false";
44+
return result.str();
45+
}
46+
};
47+
48+
TEST_P(FakeConvertDecompositionTest, CompareFunctions) {
49+
FakeConvertDecompositionParams params = this->GetParam();
50+
51+
Shape data_shape, scale_shape, shift_shape;
52+
element::Type_t data_prec, dst_prec;
53+
bool default_shift;
54+
std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params;
55+
56+
std::shared_ptr<ov::Model> model(nullptr);
57+
{
58+
const auto data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
59+
const auto scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
60+
const auto shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);
61+
62+
const auto fake_convert = default_shift ? std::make_shared<opset13::FakeConvert>(data, scale, dst_prec)
63+
: std::make_shared<opset13::FakeConvert>(data, scale, shift, dst_prec);
64+
model = std::make_shared<ov::Model>(NodeVector{fake_convert}, ParameterVector{data});
65+
66+
pass::Manager manager;
67+
manager.register_pass<ov::pass::InitNodeInfo>();
68+
manager.register_pass<ov::pass::FakeConvertDecomposition>();
69+
manager.run_passes(model);
70+
71+
OV_ASSERT_NO_THROW(check_rt_info(model));
72+
}
73+
74+
std::shared_ptr<ov::Model> model_ref(nullptr);
75+
{
76+
const auto input_data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
77+
const auto input_scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
78+
const auto input_shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);
79+
ParameterVector params;
80+
params.push_back(input_data);
81+
std::shared_ptr<Node> data = input_data;
82+
83+
std::shared_ptr<Node> result;
84+
const auto scale = std::make_shared<ov::op::v1::Multiply>(data, input_scale);
85+
if (default_shift) {
86+
const auto downconvert = std::make_shared<ov::op::v0::Convert>(scale, dst_prec);
87+
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);
88+
89+
result = std::make_shared<ov::op::v1::Divide>(upconvert, input_scale);
90+
} else {
91+
const auto shift = std::make_shared<ov::op::v1::Subtract>(scale, input_shift);
92+
93+
const auto downconvert = std::make_shared<ov::op::v0::Convert>(shift, dst_prec);
94+
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);
95+
96+
const auto deshift = std::make_shared<ov::op::v1::Add>(upconvert, input_shift);
97+
result = std::make_shared<ov::op::v1::Divide>(deshift, input_scale);
98+
}
99+
100+
model_ref = std::make_shared<ov::Model>(NodeVector{result}, params);
101+
}
102+
103+
const auto res = compare_functions(model, model_ref);
104+
ASSERT_TRUE(res.first) << res.second;
105+
}
106+
107+
const std::vector<element::Type_t> data_precisions = {element::Type_t::f32,
108+
element::Type_t::f16,
109+
element::Type_t::bf16};
110+
111+
const std::vector<element::Type_t> destination_precisions = {element::Type_t::f8e4m3, element::Type_t::f8e5m2};
112+
113+
const std::vector<bool> default_shift = {true, false};
114+
115+
const auto simple_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}),
116+
::testing::Values(Shape{1}),
117+
::testing::Values(Shape{1}),
118+
::testing::ValuesIn(data_precisions),
119+
::testing::ValuesIn(destination_precisions),
120+
::testing::ValuesIn(default_shift));
121+
122+
const auto broadcast_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}),
123+
::testing::Values(Shape{2, 3, 1, 1}),
124+
::testing::Values(Shape{2, 3, 1, 1}),
125+
::testing::ValuesIn(data_precisions),
126+
::testing::ValuesIn(destination_precisions),
127+
::testing::ValuesIn(default_shift));
128+
129+
const auto elementwise_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}),
130+
::testing::Values(Shape{2, 3, 4, 5}),
131+
::testing::Values(Shape{2, 3, 4, 5}),
132+
::testing::ValuesIn(data_precisions),
133+
::testing::ValuesIn(destination_precisions),
134+
::testing::ValuesIn(default_shift));
135+
136+
INSTANTIATE_TEST_SUITE_P(SimpleFakeConvert_Decomposition,
137+
FakeConvertDecompositionTest,
138+
simple_fake_convert_params,
139+
FakeConvertDecompositionTest::getTestCaseName);
140+
141+
INSTANTIATE_TEST_SUITE_P(BroadcastFakeConvert_Decomposition,
142+
FakeConvertDecompositionTest,
143+
broadcast_fake_convert_params,
144+
FakeConvertDecompositionTest::getTestCaseName);
145+
146+
INSTANTIATE_TEST_SUITE_P(ElementwiseFakeConvert_Decomposition,
147+
FakeConvertDecompositionTest,
148+
elementwise_fake_convert_params,
149+
FakeConvertDecompositionTest::getTestCaseName);

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
#include "transformations/op_conversions/detection_output_downgrade.hpp"
8181
#include "transformations/op_conversions/detection_output_upgrade.hpp"
8282
#include "transformations/op_conversions/eye_decomposition.hpp"
83+
#include "transformations/op_conversions/fake_convert_decomposition.hpp"
8384
#include "transformations/op_conversions/fq_decomposition.hpp"
8485
#include "transformations/op_conversions/gelu7_downgrade.hpp"
8586
#include "transformations/op_conversions/group_normalization_decomposition.hpp"
@@ -1293,6 +1294,7 @@ void Transformations::PostSnippets(void) {
12931294
return node::FakeQuantize::isSupportedOperation(node, errMsg);
12941295
},
12951296
ov::pass::FakeQuantizeDecomposition);
1297+
CPU_REGISTER_PASS_COMMON(postSnippetsManager, ov::pass::FakeConvertDecomposition);
12961298
CPU_REGISTER_PASS_COMMON(postSnippetsManager, ov::pass::ConstantFolding);
12971299
postSnippetsManager.run_passes(model);
12981300
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "single_op_tests/fake_convert.hpp"
6+
7+
namespace {
8+
using ov::test::FakeConvertLayerTest;
9+
10+
const std::vector<std::vector<ov::Shape>> shapes = {{{2, 3, 4, 5}}};
11+
12+
const std::vector<ov::element::Type> data_precisions = {ov::element::f32, ov::element::f16, ov::element::bf16};
13+
14+
const std::vector<ov::element::Type> destination_precisions = {ov::element::f8e4m3, ov::element::f8e5m2};
15+
16+
const std::vector<bool> default_shift = {true, false};
17+
18+
const auto simple_fake_convert_params =
19+
::testing::Combine(::testing::ValuesIn(ov::test::static_shapes_to_test_representation(shapes)),
20+
::testing::Values(ov::Shape{1}),
21+
::testing::Values(ov::Shape{1}),
22+
::testing::ValuesIn(data_precisions),
23+
::testing::ValuesIn(destination_precisions),
24+
::testing::ValuesIn(default_shift),
25+
::testing::Values(ov::test::utils::DEVICE_CPU));
26+
27+
const auto broadcast_fake_convert_params =
28+
::testing::Combine(::testing::ValuesIn(ov::test::static_shapes_to_test_representation(shapes)),
29+
::testing::Values(ov::Shape{2, 3, 1, 1}),
30+
::testing::Values(ov::Shape{2, 3, 1, 1}),
31+
::testing::ValuesIn(data_precisions),
32+
::testing::ValuesIn(destination_precisions),
33+
::testing::ValuesIn(default_shift),
34+
::testing::Values(ov::test::utils::DEVICE_CPU));
35+
36+
const auto elementwise_fake_convert_params =
37+
::testing::Combine(::testing::ValuesIn(ov::test::static_shapes_to_test_representation(shapes)),
38+
::testing::Values(ov::Shape{2, 3, 4, 5}),
39+
::testing::Values(ov::Shape{2, 3, 4, 5}),
40+
::testing::ValuesIn(data_precisions),
41+
::testing::ValuesIn(destination_precisions),
42+
::testing::ValuesIn(default_shift),
43+
::testing::Values(ov::test::utils::DEVICE_CPU));
44+
45+
INSTANTIATE_TEST_SUITE_P(smoke_FakeConvert_simple,
46+
FakeConvertLayerTest,
47+
simple_fake_convert_params,
48+
FakeConvertLayerTest::getTestCaseName);
49+
50+
INSTANTIATE_TEST_SUITE_P(smoke_FakeConvert_broadcast,
51+
FakeConvertLayerTest,
52+
broadcast_fake_convert_params,
53+
FakeConvertLayerTest::getTestCaseName);
54+
55+
INSTANTIATE_TEST_SUITE_P(smoke_FakeConvert_elementwise,
56+
FakeConvertLayerTest,
57+
elementwise_fake_convert_params,
58+
FakeConvertLayerTest::getTestCaseName);
59+
} // namespace
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "shared_test_classes/single_op/fake_convert.hpp"
8+
9+
namespace ov {
10+
namespace test {
11+
12+
TEST_P(FakeConvertLayerTest, Inference) {
13+
run();
14+
}
15+
} // namespace test
16+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "shared_test_classes/base/ov_subgraph.hpp"
8+
9+
namespace ov {
10+
namespace test {
11+
using FakeConvertParams = std::tuple<std::vector<InputShape>, // Data shape
12+
Shape, // Scale shape
13+
Shape, // Shift shape
14+
ov::element::Type, // Input precision
15+
ov::element::Type, // Ddestination precision
16+
bool, // Default shift
17+
std::string>; // Device name
18+
19+
class FakeConvertLayerTest : public testing::WithParamInterface<FakeConvertParams>,
20+
virtual public ov::test::SubgraphBaseTest {
21+
public:
22+
static std::string getTestCaseName(const testing::TestParamInfo<FakeConvertParams>& obj);
23+
24+
protected:
25+
void SetUp() override;
26+
};
27+
} // namespace test
28+
} // namespace ov

0 commit comments

Comments
 (0)