@@ -91,6 +91,7 @@ size_t ReduceKey::hash() const {
91
91
seed = hash_combine (seed, jcp.reduce_mode );
92
92
seed = hash_combine (seed, jcp.fuse_low_precision );
93
93
seed = hash_combine (seed, jcp.fuse_broadcast );
94
+ seed = hash_combine (seed, jcp.round_to_zero );
94
95
seed = hash_combine (seed, jcp.src_dt );
95
96
seed = hash_combine (seed, jcp.dst_dt );
96
97
seed = get_post_op_hash (seed, *postOps.get ());
@@ -101,17 +102,18 @@ size_t ReduceKey::hash() const {
101
102
bool ReduceKey::operator ==(const ReduceKey &rhs) const {
102
103
return jcp.layout == rhs.jcp .layout && jcp.reduce_mode == rhs.jcp .reduce_mode &&
103
104
jcp.fuse_low_precision == rhs.jcp .fuse_low_precision &&
105
+ jcp.fuse_broadcast == rhs.jcp .fuse_broadcast && jcp.round_to_zero == rhs.jcp .round_to_zero &&
104
106
jcp.src_dt == rhs.jcp .src_dt && jcp.dst_dt == rhs.jcp .dst_dt && *postOps.get () == *rhs.postOps .get ();
105
107
}
106
108
} // namespace
107
109
108
- #if defined(OPENVINO_ARCH_X86_64)
109
-
110
110
// some utility functions
111
111
static inline bool isFloatCompatible (memory::data_type type) {
112
112
return memory::data_type::f32 == type || memory::data_type::bf16 == type || memory::data_type::f16 == type;
113
113
}
114
114
115
+ #if defined(OPENVINO_ARCH_X86_64)
116
+
115
117
template <cpu_isa_t isa>
116
118
struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel , public jit_generator {
117
119
DECLARE_CPU_JIT_AUX_FUNCTIONS (jit_uni_reduce_kernel_f32)
@@ -966,7 +968,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
966
968
inline void store_vector (const Xbyak::Address &op, Vmm vmm_dst, memory::data_type dst_dt) {
967
969
Xmm xmm_dst = Xmm (vmm_dst.getIdx ());
968
970
Ymm ymm_dst = Ymm (vmm_dst.getIdx ());
969
- if (! isFloatCompatible ( jcp_.src_dt ) && !support_intermediate_int) {
971
+ if (jcp_.round_to_zero && !support_intermediate_int) {
970
972
uni_vroundps (vmm_dst, vmm_dst, 3 ); // rounding to zero
971
973
}
972
974
if (convert_f32_to_i32 (dst_dt)) {
@@ -1020,7 +1022,7 @@ struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_gene
1020
1022
}
1021
1023
1022
1024
inline void store_scalar (const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) {
1023
- if (! isFloatCompatible ( jcp_.src_dt ) && !support_intermediate_int) {
1025
+ if (jcp_.round_to_zero && !support_intermediate_int) {
1024
1026
uni_vroundps (xmm_dst, xmm_dst, 3 );
1025
1027
}
1026
1028
if (convert_f32_to_i32 (dst_dt)) {
@@ -1522,7 +1524,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
1522
1524
int depthwise_inj_idx = 0 ;
1523
1525
int quantization_inj_idx = 0 ;
1524
1526
int post_ops_data_offset = 0 ;
1525
- if (! isFloatCompatible ( jcp_.src_dt ) ) {
1527
+ if (jcp_.round_to_zero ) {
1526
1528
uni_vroundps (vmm_dst, vmm_dst, 3 ); // rounding to zero
1527
1529
}
1528
1530
@@ -1656,7 +1658,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
1656
1658
Xmm xmm_dst = Xmm (vmm_dst.getIdx ());
1657
1659
Ymm ymm_dst = Ymm (vmm_dst.getIdx ());
1658
1660
// If there is post ops fusing, necessary rounding has ready been done, no need to do it again.
1659
- if (!post_ops_fusing && ! isFloatCompatible ( jcp_.src_dt ) ) {
1661
+ if (!post_ops_fusing && jcp_.round_to_zero ) {
1660
1662
uni_vroundps (vmm_dst, vmm_dst, 3 );
1661
1663
}
1662
1664
if (!isFloatCompatible (dst_dt)) {
@@ -1710,7 +1712,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
1710
1712
}
1711
1713
1712
1714
inline void store_scalar (const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt) {
1713
- if (!post_ops_fusing && ! isFloatCompatible ( jcp_.src_dt ) ) {
1715
+ if (!post_ops_fusing && jcp_.round_to_zero ) {
1714
1716
uni_vroundps (xmm_dst, xmm_dst, 3 );
1715
1717
}
1716
1718
if (!isFloatCompatible (dst_dt)) {
@@ -1913,6 +1915,7 @@ Reduce::Reduce(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr con
1913
1915
}
1914
1916
set_use_aux_kernel = false ;
1915
1917
fuse_low_precision = false ;
1918
+ round_to_zero = false ;
1916
1919
vec_reduceDH_prc.clear ();
1917
1920
vec_reduceCDW_prc.clear ();
1918
1921
setJITBeyond5D ();
@@ -1950,6 +1953,11 @@ void Reduce::initSupportedPrimitiveDescriptors() {
1950
1953
input_prec = getOriginalInputPrecisionAtPort (REDUCE_DATA);
1951
1954
output_prec = getOriginalOutputPrecisionAtPort (0 );
1952
1955
1956
+ if (!isFloatCompatible (DnnlExtensionUtils::ElementTypeToDataType (input_prec)) &&
1957
+ !isFloatCompatible (DnnlExtensionUtils::ElementTypeToDataType (output_prec))) {
1958
+ round_to_zero = true ;
1959
+ }
1960
+
1953
1961
jit_mode = canApplyJIT (input_prec, output_prec);
1954
1962
1955
1963
auto is_precision_sensitive_reduce = [](const Algorithm &algorithm) {
@@ -2194,6 +2202,7 @@ void Reduce::createPrimitive() {
2194
2202
jcp.layout = layout;
2195
2203
jcp.reduce_mode = getAlgorithm ();
2196
2204
jcp.fuse_low_precision = fuse_low_precision;
2205
+ jcp.round_to_zero = round_to_zero;
2197
2206
2198
2207
#if defined(OPENVINO_ARCH_X86_64)
2199
2208
compile_post_kernel = true ;
0 commit comments