@@ -14,12 +14,6 @@ using namespace Xbyak_aarch64;
14
14
using namespace dnnl ::impl::cpu;
15
15
using namespace dnnl ::impl::cpu::aarch64;
16
16
17
- void jit_uni_eltwise_kernel::operator ()(const node::jit_eltwise_call_args_ptrs* const_args,
18
- const jit_eltwise_call_args_indexes* indexes) {
19
- assert (ker_);
20
- ker_ (const_args, indexes);
21
- }
22
-
23
17
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
24
18
jit_uni_eltwise_generic<isa>::jit_uni_eltwise_generic(jit_eltwise_params jep,
25
19
std::vector<EltwiseData> eltwise_data,
@@ -35,7 +29,8 @@ template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
35
29
void jit_uni_eltwise_generic<isa>::generate() {
36
30
preamble ();
37
31
38
- auto const exec_prc = eltwise_precision_helper::get_precision (jep_.inputs_number , jep_.src_prc , eltwise_data_);
32
+ static const std::vector<element::Type> exec_precisions_priority = {element::f16, element::f32};
33
+ auto const exec_prc = eltwise_precision_helper::get_precision (jep_.inputs_number , jep_.src_prc , eltwise_data_, exec_precisions_priority);
39
34
40
35
eltwise_emitter = create_eltwise_emitter (eltwise_data_.front (), exec_prc);
41
36
for (size_t i = 1 ; i < eltwise_data_.size (); ++i) {
@@ -52,11 +47,11 @@ void jit_uni_eltwise_generic<isa>::generate() {
52
47
for (size_t i = 0 ; i < jep.inputs_number ; i++) {
53
48
ldr (start_to_offsets,
54
49
ptr (reg_const_params,
55
- static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, src_offsets) +
50
+ static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, src_offsets) +
56
51
i * sizeof (size_t ))));
57
52
ldr (get_src_reg (i),
58
53
ptr (reg_const_params,
59
- static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, src_ptr[0 ]) + i * sizeof (size_t ))));
54
+ static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, src_ptr[0 ]) + i * sizeof (size_t ))));
60
55
XReg offset_reg = get_aux_gpr (0 ); // X_TMP_0;
61
56
XReg index_reg = get_aux_gpr (1 ); // X_TMP_1;
62
57
for (int j = 0 ; j < offset_count; j++) {
@@ -67,8 +62,8 @@ void jit_uni_eltwise_generic<isa>::generate() {
67
62
}
68
63
69
64
ldr (start_to_offsets,
70
- ptr (reg_const_params, static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, dst_offsets))));
71
- ldr (reg_dst, ptr (reg_const_params, static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, dst_ptr))));
65
+ ptr (reg_const_params, static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, dst_offsets))));
66
+ ldr (reg_dst, ptr (reg_const_params, static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, dst_ptr))));
72
67
XReg offset_reg = get_aux_gpr (0 ); // X_TMP_0;
73
68
XReg index_reg = get_aux_gpr (1 ); // X_TMP_1;
74
69
for (int j = 0 ; j < offset_count; j++) {
@@ -80,7 +75,7 @@ void jit_uni_eltwise_generic<isa>::generate() {
80
75
mov (reg_oc_off, 0 );
81
76
82
77
ldr (reg_work_amount,
83
- ptr (reg_const_params, static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, work_amount))));
78
+ ptr (reg_const_params, static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, work_amount))));
84
79
} else {
85
80
auto init_ptrs_with_offsets = [this , offset_count, param2](XReg pointer, const std::vector<size_t >& offsets) {
86
81
for (int j = 0 ; j < offset_count; j++) {
@@ -98,11 +93,11 @@ void jit_uni_eltwise_generic<isa>::generate() {
98
93
for (size_t i = 0 ; i < jep.inputs_number ; i++) {
99
94
ldr (get_src_reg (i),
100
95
ptr (param1,
101
- static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof (size_t ))));
96
+ static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof (size_t ))));
102
97
init_ptrs_with_offsets (get_src_reg (i), jep.src_offsets [i]);
103
98
}
104
99
105
- ldr (reg_dst, ptr (reg_const_params, static_cast <int32_t >(offsetof (node:: jit_eltwise_call_args_ptrs, dst_ptr))));
100
+ ldr (reg_dst, ptr (reg_const_params, static_cast <int32_t >(offsetof (jit_eltwise_call_args_ptrs, dst_ptr))));
106
101
init_ptrs_with_offsets (reg_dst, jep.dst_offsets );
107
102
108
103
mov (reg_oc_off, 0 );
@@ -778,80 +773,21 @@ void jit_uni_eltwise_generic<isa>::apply_post_ops() {
778
773
}
779
774
}
780
775
781
- namespace {
776
+ template struct jit_uni_eltwise_generic < cpu_isa_t ::asimd>;
782
777
778
+ } // namespace aarch64
779
+
780
+ namespace {
783
781
template <typename T>
784
782
struct SupportedPrecisions {
785
783
void operator ()(std::set<std::vector<element::Type>>& precisions) {
786
784
precisions = T::get_supported_precisions ();
787
785
}
788
786
};
789
-
790
- static void set_intersection (const std::set<std::vector<element::Type>>& precisions1,
791
- const std::set<std::vector<element::Type>>& precisions2,
792
- std::set<std::vector<element::Type>>& intersection) {
793
- std::map<element::Type, size_t > intersection_types;
794
-
795
- for (auto it1 = precisions1.begin (); it1 != precisions1.end (); ++it1) {
796
- for (auto it2 = precisions2.begin (); it2 != precisions2.end (); ++it2) {
797
- const auto & it1_precisions = *it1;
798
- // all element types are equal
799
- if (it1_precisions[0 ] == (*it2)[0 ]) {
800
- // first precisions size is used
801
- intersection_types.emplace (it1_precisions[0 ], it1_precisions.size ());
802
- }
803
- }
804
- }
805
-
806
- for (auto it = intersection_types.begin (); it != intersection_types.end (); ++it) {
807
- intersection.insert (std::vector<element::Type>(it->second , it->first ));
808
- }
809
- }
810
787
} // namespace
811
788
812
- ov::element::Type eltwise_precision_helper::get_precision (const size_t inputs_number,
813
- const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
814
- const std::vector<EltwiseData>& eltwise_data) {
815
- ov::element::Type exec_prc = ov::element::undefined;
816
-
817
- const auto algorithm = eltwise_data.front ().algo ;
818
- std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions (algorithm);
819
789
820
- for (size_t i = 1 ; i < eltwise_data.size (); ++i) {
821
- std::set<std::vector<element::Type>> prcs = get_supported_precisions (eltwise_data[i].algo );
822
- std::set<std::vector<element::Type>> prcs_intersect = {};
823
-
824
- set_intersection (supported_precision_intersection, prcs, prcs_intersect);
825
-
826
- supported_precision_intersection = prcs_intersect;
827
- }
828
-
829
- static const element::Type exec_precisions_priority[] = {element::f16, element::f32};
830
-
831
- for (const auto prc : exec_precisions_priority) {
832
- if (std::any_of (supported_precision_intersection.begin (),
833
- supported_precision_intersection.end (),
834
- [&prc](const std::vector<element::Type>& precisions) {
835
- return std::find (precisions.begin (), precisions.end (), prc) != precisions.end ();
836
- })) {
837
- exec_prc = prc;
838
- break ;
839
- }
840
- }
841
-
842
- for (size_t i = 0 ; i < inputs_number; i++) {
843
- if (src_prc[i] != exec_prc) {
844
- exec_prc = ov::element::f32;
845
- break ;
846
- }
847
- }
848
-
849
- if (exec_prc == ov::element::undefined) {
850
- OPENVINO_THROW (" Eltwise jitter failed to specify execution precision for Eltwise node" );
851
- }
852
-
853
- return exec_prc;
854
- }
790
+ using namespace aarch64 ;
855
791
856
792
std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_precisions (const Algorithm& algo) {
857
793
std::set<std::vector<element::Type>> precisions;
@@ -911,8 +847,5 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
911
847
return precisions;
912
848
}
913
849
914
- template struct jit_uni_eltwise_generic <cpu_isa_t ::asimd>;
915
-
916
- } // namespace aarch64
917
850
} // namespace intel_cpu
918
851
} // namespace ov
0 commit comments