|
2 | 2 | // SPDX-License-Identifier: Apache-2.0
|
3 | 3 | //
|
4 | 4 |
|
| 5 | +#define _USE_MATH_DEFINES |
| 6 | + |
| 7 | +#include <math.h> |
| 8 | + |
5 | 9 | #include "common_test_utils/ov_test_utils.hpp"
|
6 | 10 | #include "common_test_utils/test_assertions.hpp"
|
7 | 11 | #include "common_test_utils/test_tools.hpp"
|
@@ -2504,3 +2508,50 @@ TEST_F(TransformationTestsF, preprocessing_conv_decompression) {
|
2504 | 2508 | model_ref = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{input});
|
2505 | 2509 | }
|
2506 | 2510 | }
|
| 2511 | + |
| 2512 | +TEST_F(TransformationTestsF, preprocessing_gelu_fusion) { |
| 2513 | + auto in_shape = Shape{1, 3, 32, 32}; |
| 2514 | + auto in_type = element::f32; |
| 2515 | + auto weight_type = element::f32; |
| 2516 | + { |
| 2517 | + auto data = std::make_shared<ov::op::v0::Parameter>(in_type, in_shape); |
| 2518 | + |
| 2519 | + auto mul_const_sqrt_1_2 = ov::op::v0::Constant::create(in_type, Shape{1}, {M_SQRT1_2}); |
| 2520 | + auto mul_to_erf = std::make_shared<ov::op::v1::Multiply>(data, mul_const_sqrt_1_2); |
| 2521 | + auto erf = std::make_shared<ov::op::v0::Erf>(mul_to_erf); |
| 2522 | + |
| 2523 | + auto add_const = ov::op::v0::Constant::create(in_type, Shape{1}, {1.0}); |
| 2524 | + auto add = std::make_shared<ov::op::v1::Add>(erf, add_const); |
| 2525 | + auto mul_first = std::make_shared<ov::op::v1::Multiply>(data, add); |
| 2526 | + |
| 2527 | + auto mul_const = ov::op::v0::Constant::create(in_type, Shape{1}, {0.5}); |
| 2528 | + auto mul = std::make_shared<ov::op::v1::Multiply>(mul_first, mul_const); |
| 2529 | + |
| 2530 | + std::shared_ptr<Node> weights = std::make_shared<op::v0::Constant>(weight_type, ov::Shape{1, 3, 3, 3}, 1); |
| 2531 | + auto conv = std::make_shared<op::v1::Convolution>(mul, |
| 2532 | + weights, |
| 2533 | + Strides{}, |
| 2534 | + CoordinateDiff{}, |
| 2535 | + CoordinateDiff{}, |
| 2536 | + Strides{}); |
| 2537 | + auto res = std::make_shared<op::v0::Result>(conv); |
| 2538 | + auto f = std::make_shared<ov::Model>(ov::ResultVector{res}, ov::ParameterVector{data}); |
| 2539 | + auto p = PrePostProcessor(f); |
| 2540 | + model = p.build(); |
| 2541 | + } |
| 2542 | + |
| 2543 | + { |
| 2544 | + auto input = std::make_shared<op::v0::Parameter>(in_type, in_shape); |
| 2545 | + |
| 2546 | + auto gelu = std::make_shared<op::v7::Gelu>(input); |
| 2547 | + auto weights = op::v0::Constant::create(weight_type, ov::Shape({1, 3, 3, 3}), {1.f}); |
| 2548 | + auto conv = std::make_shared<op::v1::Convolution>(gelu, |
| 2549 | + weights, |
| 2550 | + Strides{}, |
| 2551 | + CoordinateDiff{}, |
| 2552 | + CoordinateDiff{}, |
| 2553 | + Strides{}); |
| 2554 | + auto res = std::make_shared<op::v0::Result>(conv); |
| 2555 | + model_ref = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{input}); |
| 2556 | + } |
| 2557 | +} |
0 commit comments