Skip to content

Commit f81729e

Browse files
committed
Refactor clamp tests to use a simple function and update expected node types
1 parent cb49167 commit f81729e

File tree

1 file changed

+7
-54
lines changed

1 file changed

+7
-54
lines changed

src/core/tests/preprocess.cpp

+7-54
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5-
#include <openvino/runtime/core.hpp>
6-
75
#include "common_test_utils/ov_test_utils.hpp"
86
#include "common_test_utils/test_assertions.hpp"
97
#include "common_test_utils/test_tools.hpp"
108
#include "gtest/gtest.h"
119
#include "openvino/core/except.hpp"
12-
#include "openvino/core/model.hpp"
1310
#include "openvino/core/preprocess/pre_post_process.hpp"
14-
#include "openvino/core/shape.hpp"
15-
#include "openvino/core/type.hpp"
1611
#include "openvino/opsets/opset8.hpp"
1712
#include "openvino/util/common_util.hpp"
1813
#include "preprocess/color_utils.hpp"
@@ -129,57 +124,13 @@ TEST(pre_post_process, simple_mean_scale_getters_f64) {
129124
EXPECT_EQ(f->get_output_element_type(0), element::f64);
130125
}
131126

132-
TEST(pre_post_process, simple_clamp_f16) {
133-
auto f = create_clamp_function(element::f16, Shape{1, 3, 2, 2}, 0.0f, 1.0f);
134-
auto p = PrePostProcessor(f);
135-
p.input("tensor_input").preprocess().clamp(0.0f, 1.0f);
136-
p.output("tensor_output").postprocess().clamp(0.0f, 1.0f);
137-
f = p.build();
138-
EXPECT_EQ(f->get_output_element_type(0), element::f16);
139-
140-
EXPECT_EQ(f->input().get_shape(), (Shape{1, 3, 2, 2}));
141-
EXPECT_EQ(f->output().get_shape(), (Shape{1, 3, 2, 2}));
142-
}
143-
144-
TEST(pre_post_process, simple_clamp_f64) {
145-
auto f = create_clamp_function(element::f64, Shape{1, 3, 2, 2}, 0.0f, 1.0f);
146-
auto p = PrePostProcessor(f);
147-
p.input("tensor_input").preprocess().clamp(0.0f, 1.0f);
148-
p.output("tensor_output").postprocess().clamp(0.0f, 1.0f);
149-
f = p.build();
150-
EXPECT_EQ(f->get_output_element_type(0), element::f64);
127+
TEST(pre_post_process, clamp_operation_on_input_preprocess) {
128+
auto model = create_simple_function(element::f32, Shape{1, 3, 2, 2});
151129

152-
EXPECT_EQ(f->input().get_shape(), (Shape{1, 3, 2, 2}));
153-
EXPECT_EQ(f->output().get_shape(), (Shape{1, 3, 2, 2}));
154-
}
155-
156-
class PreprocessClampTest : public ::testing::Test {
157-
protected:
158-
std::shared_ptr<Model> model;
159-
std::shared_ptr<op::v0::Parameter> input;
160-
161-
void SetUp() override {
162-
input = std::make_shared<op::v0::Parameter>(element::f64, Shape{1});
163-
input->set_friendly_name("input");
164-
input->get_output_tensor(0).set_names({"tensor_input"});
165-
166-
auto add_op = std::make_shared<op::v1::Add>(input, input);
167-
add_op->set_friendly_name("Add");
168-
add_op->get_output_tensor(0).set_names({"tensor_add"});
169-
170-
auto result = std::make_shared<op::v0::Result>(add_op);
171-
result->set_friendly_name("Result");
172-
result->get_output_tensor(0).set_names({"tensor_output"});
173-
174-
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input});
175-
}
176-
};
177-
178-
TEST_F(PreprocessClampTest, clamp_operation_on_input_preprocess) {
179130
{
180131
auto input_node = model->get_parameters().front();
181132
auto connected_node = input_node->output(0).get_target_inputs().begin()->get_node();
182-
EXPECT_STREQ(connected_node->get_type_name(), "Add");
133+
EXPECT_STREQ(connected_node->get_type_name(), "Relu");
183134
}
184135
auto p = PrePostProcessor(model);
185136
p.input().preprocess().clamp(0.0, 1.0);
@@ -191,11 +142,13 @@ TEST_F(PreprocessClampTest, clamp_operation_on_input_preprocess) {
191142
}
192143
}
193144

194-
TEST_F(PreprocessClampTest, clamp_operation_on_output_postprocess) {
145+
TEST(pre_post_process, clamp_operation_on_output_postprocess) {
146+
auto model = create_simple_function(element::f32, Shape{1, 3, 2, 2});
147+
195148
{
196149
auto result_node = model->get_results().front();
197150
auto connected_node = result_node->input_value(0).get_node_shared_ptr();
198-
EXPECT_STREQ(connected_node->get_type_name(), "Add");
151+
EXPECT_STREQ(connected_node->get_type_name(), "Relu");
199152
}
200153
auto p = PrePostProcessor(model);
201154
p.output().postprocess().clamp(0.0, 1.0);

0 commit comments

Comments
 (0)