|
| 1 | +// Copyright (C) 2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "transformations/op_conversions/fake_convert_decomposition.hpp" |
| 6 | + |
| 7 | +#include <gtest/gtest.h> |
| 8 | + |
| 9 | +#include "common_test_utils/common_utils.hpp" |
| 10 | +#include "common_test_utils/ov_test_utils.hpp" |
| 11 | +#include "openvino/opsets/opset1.hpp" |
| 12 | +#include "openvino/opsets/opset13.hpp" |
| 13 | + |
| 14 | +using namespace ov; |
| 15 | + |
| 16 | +using FakeConvertDecompositionParams = std::tuple<Shape, // data shape |
| 17 | + Shape, // scale shape |
| 18 | + Shape, // shift shape |
| 19 | + element::Type_t, // input precision |
| 20 | + element::Type_t, // destination precision |
| 21 | + bool>; // default shift |
| 22 | + |
| 23 | +class FakeConvertDecompositionTest : public ov::test::TestsCommon, |
| 24 | + public ::testing::WithParamInterface<FakeConvertDecompositionParams> { |
| 25 | +public: |
| 26 | + static std::string getTestCaseName(::testing::TestParamInfo<FakeConvertDecompositionParams> obj) { |
| 27 | + FakeConvertDecompositionParams params = obj.param; |
| 28 | + |
| 29 | + Shape data_shape, scale_shape, shift_shape; |
| 30 | + element::Type_t data_prec, dst_prec; |
| 31 | + bool default_shift; |
| 32 | + std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params; |
| 33 | + |
| 34 | + std::ostringstream result; |
| 35 | + result << "dataShape=" << ov::test::utils::vec2str(data_shape) << "_"; |
| 36 | + result << "scaleShape=" << ov::test::utils::vec2str(scale_shape) << "_"; |
| 37 | + result << "shiftShape=" << ov::test::utils::vec2str(shift_shape) << "_"; |
| 38 | + result << "dataPrecision=" << element::Type(data_prec) << "_"; |
| 39 | + result << "destinationPrecision=" << element::Type(dst_prec) << "_"; |
| 40 | + if (default_shift) |
| 41 | + result << "defaultShift=true"; |
| 42 | + else |
| 43 | + result << "defaultShift=false"; |
| 44 | + return result.str(); |
| 45 | + } |
| 46 | +}; |
| 47 | + |
| 48 | +TEST_P(FakeConvertDecompositionTest, CompareFunctions) { |
| 49 | + FakeConvertDecompositionParams params = this->GetParam(); |
| 50 | + |
| 51 | + Shape data_shape, scale_shape, shift_shape; |
| 52 | + element::Type_t data_prec, dst_prec; |
| 53 | + bool default_shift; |
| 54 | + std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params; |
| 55 | + |
| 56 | + std::shared_ptr<ov::Model> model(nullptr); |
| 57 | + { |
| 58 | + const auto data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape)); |
| 59 | + const auto scale = std::make_shared<opset1::Constant>(data_prec, scale_shape); |
| 60 | + const auto shift = std::make_shared<opset1::Constant>(data_prec, shift_shape); |
| 61 | + |
| 62 | + const auto fake_convert = default_shift ? std::make_shared<opset13::FakeConvert>(data, scale, dst_prec) |
| 63 | + : std::make_shared<opset13::FakeConvert>(data, scale, shift, dst_prec); |
| 64 | + model = std::make_shared<ov::Model>(NodeVector{fake_convert}, ParameterVector{data}); |
| 65 | + |
| 66 | + pass::Manager manager; |
| 67 | + manager.register_pass<ov::pass::InitNodeInfo>(); |
| 68 | + manager.register_pass<ov::pass::FakeConvertDecomposition>(); |
| 69 | + manager.run_passes(model); |
| 70 | + |
| 71 | + OV_ASSERT_NO_THROW(check_rt_info(model)); |
| 72 | + } |
| 73 | + |
| 74 | + std::shared_ptr<ov::Model> model_ref(nullptr); |
| 75 | + { |
| 76 | + const auto input_data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape)); |
| 77 | + const auto input_scale = std::make_shared<opset1::Constant>(data_prec, scale_shape); |
| 78 | + const auto input_shift = std::make_shared<opset1::Constant>(data_prec, shift_shape); |
| 79 | + ParameterVector params; |
| 80 | + params.push_back(input_data); |
| 81 | + std::shared_ptr<Node> data = input_data; |
| 82 | + |
| 83 | + std::shared_ptr<Node> result; |
| 84 | + const auto scale = std::make_shared<ov::op::v1::Multiply>(data, input_scale); |
| 85 | + if (default_shift) { |
| 86 | + const auto downconvert = std::make_shared<ov::op::v0::Convert>(scale, dst_prec); |
| 87 | + const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec); |
| 88 | + |
| 89 | + result = std::make_shared<ov::op::v1::Divide>(upconvert, input_scale); |
| 90 | + } else { |
| 91 | + const auto shift = std::make_shared<ov::op::v1::Subtract>(scale, input_shift); |
| 92 | + |
| 93 | + const auto downconvert = std::make_shared<ov::op::v0::Convert>(shift, dst_prec); |
| 94 | + const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec); |
| 95 | + |
| 96 | + const auto deshift = std::make_shared<ov::op::v1::Add>(upconvert, input_shift); |
| 97 | + result = std::make_shared<ov::op::v1::Divide>(deshift, input_scale); |
| 98 | + } |
| 99 | + |
| 100 | + model_ref = std::make_shared<ov::Model>(NodeVector{result}, params); |
| 101 | + } |
| 102 | + |
| 103 | + const auto res = compare_functions(model, model_ref); |
| 104 | + ASSERT_TRUE(res.first) << res.second; |
| 105 | +} |
| 106 | + |
| 107 | +const std::vector<element::Type_t> data_precisions = {element::Type_t::f32, |
| 108 | + element::Type_t::f16, |
| 109 | + element::Type_t::bf16}; |
| 110 | + |
| 111 | +const std::vector<element::Type_t> destination_precisions = {element::Type_t::f8e4m3, element::Type_t::f8e5m2}; |
| 112 | + |
| 113 | +const std::vector<bool> default_shift = {true, false}; |
| 114 | + |
| 115 | +const auto simple_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}), |
| 116 | + ::testing::Values(Shape{1}), |
| 117 | + ::testing::Values(Shape{1}), |
| 118 | + ::testing::ValuesIn(data_precisions), |
| 119 | + ::testing::ValuesIn(destination_precisions), |
| 120 | + ::testing::ValuesIn(default_shift)); |
| 121 | + |
| 122 | +const auto broadcast_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}), |
| 123 | + ::testing::Values(Shape{2, 3, 1, 1}), |
| 124 | + ::testing::Values(Shape{2, 3, 1, 1}), |
| 125 | + ::testing::ValuesIn(data_precisions), |
| 126 | + ::testing::ValuesIn(destination_precisions), |
| 127 | + ::testing::ValuesIn(default_shift)); |
| 128 | + |
| 129 | +const auto elementwise_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}), |
| 130 | + ::testing::Values(Shape{2, 3, 4, 5}), |
| 131 | + ::testing::Values(Shape{2, 3, 4, 5}), |
| 132 | + ::testing::ValuesIn(data_precisions), |
| 133 | + ::testing::ValuesIn(destination_precisions), |
| 134 | + ::testing::ValuesIn(default_shift)); |
| 135 | + |
| 136 | +INSTANTIATE_TEST_SUITE_P(SimpleFakeConvert_Decomposition, |
| 137 | + FakeConvertDecompositionTest, |
| 138 | + simple_fake_convert_params, |
| 139 | + FakeConvertDecompositionTest::getTestCaseName); |
| 140 | + |
| 141 | +INSTANTIATE_TEST_SUITE_P(BroadcastFakeConvert_Decomposition, |
| 142 | + FakeConvertDecompositionTest, |
| 143 | + broadcast_fake_convert_params, |
| 144 | + FakeConvertDecompositionTest::getTestCaseName); |
| 145 | + |
| 146 | +INSTANTIATE_TEST_SUITE_P(ElementwiseFakeConvert_Decomposition, |
| 147 | + FakeConvertDecompositionTest, |
| 148 | + elementwise_fake_convert_params, |
| 149 | + FakeConvertDecompositionTest::getTestCaseName); |
0 commit comments