Skip to content

Commit 3c713d4

Browse files
authored
[CPU] Avoid rounding to zero for Reduce node in quantized models (#25766)
### Details: - *If the Reduce node has both input and output precision to be integers from the original model, then rounding to zero should be done before converting intermediate floating point value to integer.* - *However, if such integer precisions are resulted from quantization, then we should not do such rounding, in order to maintain accuracy.* - *Add corresponding test cases.* ### Tickets: - *CVS-147352*
1 parent 36eebc2 commit 3c713d4

File tree

9 files changed

+219
-7
lines changed

9 files changed

+219
-7
lines changed

src/plugins/intel_cpu/src/nodes/reduce.cpp

+16-7
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ size_t ReduceKey::hash() const {
9191
seed = hash_combine(seed, jcp.reduce_mode);
9292
seed = hash_combine(seed, jcp.fuse_low_precision);
9393
seed = hash_combine(seed, jcp.fuse_broadcast);
94+
seed = hash_combine(seed, jcp.round_to_zero);
9495
seed = hash_combine(seed, jcp.src_dt);
9596
seed = hash_combine(seed, jcp.dst_dt);
9697
seed = get_post_op_hash(seed, *postOps.get());
@@ -101,17 +102,18 @@ size_t ReduceKey::hash() const {
101102
bool ReduceKey::operator==(const ReduceKey &rhs) const {
102103
return jcp.layout == rhs.jcp.layout && jcp.reduce_mode == rhs.jcp.reduce_mode &&
103104
jcp.fuse_low_precision == rhs.jcp.fuse_low_precision &&
105+
jcp.fuse_broadcast == rhs.jcp.fuse_broadcast && jcp.round_to_zero == rhs.jcp.round_to_zero &&
104106
jcp.src_dt == rhs.jcp.src_dt && jcp.dst_dt == rhs.jcp.dst_dt && *postOps.get() == *rhs.postOps.get();
105107
}
106108
} // namespace
107109

108-
#if defined(OPENVINO_ARCH_X86_64)
109-
110110
// some utility functions
111111
static inline bool isFloatCompatible(memory::data_type type) {
112112
return memory::data_type::f32 == type || memory::data_type::bf16 == type || memory::data_type::f16 == type;
113113
}
114114

115+
#if defined(OPENVINO_ARCH_X86_64)
116+
115117
template <cpu_isa_t isa>
116118
struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_generator {
117119
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reduce_kernel_f32)
@@ -966,7 +968,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
966968
inline void store_vector(const Xbyak::Address &op, Vmm vmm_dst, memory::data_type dst_dt) {
967969
Xmm xmm_dst = Xmm(vmm_dst.getIdx());
968970
Ymm ymm_dst = Ymm(vmm_dst.getIdx());
969-
if (!isFloatCompatible(jcp_.src_dt) && !support_intermediate_int) {
971+
if (jcp_.round_to_zero && !support_intermediate_int) {
970972
uni_vroundps(vmm_dst, vmm_dst, 3); // rounding to zero
971973
}
972974
if (convert_f32_to_i32(dst_dt)) {
@@ -1020,7 +1022,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
10201022
}
10211023

10221024
inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) {
1023-
if (!isFloatCompatible(jcp_.src_dt) && !support_intermediate_int) {
1025+
if (jcp_.round_to_zero && !support_intermediate_int) {
10241026
uni_vroundps(xmm_dst, xmm_dst, 3);
10251027
}
10261028
if (convert_f32_to_i32(dst_dt)) {
@@ -1522,7 +1524,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
15221524
int depthwise_inj_idx = 0;
15231525
int quantization_inj_idx = 0;
15241526
int post_ops_data_offset = 0;
1525-
if (!isFloatCompatible(jcp_.src_dt)) {
1527+
if (jcp_.round_to_zero) {
15261528
uni_vroundps(vmm_dst, vmm_dst, 3); // rounding to zero
15271529
}
15281530

@@ -1656,7 +1658,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
16561658
Xmm xmm_dst = Xmm(vmm_dst.getIdx());
16571659
Ymm ymm_dst = Ymm(vmm_dst.getIdx());
16581660
// If there is post ops fusing, necessary rounding has ready been done, no need to do it again.
1659-
if (!post_ops_fusing && !isFloatCompatible(jcp_.src_dt)) {
1661+
if (!post_ops_fusing && jcp_.round_to_zero) {
16601662
uni_vroundps(vmm_dst, vmm_dst, 3);
16611663
}
16621664
if (!isFloatCompatible(dst_dt)) {
@@ -1710,7 +1712,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
17101712
}
17111713

17121714
inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) {
1713-
if (!post_ops_fusing && !isFloatCompatible(jcp_.src_dt)) {
1715+
if (!post_ops_fusing && jcp_.round_to_zero) {
17141716
uni_vroundps(xmm_dst, xmm_dst, 3);
17151717
}
17161718
if (!isFloatCompatible(dst_dt)) {
@@ -1913,6 +1915,7 @@ Reduce::Reduce(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr con
19131915
}
19141916
set_use_aux_kernel = false;
19151917
fuse_low_precision = false;
1918+
round_to_zero = false;
19161919
vec_reduceDH_prc.clear();
19171920
vec_reduceCDW_prc.clear();
19181921
setJITBeyond5D();
@@ -1950,6 +1953,11 @@ void Reduce::initSupportedPrimitiveDescriptors() {
19501953
input_prec = getOriginalInputPrecisionAtPort(REDUCE_DATA);
19511954
output_prec = getOriginalOutputPrecisionAtPort(0);
19521955

1956+
if (!isFloatCompatible(DnnlExtensionUtils::ElementTypeToDataType(input_prec)) &&
1957+
!isFloatCompatible(DnnlExtensionUtils::ElementTypeToDataType(output_prec))) {
1958+
round_to_zero = true;
1959+
}
1960+
19531961
jit_mode = canApplyJIT(input_prec, output_prec);
19541962

19551963
auto is_precision_sensitive_reduce = [](const Algorithm &algorithm) {
@@ -2194,6 +2202,7 @@ void Reduce::createPrimitive() {
21942202
jcp.layout = layout;
21952203
jcp.reduce_mode = getAlgorithm();
21962204
jcp.fuse_low_precision = fuse_low_precision;
2205+
jcp.round_to_zero = round_to_zero;
21972206

21982207
#if defined(OPENVINO_ARCH_X86_64)
21992208
compile_post_kernel = true;

src/plugins/intel_cpu/src/nodes/reduce.h

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct jit_reduce_config_params {
2222
Algorithm reduce_mode;
2323
bool fuse_low_precision;
2424
bool fuse_broadcast; // if post ops fusion needs broadcast
25+
bool round_to_zero;
2526
dnnl::memory::data_type src_dt;
2627
dnnl::memory::data_type dst_dt;
2728
int src_data_size;
@@ -138,6 +139,7 @@ class Reduce : public Node {
138139
bool jit_beyond_5D = false;
139140
bool jit_mode = true;
140141
bool keep_dims = true;
142+
bool round_to_zero = false;
141143
bool is_hybrid_layout = false;
142144
bool compile_post_kernel = true;
143145
bool apply_post_kernel = true;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "subgraph_tests/integer_reduce_mean.hpp"
6+
7+
#include <tuple>
8+
#include <vector>
9+
10+
using namespace ov::test;
11+
namespace {
12+
13+
const std::vector<ov::element::Type> input_precision = {ov::element::f32};
14+
const std::vector<ov::element::Type> integer_input_precision = {ov::element::i32, ov::element::i8, ov::element::u8};
15+
const std::vector<std::vector<size_t>> input_shape = {{1, 2, 3, 3}};
16+
const std::vector<std::vector<size_t>> axes = {{2, 3}};
17+
18+
INSTANTIATE_TEST_SUITE_P(smoke_ReduceMeanQuantized,
19+
IntegerReduceMeanTest,
20+
testing::Combine(
21+
::testing::ValuesIn(input_precision),
22+
::testing::ValuesIn(input_shape),
23+
::testing::ValuesIn(axes),
24+
::testing::Values(true),
25+
::testing::Values(ov::test::utils::DEVICE_CPU)),
26+
IntegerReduceMeanTest::getTestCaseName);
27+
28+
INSTANTIATE_TEST_SUITE_P(smoke_ReduceMeanIntegerInput,
29+
IntegerReduceMeanTest,
30+
testing::Combine(
31+
::testing::ValuesIn(integer_input_precision),
32+
::testing::ValuesIn(input_shape),
33+
::testing::ValuesIn(axes),
34+
::testing::Values(false),
35+
::testing::Values(ov::test::utils::DEVICE_CPU)),
36+
IntegerReduceMeanTest::getTestCaseName);
37+
38+
} // namespace

src/plugins/template/backend/ops/ops_evaluates.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ extern template bool evaluate_node<ov::op::v0::LSTMCell>(std::shared_ptr<ov::Nod
7373
ov::TensorVector& outputs,
7474
const ov::TensorVector& inputs);
7575

76+
extern template bool evaluate_node<ov::op::v1::ReduceMean>(std::shared_ptr<ov::Node> node,
77+
ov::TensorVector& outputs,
78+
const ov::TensorVector& inputs);
79+
7680
OPENVINO_SUPPRESS_DEPRECATED_START
7781
extern template bool evaluate_node<ov::op::v0::LSTMSequence>(std::shared_ptr<ov::Node> node,
7882
ov::TensorVector& outputs,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/reference/reduce_mean.hpp"
6+
7+
#include "evaluate_node.hpp"
8+
9+
template <ov::element::Type_t ET>
10+
bool evaluate(const std::shared_ptr<ov::op::v1::ReduceMean>& op,
11+
ov::TensorVector& outputs,
12+
const ov::TensorVector& inputs) {
13+
using T = ov::fundamental_type_for<ET>;
14+
ov::reference::reduce_mean(inputs[0].data<const T>(),
15+
outputs[0].data<T>(),
16+
inputs[0].get_shape(),
17+
op->get_reduction_axes());
18+
return true;
19+
}
20+
21+
template <>
22+
bool evaluate_node<ov::op::v1::ReduceMean>(std::shared_ptr<ov::Node> node,
23+
ov::TensorVector& outputs,
24+
const ov::TensorVector& inputs) {
25+
const auto& element_type = node->get_output_element_type(0);
26+
27+
switch (element_type) {
28+
case ov::element::bf16:
29+
return evaluate<ov::element::bf16>(ov::as_type_ptr<ov::op::v1::ReduceMean>(node), outputs, inputs);
30+
case ov::element::f16:
31+
return evaluate<ov::element::f16>(ov::as_type_ptr<ov::op::v1::ReduceMean>(node), outputs, inputs);
32+
case ov::element::f32:
33+
return evaluate<ov::element::f32>(ov::as_type_ptr<ov::op::v1::ReduceMean>(node), outputs, inputs);
34+
case ov::element::i8:
35+
return evaluate<ov::element::i8>(ov::as_type_ptr<ov::op::v1::ReduceMean>(node), outputs, inputs);
36+
case ov::element::u8:
37+
return evaluate<ov::element::u8>(ov::as_type_ptr<ov::op::v1::ReduceMean>(node), outputs, inputs);
38+
default:
39+
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()");
40+
}
41+
}

src/plugins/template/backend/opset_int_tbl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ _OPENVINO_OP_REG(Multiply, op::v1)
6262
_OPENVINO_OP_REG(NonMaxSuppression, op::v1)
6363
_OPENVINO_OP_REG(OneHot, op::v1)
6464
_OPENVINO_OP_REG(Pad, op::v1)
65+
_OPENVINO_OP_REG(ReduceMean, op::v1)
6566
_OPENVINO_OP_REG(Split, op::v1)
6667
_OPENVINO_OP_REG(Reshape, op::v1)
6768
_OPENVINO_OP_REG(Select, op::v1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "shared_test_classes/subgraph/integer_reduce_mean.hpp"
8+
9+
namespace ov {
10+
namespace test {
11+
12+
TEST_P(IntegerReduceMeanTest, CompareWithRefs){
13+
run();
14+
};
15+
16+
} // namespace test
17+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "shared_test_classes/base/ov_subgraph.hpp"
6+
7+
namespace ov {
8+
namespace test {
9+
10+
typedef std::tuple<ov::element::Type, // input precision
11+
std::vector<size_t>, // input shape
12+
std::vector<size_t>, // axes
13+
bool, // quantized
14+
const char* // plugin
15+
> IntegerReduceMeanParams;
16+
17+
// IntegerReduceMeanTest covers the two rounding scenarios in ReduceMean with integer inputs.
18+
// Scenario 1: ReduceMean has both input and output precisions to be integers from the original model, so rounding to zero should
19+
// be done before converting intermediate floating point value to integer. Covered by test suite smoke_ReduceMeanIntegerInput.
20+
// Scenario 2: Integer inputs of ReduceMean are resulted from quantization, then such rounding should not be done, in order to maintain
21+
// accuracy. Coverd by test suite smoke_ReduceMeanQuantized.
22+
class IntegerReduceMeanTest : public testing::WithParamInterface<IntegerReduceMeanParams>,
23+
public ov::test::SubgraphBaseStaticTest {
24+
public:
25+
static std::string getTestCaseName(const testing::TestParamInfo<IntegerReduceMeanParams>& obj);
26+
27+
protected:
28+
void SetUp() override;
29+
};
30+
31+
} // namespace test
32+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "shared_test_classes/subgraph/integer_reduce_mean.hpp"
6+
#include "common_test_utils/node_builders/fake_quantize.hpp"
7+
8+
namespace ov {
9+
namespace test {
10+
11+
std::string IntegerReduceMeanTest::getTestCaseName(const testing::TestParamInfo<IntegerReduceMeanParams>& obj) {
12+
ov::element::Type input_precision;
13+
std::vector<size_t> input_shape;
14+
std::vector<size_t> axes;
15+
bool quantized;
16+
const char *device;
17+
std::tie(input_precision, input_shape, axes, quantized, device) = obj.param;
18+
std::ostringstream result;
19+
result << "inputPrecision=" << input_precision.to_string() << "_";
20+
result << "inputShape=" << ov::test::utils::vec2str(input_shape) << "_";
21+
result << "axes=" << ov::test::utils::vec2str(axes) << "_";
22+
result << "device=" + std::string(device);
23+
if (quantized)
24+
result << "quantized=true";
25+
else
26+
result << "quantized=false";
27+
return result.str();
28+
}
29+
30+
void IntegerReduceMeanTest::SetUp() {
31+
ov::element::Type input_precision;
32+
std::vector<size_t> input_shape;
33+
std::vector<size_t> axes;
34+
std::vector<size_t> axes_shape;
35+
bool quantized;
36+
std::tie(input_precision, input_shape, axes, quantized, targetDevice) = this->GetParam();
37+
axes_shape.push_back(axes.size());
38+
39+
auto dataNode = std::make_shared<ov::op::v0::Parameter>(input_precision, ov::Shape(input_shape));
40+
auto axesNode = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape(axes_shape), axes);
41+
42+
std::shared_ptr<ov::op::v1::ReduceMean> reduce_mean;
43+
if (quantized) {
44+
std::vector<size_t> dataFqConstShapes(input_shape.size(), 1);
45+
size_t constDataSize = ov::shape_size(dataFqConstShapes);
46+
std::vector<float> inputLowData(constDataSize), inputHighData(constDataSize), outputLowData(constDataSize), outputHighData(constDataSize);
47+
for (size_t i = 0; i < constDataSize; i++) {
48+
inputLowData[i] = 0;
49+
inputHighData[i] = 255;
50+
outputLowData[i] = 0;
51+
outputHighData[i] = 255;
52+
}
53+
auto dataFqNode = ov::test::utils::make_fake_quantize(
54+
dataNode, input_precision, 256, dataFqConstShapes, inputLowData, inputHighData, outputLowData, outputHighData);
55+
reduce_mean = std::make_shared<ov::op::v1::ReduceMean>(dataFqNode, axesNode, true);
56+
} else {
57+
reduce_mean = std::make_shared<ov::op::v1::ReduceMean>(dataNode, axesNode, true);
58+
}
59+
60+
ov::ParameterVector inputs;
61+
inputs.push_back(dataNode);
62+
ov::ResultVector outputs;
63+
outputs.push_back(std::make_shared<ov::op::v0::Result>(reduce_mean));
64+
function = std::make_shared<ov::Model>(outputs, inputs);
65+
}
66+
67+
} // namespace test
68+
} // namespace ov

0 commit comments

Comments
 (0)