Skip to content

Commit bfef7eb

Browse files
authored
[CPU][ARM64] Add JIT emitter for Eltwise LogicalXor operation (#26846)
### Details: - Added a jit_logical_xor_emitter derived class in aarch64/jit_eltwise_emitters - Created entry Algorithm::EltwiseLogicalXor in the get_supported_precisions in nodes/kernels/aarch64 - Add the EltwiseLogicalXor entry in the aarch64 executors supported algorithms ### Tickets: - #24108 Signed-off-by: Nashez Zubair <nashezzubair@gmail.com>
1 parent acb8724 commit bfef7eb

File tree

4 files changed

+83
-0
lines changed

4 files changed

+83
-0
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,58 @@ std::set<std::vector<element::Type>> jit_logical_not_emitter::get_supported_prec
11981198
return {{element::f32}};
11991199
}
12001200

1201+
/// LOGICAL_XOR ///
1202+
jit_logical_xor_emitter::jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
1203+
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
1204+
const std::shared_ptr<ov::Node>& node)
1205+
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
1206+
prepare_table();
1207+
}
1208+
1209+
jit_logical_xor_emitter::jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
1210+
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
1211+
const ov::element::Type exec_prc)
1212+
: jit_emitter(host, host_isa, exec_prc) {
1213+
prepare_table();
1214+
}
1215+
1216+
size_t jit_logical_xor_emitter::get_inputs_count() const { return 2; }
1217+
1218+
size_t jit_logical_xor_emitter::get_aux_vecs_count() const { return 1; }
1219+
1220+
size_t jit_logical_xor_emitter::get_aux_gprs_count() const { return 1; }
1221+
1222+
void jit_logical_xor_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
1223+
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
1224+
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
1225+
} else {
1226+
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
1227+
}
1228+
}
1229+
1230+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
1231+
void jit_logical_xor_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
1232+
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());
1233+
1234+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
1235+
const TReg src1 = TReg(in_vec_idxs[0]);
1236+
const TReg src2 = TReg(in_vec_idxs[1]);
1237+
const TReg dst = TReg(out_vec_idxs[0]);
1238+
const TReg aux = TReg(aux_vec_idxs[0]);
1239+
1240+
h->eor(dst.b16, src1.b16, src2.b16);
1241+
h->ld1r(aux.s, table_val2("one"));
1242+
h->and_(dst.b16, dst.b16, aux.b16);
1243+
}
1244+
1245+
void jit_logical_xor_emitter::register_table_entries() {
1246+
push_arg_entry_of("one", 0x3f800000, true);
1247+
}
1248+
1249+
std::set<std::vector<element::Type>> jit_logical_xor_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
1250+
return {{element::f32, element::f32}};
1251+
}
1252+
12011253
/// MAX ///
12021254
jit_maximum_emitter::jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
12031255
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp

+28
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,34 @@ class jit_logical_not_emitter : public jit_emitter {
549549
void register_table_entries() override;
550550
};
551551

552+
class jit_logical_xor_emitter : public jit_emitter {
553+
public:
554+
jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
555+
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
556+
const ov::element::Type exec_prc = ov::element::f32);
557+
558+
jit_logical_xor_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
559+
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
560+
const std::shared_ptr<ov::Node>& n);
561+
562+
size_t get_inputs_count() const override;
563+
564+
size_t get_aux_vecs_count() const override;
565+
566+
size_t get_aux_gprs_count() const override;
567+
568+
static std::set<std::vector<element::Type>> get_supported_precisions(
569+
const std::shared_ptr<ov::Node>& node = nullptr);
570+
571+
private:
572+
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;
573+
574+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
575+
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
576+
577+
void register_table_entries() override;
578+
};
579+
552580
class jit_mod_emitter : public jit_emitter {
553581
public:
554582
jit_mod_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,

src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ bool JitEltwiseExecutor::isSupported(
3535
Algorithm::EltwiseIsNaN,
3636
Algorithm::EltwiseLessEqual,
3737
Algorithm::EltwiseLogicalNot,
38+
Algorithm::EltwiseLogicalXor,
3839
Algorithm::EltwiseMaximum,
3940
Algorithm::EltwiseMinimum,
4041
Algorithm::EltwiseMish,

src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
653653
OV_CASE(Algorithm::EltwiseIsInf, ov::intel_cpu::aarch64::jit_is_inf_emitter),
654654
OV_CASE(Algorithm::EltwiseLessEqual, ov::intel_cpu::aarch64::jit_less_equal_emitter),
655655
OV_CASE(Algorithm::EltwiseLogicalNot, ov::intel_cpu::aarch64::jit_logical_not_emitter),
656+
OV_CASE(Algorithm::EltwiseLogicalXor, ov::intel_cpu::aarch64::jit_logical_xor_emitter),
656657
OV_CASE(Algorithm::EltwiseIsNaN, ov::intel_cpu::aarch64::jit_is_nan_emitter),
657658
OV_CASE(Algorithm::EltwiseMaximum, ov::intel_cpu::aarch64::jit_maximum_emitter),
658659
OV_CASE(Algorithm::EltwiseMinimum, ov::intel_cpu::aarch64::jit_minimum_emitter),
@@ -833,6 +834,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
833834
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
834835
OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter),
835836
OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter),
837+
OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter),
836838
OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter),
837839
OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter),
838840
OV_CASE(Algorithm::EltwiseMish, jit_mish_emitter),

0 commit comments

Comments
 (0)