Skip to content

Commit 3571d18

Browse files
authored
[CPU] Moved jit_uni_eltwise_generic x64 to another files (#28749)
### Details: - *Moved `jit_uni_eltwise_generic` impl for `x64` from node-related files to `kernel/x64` and arch-common structures to `kernel/jit_eltwise_common.hpp`. These changes improve readability and reduce needed copies in arch-dependent kernel implementations* - *Prerequisite for `jit_uni_eltwise_generic` for `riscv64`* ### Tickets: - *N/A*
1 parent 4354147 commit 3571d18

10 files changed

+1307
-1299
lines changed

src/plugins/intel_cpu/src/nodes/eltwise.cpp

+33-1,072
Large diffs are not rendered by default.

src/plugins/intel_cpu/src/nodes/eltwise.h

+1-63
Original file line numberDiff line numberDiff line change
@@ -14,65 +14,13 @@
1414
#include "dnnl_postops_composer_legacy.h"
1515
#include "executors/eltwise_list.hpp"
1616
#include "nodes/executors/eltwise.hpp"
17-
#include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp"
18-
19-
#if defined(OPENVINO_ARCH_ARM64)
20-
# include "kernels/aarch64/jit_uni_eltwise_generic.hpp"
21-
#endif
17+
#include "nodes/kernels/jit_eltwise_common.hpp"
2218

2319
namespace ov {
2420
namespace intel_cpu {
2521
namespace node {
2622

27-
#ifndef OPENVINO_ARCH_ARM64
28-
29-
struct jit_eltwise_params {
30-
size_t inputs_number;
31-
size_t input_size;
32-
33-
ov::element::Type src_prc[MAX_ELTWISE_INPUTS];
34-
ov::element::Type dst_prc;
35-
36-
VectorDims dims;
37-
VectorDims src_offsets[MAX_ELTWISE_INPUTS];
38-
VectorDims dst_offsets;
39-
VectorDims oc_offsets;
40-
41-
size_t src_size[MAX_ELTWISE_INPUTS];
42-
size_t dst_size;
43-
size_t oc_size;
44-
45-
size_t work_amount;
46-
bool use_runtime_ptrs;
47-
bool do_output_saturation;
48-
};
49-
50-
struct jit_eltwise_call_args_indexes {
51-
size_t indexes[MAX_ELTWISE_DIM_RANK];
52-
};
53-
54-
class Eltwise;
55-
56-
struct jit_uni_eltwise_kernel {
57-
void (*ker_)(const jit_eltwise_call_args_ptrs*, const jit_eltwise_call_args_indexes*);
58-
59-
void operator()(const jit_eltwise_call_args_ptrs* const_args, const jit_eltwise_call_args_indexes* indexes) {
60-
assert(ker_);
61-
ker_(const_args, indexes);
62-
}
63-
64-
explicit jit_uni_eltwise_kernel(jit_eltwise_params jep) : ker_(nullptr), jep_(std::move(jep)) {}
65-
virtual ~jit_uni_eltwise_kernel() {}
66-
67-
virtual void create_ker() = 0;
68-
69-
jit_eltwise_params jep_;
70-
};
71-
72-
#endif
73-
7423
enum class EltwiseImplType { reference = 0, optimized = 1, optimizedShapeAgnostic = 2 };
75-
7624
class Eltwise : public Node {
7725
public:
7826
class IEltwiseExecutor {
@@ -218,16 +166,6 @@ class Eltwise : public Node {
218166
std::shared_ptr<EltwiseExecutor> eltwiseExecPtr = nullptr;
219167
};
220168

221-
class eltwise_precision_helper {
222-
public:
223-
static ov::element::Type get_precision(const size_t inputs_number,
224-
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
225-
const std::vector<EltwiseData>& eltwise_data);
226-
227-
private:
228-
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
229-
};
230-
231169
} // namespace node
232170
} // namespace intel_cpu
233171
} // namespace ov

src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp

+14-81
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ using namespace Xbyak_aarch64;
1414
using namespace dnnl::impl::cpu;
1515
using namespace dnnl::impl::cpu::aarch64;
1616

17-
void jit_uni_eltwise_kernel::operator()(const node::jit_eltwise_call_args_ptrs* const_args,
18-
const jit_eltwise_call_args_indexes* indexes) {
19-
assert(ker_);
20-
ker_(const_args, indexes);
21-
}
22-
2317
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
2418
jit_uni_eltwise_generic<isa>::jit_uni_eltwise_generic(jit_eltwise_params jep,
2519
std::vector<EltwiseData> eltwise_data,
@@ -35,7 +29,8 @@ template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
3529
void jit_uni_eltwise_generic<isa>::generate() {
3630
preamble();
3731

38-
auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_);
32+
static const std::vector<element::Type> exec_precisions_priority = {element::f16, element::f32};
33+
auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_, exec_precisions_priority);
3934

4035
eltwise_emitter = create_eltwise_emitter(eltwise_data_.front(), exec_prc);
4136
for (size_t i = 1; i < eltwise_data_.size(); ++i) {
@@ -52,11 +47,11 @@ void jit_uni_eltwise_generic<isa>::generate() {
5247
for (size_t i = 0; i < jep.inputs_number; i++) {
5348
ldr(start_to_offsets,
5449
ptr(reg_const_params,
55-
static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_offsets) +
50+
static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, src_offsets) +
5651
i * sizeof(size_t))));
5752
ldr(get_src_reg(i),
5853
ptr(reg_const_params,
59-
static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t))));
54+
static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, src_ptr[0]) + i * sizeof(size_t))));
6055
XReg offset_reg = get_aux_gpr(0); // X_TMP_0;
6156
XReg index_reg = get_aux_gpr(1); // X_TMP_1;
6257
for (int j = 0; j < offset_count; j++) {
@@ -67,8 +62,8 @@ void jit_uni_eltwise_generic<isa>::generate() {
6762
}
6863

6964
ldr(start_to_offsets,
70-
ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_offsets))));
71-
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
65+
ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, dst_offsets))));
66+
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, dst_ptr))));
7267
XReg offset_reg = get_aux_gpr(0); // X_TMP_0;
7368
XReg index_reg = get_aux_gpr(1); // X_TMP_1;
7469
for (int j = 0; j < offset_count; j++) {
@@ -80,7 +75,7 @@ void jit_uni_eltwise_generic<isa>::generate() {
8075
mov(reg_oc_off, 0);
8176

8277
ldr(reg_work_amount,
83-
ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, work_amount))));
78+
ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, work_amount))));
8479
} else {
8580
auto init_ptrs_with_offsets = [this, offset_count, param2](XReg pointer, const std::vector<size_t>& offsets) {
8681
for (int j = 0; j < offset_count; j++) {
@@ -98,11 +93,11 @@ void jit_uni_eltwise_generic<isa>::generate() {
9893
for (size_t i = 0; i < jep.inputs_number; i++) {
9994
ldr(get_src_reg(i),
10095
ptr(param1,
101-
static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof(size_t))));
96+
static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, src_ptr) + i * sizeof(size_t))));
10297
init_ptrs_with_offsets(get_src_reg(i), jep.src_offsets[i]);
10398
}
10499

