1
- // Copyright (C) 2022 Intel Corporation
1
+ // Copyright (C) 2022-2023 Intel Corporation
2
2
// SPDX-License-Identifier: Apache-2.0
3
3
//
4
4
25
25
#include " transformations/common_optimizations/fq_mul_fusion.hpp"
26
26
#include " transformations/common_optimizations/mul_fake_quantize_fusion.hpp"
27
27
#include " transformations/common_optimizations/nop_elimination.hpp"
28
+ #include " transformations/common_optimizations/reshape_prelu.hpp"
28
29
#include " transformations/common_optimizations/transpose_sinking.hpp"
29
30
#include " transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"
30
31
#include " transformations/common_optimizations/augru_cell_fusion.hpp"
53
54
#include " transformations/op_conversions/convert_slice_to_strided_slice.hpp"
54
55
#include " transformations/op_conversions/convert_space_to_batch.hpp"
55
56
#include " transformations/op_conversions/convert_space_to_depth.hpp"
56
- #include " transformations/op_conversions/convert_subtract.hpp"
57
- #include " transformations/op_conversions/convert_ti_to_sequences.hpp"
58
57
#include " transformations/op_conversions/detection_output_downgrade.hpp"
59
58
#include " transformations/op_conversions/detection_output_upgrade.hpp"
60
59
#include " transformations/op_conversions/eye_decomposition.hpp"
98
97
#include " transformations/snippets/x64/pass/snippets_mark_skipped.hpp"
99
98
#include " transformations/cpu_opset/x64/pass/mha_fusion.hpp"
100
99
#include " transformations/cpu_opset/x64/pass/convert_to_interaction.hpp"
101
- #include " transformations/cpu_opset/arm/pass/convert_group_conv.hpp"
102
- #include " transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp"
103
- #include " transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp"
104
- #include " transformations/cpu_opset/arm/pass/mish_decomposition.hpp"
105
- #include " transformations/cpu_opset/common/pass/decompose_integer_divide.hpp"
100
+ #include " transformations/cpu_opset/x64/pass/convert_precision_i64_i32.hpp"
106
101
#include " transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp"
107
102
#include " transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp"
108
103
#include " transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp"
@@ -127,7 +122,7 @@ namespace intel_cpu {
127
122
128
123
using const_node_ptr = const std::shared_ptr<const ov::Node>;
129
124
130
- bool Transformations::fuse_type_to_convert (const std::shared_ptr<ngraph ::Node>& node, const precisions_map& precisions) {
125
+ bool Transformations::fuse_type_to_convert (const std::shared_ptr<ov ::Node>& node, const precisions_map& precisions) {
131
126
const auto & from = node->get_output_element_type (0 );
132
127
auto it = precisions.find (from);
133
128
if (it == precisions.end ())
@@ -139,7 +134,7 @@ bool Transformations::fuse_type_to_convert(const std::shared_ptr<ngraph::Node>&
139
134
// is converted to be 1 for boolean, but 0 for u8. Thus an Abs and Ceil node should be added before the
140
135
// Convert node for this scenario.
141
136
if (convert->input (0 ).get_element_type ().is_real () &&
142
- convert->get_convert_element_type () == ngraph ::element::boolean && to.is_integral_number ()) {
137
+ convert->get_convert_element_type () == ov ::element::boolean && to.is_integral_number ()) {
143
138
auto abs = std::make_shared<ov::opset10::Abs>(convert->input_value (0 ).get_node_shared_ptr ());
144
139
auto ceil = std::make_shared<ov::opset10::Ceiling>(abs );
145
140
auto new_convert = std::make_shared<ov::opset10::Convert>(ceil , to);
@@ -208,11 +203,10 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
208
203
if (useLpt) {
209
204
CPU_REGISTER_PASS_COMMON (manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions);
210
205
}
206
+ bool supportI64 = config.enableNativeI64 ;
211
207
212
- auto get_convert_precisions = []() {
208
+ auto get_convert_precisions = [& ]() {
213
209
precisions_map map = {
214
- {ov::element::i64, ov::element::i32},
215
- {ov::element::u64, ov::element::i32},
216
210
{ov::element::i16, ov::element::i32},
217
211
{ov::element::u16, ov::element::i32},
218
212
{ov::element::u32, ov::element::i32},
@@ -223,12 +217,21 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
223
217
{ov::element::u4, ov::element::u8}
224
218
};
225
219
226
- if (!dnnl::impl::cpu::x64::mayiuse (dnnl::impl::cpu::x64::avx512_core))
220
+ if (supportI64) {
221
+ map.insert ({ov::element::u64, ov::element::i64});
222
+ } else {
223
+ map.insert ({ov::element::u64, ov::element::i32});
224
+ map.insert ({ov::element::i64, ov::element::i32});
225
+ }
226
+
227
+ if (!dnnl::impl::cpu::x64::mayiuse (dnnl::impl::cpu::x64::avx512_core)) {
227
228
map.insert ({ov::element::bf16, ov::element::f32});
229
+ }
228
230
229
231
return map;
230
232
};
231
- static const auto precisions = get_convert_precisions ();
233
+
234
+ const auto precisions = get_convert_precisions ();
232
235
type_to_fuse_map type_to_fuse = {{ov::opset10::Convert::get_type_info_static (), fuse_type_to_convert}};
233
236
234
237
CPU_REGISTER_PASS_COMMON (manager, ov::pass::AUGRUCellFusion);
@@ -263,8 +266,13 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
263
266
// Common ConvertPrecision pass handles only a limited set of opevino operations to match the list of precisions supported by the plugin.
264
267
// However, if the extension operation produces an output precision that is not natively supported, this may lead to inconsistency during
265
268
// element type propagation. This transformation is called before the ConvertPrecision pass to align the actual precisions with the list of supported ones.
266
- CPU_REGISTER_PASS_COMMON (manager, ov::pass::InsertConvertAfterExtension);
269
+ if (!supportI64) {
270
+ CPU_REGISTER_PASS_COMMON (manager, ov::pass::InsertConvertAfterExtension);
271
+ }
267
272
CPU_REGISTER_PASS_COMMON (manager, ov::pass::ConvertPrecision, precisions, type_to_fuse);
273
+ if (supportI64) {
274
+ CPU_REGISTER_PASS_X64 (manager, ConvertPrecisionI64ToI32);
275
+ }
268
276
269
277
CPU_REGISTER_PASS_COMMON (manager, ov::pass::EliminateConvert);
270
278
CPU_REGISTER_PASS_COMMON (manager, SwapConvertTranspose);
0 commit comments