Skip to content

Commit 523aa3e

Browse files
authored
Add GeluFusion to PreProcessing transformations pipeline (#29421)
### Details: In specific cases of Gelu pattern, Mul(input, const(0.5)) might be fused to Conv weights, the pattern can't be detected after that. We have to add Gelu to the preprocessing transformation pipeline. ### Tickets: - *CVS-163672*
1 parent f5309e1 commit 523aa3e

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

src/core/src/preprocess/pre_post_process.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "transformations/common_optimizations/convolution_to_group_convolution_fusion.hpp"
1717
#include "transformations/common_optimizations/disable_random_uniform_constant_folding.hpp"
1818
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
19+
#include "transformations/common_optimizations/gelu_fusion.hpp"
1920
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
2021
#include "transformations/common_optimizations/ric_fusion.hpp"
2122
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
@@ -89,12 +90,14 @@ void transformation_pipeline(std::shared_ptr<ov::Model>& model) {
8990

9091
// 2. Fusion transformations:
9192
REGISTER_PASS(manager, ConvertDivideWithConstant)
92-
auto multiply_fusions = manager.register_pass<GraphRewrite>();
93-
ADD_MATCHER(multiply_fusions, MultiplyConvolutionFusion)
94-
ADD_MATCHER(multiply_fusions, MultiplyGroupConvolutionFusion)
95-
ADD_MATCHER(multiply_fusions, MultiplyConvolutionBackpropDataFusion)
96-
ADD_MATCHER(multiply_fusions, MultiplyGroupConvolutionBackpropDataFusion)
97-
multiply_fusions->set_name("ov::pass::MultiplyFusions");
93+
auto fusions = manager.register_pass<GraphRewrite>();
94+
// Gelu fusion have to be executed before MulConv fusion because Mul(X, 0.5) might be fused to Conv weights
95+
ADD_MATCHER(fusions, GeluFusion)
96+
ADD_MATCHER(fusions, MultiplyConvolutionFusion)
97+
ADD_MATCHER(fusions, MultiplyGroupConvolutionFusion)
98+
ADD_MATCHER(fusions, MultiplyConvolutionBackpropDataFusion)
99+
ADD_MATCHER(fusions, MultiplyGroupConvolutionBackpropDataFusion)
100+
fusions->set_name("ov::pass::MultiplyFusions");
98101
REGISTER_PASS(manager, ReverseInputChannelsFusion)
99102

100103
// 3. CF call due to detected perf degradations

src/core/tests/preprocess.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5+
#define _USE_MATH_DEFINES
6+
7+
#include <math.h>
8+
59
#include "common_test_utils/ov_test_utils.hpp"
610
#include "common_test_utils/test_assertions.hpp"
711
#include "common_test_utils/test_tools.hpp"
@@ -2504,3 +2508,50 @@ TEST_F(TransformationTestsF, preprocessing_conv_decompression) {
25042508
model_ref = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{input});
25052509
}
25062510
}
2511+
2512+
TEST_F(TransformationTestsF, preprocessing_gelu_fusion) {
2513+
auto in_shape = Shape{1, 3, 32, 32};
2514+
auto in_type = element::f32;
2515+
auto weight_type = element::f32;
2516+
{
2517+
auto data = std::make_shared<ov::op::v0::Parameter>(in_type, in_shape);
2518+
2519+
auto mul_const_sqrt_1_2 = ov::op::v0::Constant::create(in_type, Shape{1}, {M_SQRT1_2});
2520+
auto mul_to_erf = std::make_shared<ov::op::v1::Multiply>(data, mul_const_sqrt_1_2);
2521+
auto erf = std::make_shared<ov::op::v0::Erf>(mul_to_erf);
2522+
2523+
auto add_const = ov::op::v0::Constant::create(in_type, Shape{1}, {1.0});
2524+
auto add = std::make_shared<ov::op::v1::Add>(erf, add_const);
2525+
auto mul_first = std::make_shared<ov::op::v1::Multiply>(data, add);
2526+
2527+
auto mul_const = ov::op::v0::Constant::create(in_type, Shape{1}, {0.5});
2528+
auto mul = std::make_shared<ov::op::v1::Multiply>(mul_first, mul_const);
2529+
2530+
std::shared_ptr<Node> weights = std::make_shared<op::v0::Constant>(weight_type, ov::Shape{1, 3, 3, 3}, 1);
2531+
auto conv = std::make_shared<op::v1::Convolution>(mul,
2532+
weights,
2533+
Strides{},
2534+
CoordinateDiff{},
2535+
CoordinateDiff{},
2536+
Strides{});
2537+
auto res = std::make_shared<op::v0::Result>(conv);
2538+
auto f = std::make_shared<ov::Model>(ov::ResultVector{res}, ov::ParameterVector{data});
2539+
auto p = PrePostProcessor(f);
2540+
model = p.build();
2541+
}
2542+
2543+
{
2544+
auto input = std::make_shared<op::v0::Parameter>(in_type, in_shape);
2545+
2546+
auto gelu = std::make_shared<op::v7::Gelu>(input);
2547+
auto weights = op::v0::Constant::create(weight_type, ov::Shape({1, 3, 3, 3}), {1.f});
2548+
auto conv = std::make_shared<op::v1::Convolution>(gelu,
2549+
weights,
2550+
Strides{},
2551+
CoordinateDiff{},
2552+
CoordinateDiff{},
2553+
Strides{});
2554+
auto res = std::make_shared<op::v0::Result>(conv);
2555+
model_ref = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{input});
2556+
}
2557+
}

0 commit comments

Comments
 (0)