105-
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(node::jit_eltwise_call_args_ptrs, dst_ptr))));
100+
ldr(reg_dst, ptr(reg_const_params, static_cast<int32_t>(offsetof(jit_eltwise_call_args_ptrs, dst_ptr))));
106101
init_ptrs_with_offsets(reg_dst, jep.dst_offsets);
107102

108103
mov(reg_oc_off, 0);
@@ -778,80 +773,21 @@ void jit_uni_eltwise_generic<isa>::apply_post_ops() {
778773
}
779774
}
780775

781-
namespace {
776+
template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;
782777

778+
} // namespace aarch64
779+
780+
namespace {
783781
template <typename T>
784782
struct SupportedPrecisions {
785783
void operator()(std::set<std::vector<element::Type>>& precisions) {
786784
precisions = T::get_supported_precisions();
787785
}
788786
};
789-
790-
static void set_intersection(const std::set<std::vector<element::Type>>& precisions1,
791-
const std::set<std::vector<element::Type>>& precisions2,
792-
std::set<std::vector<element::Type>>& intersection) {
793-
std::map<element::Type, size_t> intersection_types;
794-
795-
for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) {
796-
for (auto it2 = precisions2.begin(); it2 != precisions2.end(); ++it2) {
797-
const auto& it1_precisions = *it1;
798-
// all element types are equal
799-
if (it1_precisions[0] == (*it2)[0]) {
800-
// first precisions size is used
801-
intersection_types.emplace(it1_precisions[0], it1_precisions.size());
802-
}
803-
}
804-
}
805-
806-
for (auto it = intersection_types.begin(); it != intersection_types.end(); ++it) {
807-
intersection.insert(std::vector<element::Type>(it->second, it->first));
808-
}
809-
}
810787
} // namespace
811788

812-
ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_number,
813-
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
814-
const std::vector<EltwiseData>& eltwise_data) {
815-
ov::element::Type exec_prc = ov::element::undefined;
816-
817-
const auto algorithm = eltwise_data.front().algo;
818-
std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions(algorithm);
819789

820-
for (size_t i = 1; i < eltwise_data.size(); ++i) {
821-
std::set<std::vector<element::Type>> prcs = get_supported_precisions(eltwise_data[i].algo);
822-
std::set<std::vector<element::Type>> prcs_intersect = {};
823-
824-
set_intersection(supported_precision_intersection, prcs, prcs_intersect);
825-
826-
supported_precision_intersection = prcs_intersect;
827-
}
828-
829-
static const element::Type exec_precisions_priority[] = {element::f16, element::f32};
830-
831-
for (const auto prc : exec_precisions_priority) {
832-
if (std::any_of(supported_precision_intersection.begin(),
833-
supported_precision_intersection.end(),
834-
[&prc](const std::vector<element::Type>& precisions) {
835-
return std::find(precisions.begin(), precisions.end(), prc) != precisions.end();
836-
})) {
837-
exec_prc = prc;
838-
break;
839-
}
840-
}
841-
842-
for (size_t i = 0; i < inputs_number; i++) {
843-
if (src_prc[i] != exec_prc) {
844-
exec_prc = ov::element::f32;
845-
break;
846-
}
847-
}
848-
849-
if (exec_prc == ov::element::undefined) {
850-
OPENVINO_THROW("Eltwise jitter failed to specify execution precision for Eltwise node");
851-
}
852-
853-
return exec_prc;
854-
}
790+
using namespace aarch64;
855791

856792
std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) {
857793
std::set<std::vector<element::Type>> precisions;
@@ -911,8 +847,5 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
911847
return precisions;
912848
}
913849

914-
template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;
915-
916-
} // namespace aarch64
917850
} // namespace intel_cpu
918851
} // namespace ov

src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp

+1-52
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp"
3030
#include "emitters/plugin/aarch64/jit_emitter.hpp"
31-
#include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp"
31+
#include "nodes/kernels/jit_eltwise_common.hpp"
3232
#include "utils/cpu_utils.hpp"
3333
#include "utils/general_utils.h"
3434

@@ -40,45 +40,6 @@ using namespace Xbyak_aarch64;
4040
using namespace dnnl::impl::cpu;
4141
using namespace dnnl::impl::cpu::aarch64;
4242

43-
struct jit_eltwise_params {
44-
size_t inputs_number;
45-
size_t input_size;
46-
47-
ov::element::Type src_prc[MAX_ELTWISE_INPUTS];
48-
ov::element::Type dst_prc;
49-
50-
VectorDims dims;
51-
VectorDims src_offsets[MAX_ELTWISE_INPUTS];
52-
VectorDims dst_offsets;
53-
VectorDims oc_offsets;
54-
55-
size_t src_size[MAX_ELTWISE_INPUTS];
56-
size_t dst_size;
57-
size_t oc_size;
58-
59-
size_t work_amount;
60-
bool use_runtime_ptrs;
61-
bool do_output_saturation;
62-
};
63-
64-
struct jit_eltwise_call_args_indexes {
65-
size_t indexes[MAX_ELTWISE_DIM_RANK];
66-
};
67-
68-
struct jit_uni_eltwise_kernel {
69-
void (*ker_)(const node::jit_eltwise_call_args_ptrs*, const jit_eltwise_call_args_indexes*);
70-
71-
void operator()(const node::jit_eltwise_call_args_ptrs* const_args, const jit_eltwise_call_args_indexes* indexes);
72-
73-
jit_uni_eltwise_kernel() {}
74-
jit_uni_eltwise_kernel(jit_eltwise_params jep) : ker_(nullptr), jep_(std::move(jep)) {}
75-
virtual ~jit_uni_eltwise_kernel() {}
76-
77-
virtual void create_ker() = 0;
78-
79-
jit_eltwise_params jep_;
80-
};
81-
8243
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
8344
struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
8445
public:
@@ -89,8 +50,6 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
8950
std::vector<ov::intel_cpu::Type> ops_list,
9051
dnnl::post_ops post_ops);
9152

92-
jit_uni_eltwise_generic() {}
93-
9453
void create_ker() override {
9554
jit_generator::create_kernel();
9655
ker_ = (decltype(ker_))jit_ker();
@@ -255,16 +214,6 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
255214
std::vector<std::shared_ptr<jit_emitter>> post_op_emitters;
256215
};
257216

258-
class eltwise_precision_helper {
259-
public:
260-
static ov::element::Type get_precision(const size_t inputs_number,
261-
const ov::element::Type (&src_prc)[MAX_ELTWISE_INPUTS],
262-
const std::vector<EltwiseData>& eltwise_data);
263-
264-
private:
265-
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
266-
};
267-
268217
} // namespace aarch64
269218
} // namespace intel_cpu
270219
} // namespace ov

src/plugins/intel_cpu/src/nodes/kernels/jit_eltwise_call_args_ptrs.hpp

-29
This file was deleted.

0 commit comments

Comments
 (0)