From ba3fd33069adadbda658543336476b7181fe4662 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Tue, 2 Jul 2024 07:32:59 +0000 Subject: [PATCH 01/33] [CPU] [Snippets] Implement load, store, convert emitters, and add convert tokenization --- .../aarch64/jit_conversion_emitters.cpp | 170 ++++++++++ .../aarch64/jit_conversion_emitters.hpp | 60 ++++ .../aarch64/jit_load_store_emitters.cpp | 304 +++++++++++++++++- .../aarch64/jit_load_store_emitters.hpp | 14 + .../src/emitters/plugin/aarch64/utils.cpp | 100 ++++++ .../src/emitters/plugin/aarch64/utils.hpp | 36 +++ .../snippets/aarch64/cpu_generator.cpp | 3 + .../snippets/aarch64/jit_memory_emitters.cpp | 18 +- .../aarch64/pass/snippets_mark_skipped.cpp | 25 ++ .../transformation_pipeline.cpp | 3 +- .../intel_cpu/tests/functional/CMakeLists.txt | 9 + .../skip_tests_config.cpp | 4 +- .../snippets/arm/convert.cpp | 169 ++++++++++ .../snippets/{ => x64}/convert.cpp | 0 14 files changed, 889 insertions(+), 26 deletions(-) create mode 100644 src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp create mode 100644 src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp create mode 100644 src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp create mode 100644 src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp create mode 100644 src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/arm/convert.cpp rename src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/{ => x64}/convert.cpp (100%) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp new file mode 100644 index 00000000000000..49a5358b0d24c6 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -0,0 +1,170 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "jit_conversion_emitters.hpp" +#include "emitters/utils.hpp" +#include "utils.hpp" + +using namespace dnnl::impl::cpu::aarch64; +using namespace Xbyak_aarch64; + +namespace ov { +namespace intel_cpu { +namespace aarch64 { + +template +static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + ov::element::Type input_type, ov::element::Type output_type, bool is_saturated) { + switch (output_type) { + case ov::element::f32: + switch (input_type) { + case ov::element::f32: + break; + case ov::element::i32: + cvt_i32_to_f32(h, in_idxs, out_idxs); + break; + case ov::element::f16: + cvt_f16_to_f32(h, in_idxs, out_idxs); + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i32_to_f32(h, out_idxs, out_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + case ov::element::i32: + switch (input_type) { + case ov::element::f32: + cvt_f32_to_i32(h, in_idxs, out_idxs); + break; + case ov::element::i32: + break; + case ov::element::f16: + cvt_f16_to_f32(h, in_idxs, out_idxs); + cvt_f32_to_i32(h, out_idxs, out_idxs); + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + case ov::element::f16: + switch (input_type) { + case ov::element::f32: + cvt_f32_to_f16(h, in_idxs, out_idxs); + break; + case ov::element::i32: + cvt_i32_to_f32(h, in_idxs, out_idxs); + cvt_f32_to_f16(h, out_idxs, out_idxs); + break; + case ov::element::f16: + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i32_to_f32(h, out_idxs, out_idxs); + cvt_f32_to_f16(h, out_idxs, out_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + case ov::element::i8: + case ov::element::u8: + switch (input_type) { + case ov::element::f32: + cvt_f32_to_i32(h, in_idxs, out_idxs); + cvt_i32_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + break; + case ov::element::i32: + cvt_i32_to_byte(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated); + break; + case ov::element::f16: + cvt_f16_to_f32(h, in_idxs, out_idxs); + cvt_f32_to_i32(h, out_idxs, out_idxs); + cvt_i32_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + break; + case ov::element::i8: + case ov::element::u8: + cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i32_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); + } + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", output_type.get_type_name()); + } +} + +jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) +: jit_emitter(host, host_isa, exec_prc) { + input_type = node->get_input_element_type(0); + output_type = node->get_output_element_type(0); +} + +void jit_convert_emitter::validate_types() const { + OV_CPU_JIT_EMITTER_ASSERT(one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported input type: ", input_type.get_type_name()); + OV_CPU_JIT_EMITTER_ASSERT(one_of(output_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported output type: ", output_type.get_type_name()); + OV_CPU_JIT_EMITTER_ASSERT(input_type != output_type, "Input type ", input_type.get_type_name(), " and output type ", + output_type.get_type_name(), " should be different."); +} + +size_t jit_convert_emitter::get_inputs_count() const { return 1; } + +void jit_convert_emitter::emit_data() const { + jit_emitter::emit_data(); +} + +jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa, + const std::shared_ptr& node, ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, node, exec_prc) { +} + +void jit_convert_truncation_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { + validate_types(); + if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { + emit_isa(in_idxs, out_idxs); + } else { + OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_); + } +} + +template +void jit_convert_truncation_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { + jit_convert_process(h, in_idxs, out_idxs, input_type, output_type, false); +} + +jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa, + const std::shared_ptr& node, ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, node, exec_prc) { +} + +void jit_convert_saturation_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { + validate_types(); + if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { + emit_isa(in_idxs, out_idxs); + } else { + OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_); + } +} + +template +void jit_convert_saturation_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { + jit_convert_process(h, in_idxs, out_idxs, input_type, output_type, true); +} + +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp new file mode 100644 index 00000000000000..df24f714ccc55b --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "jit_emitter.hpp" + +namespace ov { +namespace intel_cpu { +namespace aarch64 { + +class jit_convert_emitter : public jit_emitter { +public: + jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + + size_t get_inputs_count() const override; + +protected: + void emit_data() const override; + void validate_types() const; + + ov::element::Type input_type; + ov::element::Type output_type; +}; + +// This emitter is covered by specification of "Convert" operation. The implementation uses a "warp-around" conversion. +// Example: +// int32_t -> int8_t +// 129 -> -127 +class jit_convert_truncation_emitter : public jit_convert_emitter { +public: + jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + +private: + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; + template + void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; +}; + +// This emitter is covered by the common dnnl behavior. The implementation uses a "saturation" conversion. +// Example: +// int32_t -> int8_t +// 129 -> 127 +class jit_convert_saturation_emitter : public jit_convert_emitter { +public: + jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + +private: + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; + template + void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; +}; + +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 0fe1095e291fb7..8fa0278bacad42 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -5,6 +5,7 @@ #include "jit_load_store_emitters.hpp" #include "cpu/aarch64/cpu_isa_traits.hpp" #include "emitters/utils.hpp" +#include "utils.hpp" using namespace Xbyak_aarch64; @@ -30,12 +31,7 @@ void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std:: } template -void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { - OV_CPU_JIT_EMITTER_ASSERT(src_prc_ == ov::element::f32 && dst_prc_ == ov::element::f32, - "Only supports both input and output precisions of being FP32"); - OV_CPU_JIT_EMITTER_ASSERT(load_num_ <= static_cast((get_vec_length() / dst_prc_.size())), - "Unexpected number of elements to load."); - +void jit_load_emitter::load_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); XReg prc = XReg(aux_gpr_idxs[0]); @@ -65,6 +61,142 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v } } +template +void jit_load_emitter::load_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + XReg src = XReg(in_idxs[0]); + XReg prc = XReg(aux_gpr_idxs[0]); + TReg dst = TReg(out_idxs[0]); + DReg dst_d = DReg(out_idxs[0]); + HReg dst_h = HReg(out_idxs[0]); + SReg dst_s = SReg(out_idxs[0]); + + switch (load_num_) { + case 0: + break; + case 1: + h->ldr(dst_h, post_ptr(src, byte_offset_)); + break; + case 2: + h->ldr(dst_s, post_ptr(src, byte_offset_)); + break; + case 3: + h->ldr(dst_s, post_ptr(src, byte_offset_)); + h->add_imm(prc, src, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); + h->ld1(dst.h[2], ptr(prc)); + break; + case 4: + h->ldr(dst_d, post_ptr(src, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); + } +} + +template +void jit_load_emitter::load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + XReg src = XReg(in_idxs[0]); + XReg prc = XReg(aux_gpr_idxs[0]); + TReg dst = TReg(out_idxs[0]); + BReg dst_b = BReg(out_idxs[0]); + HReg dst_h = HReg(out_idxs[0]); + SReg dst_s = SReg(out_idxs[0]); + + switch (load_num_) { + case 0: + break; + case 1: + h->ldr(dst_b, post_ptr(src, byte_offset_)); + break; + case 2: + h->ldr(dst_h, post_ptr(src, byte_offset_)); + break; + case 3: + h->ldr(dst_h, post_ptr(src, byte_offset_)); + h->add_imm(prc, src, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); + h->ld1(dst.b[2], ptr(prc)); + break; + case 4: + h->ldr(dst_s, post_ptr(src, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); + } +} + +template +void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { + bool is_supported_precision = one_of(src_prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + (src_prc_ == dst_prc_ || one_of(dst_prc_, ov::element::f32, ov::element::i32)); + OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); + OV_CPU_JIT_EMITTER_ASSERT(load_num_ <= static_cast((get_vec_length() / dst_prc_.size())), + "Unexpected number of elements to load."); + + switch (src_prc_) { + case ov::element::f32: + load_qbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + switch (dst_prc_) { + case ov::element::f32: + break; + case ov::element::i32: + cvt_f32_to_i32(h, aux_vec_idxs, out_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); + } + break; + case ov::element::i32: + load_qbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + switch (dst_prc_) { + case ov::element::f32: + cvt_i32_to_f32(h, aux_vec_idxs, out_idxs); + break; + case ov::element::i32: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); + } + break; + case ov::element::f16: + load_dbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + switch (dst_prc_) { + case ov::element::f32: + cvt_f16_to_f32(h, aux_vec_idxs, out_idxs); + break; + case ov::element::i32: + cvt_f16_to_f32(h, aux_vec_idxs, aux_vec_idxs); + cvt_f32_to_i32(h, aux_vec_idxs, out_idxs); + break; + case ov::element::f16: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); + } + break; + case ov::element::i8: + case ov::element::u8: + load_byte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + switch (dst_prc_) { + case ov::element::f32: + cvt_byte_to_i32(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); + cvt_i32_to_f32(h, aux_vec_idxs, out_idxs); + break; + case ov::element::i32: + cvt_byte_to_i32(h, aux_vec_idxs, out_idxs, src_prc_.is_signed()); + break; + case ov::element::i8: + case ov::element::u8: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); + } + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); + } +} + size_t jit_load_emitter::get_aux_gprs_count() const { if (load_num_ == 3) return 1; @@ -72,6 +204,13 @@ size_t jit_load_emitter::get_aux_gprs_count() const { return 0; } +size_t jit_load_emitter::get_aux_vecs_count() const { + if (src_prc_ != dst_prc_) + return 1; + + return 0; +} + jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, ov::element::Type exec_prc, emitter_in_out_map in_out_type) @@ -87,16 +226,12 @@ void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std: } template -void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { - OV_CPU_JIT_EMITTER_ASSERT(src_prc_ == ov::element::f32 && dst_prc_ == ov::element::f32, - "Only supports both input and output precisions of being FP32"); - OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= static_cast((get_vec_length() / dst_prc_.size())), - "Unexpected number of elements to store."); - +void jit_store_emitter::store_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); SReg src_s = SReg(in_idxs[0]); DReg src_d = DReg(in_idxs[0]); + QReg src_q = QReg(in_idxs[0]); XReg dst = XReg(out_idxs[0]); XReg prc = XReg(aux_gpr_idxs[0]); @@ -115,13 +250,149 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: h->st1(src.s[2], ptr(prc)); break; case 4: - h->str(QReg(src.getIdx()), post_ptr(dst, byte_offset_)); + h->str(src_q, post_ptr(dst, byte_offset_)); break; default: OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); } } +template +void jit_store_emitter::store_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + HReg src_h = HReg(in_idxs[0]); + SReg src_s = SReg(in_idxs[0]); + DReg src_d = DReg(in_idxs[0]); + XReg dst = XReg(out_idxs[0]); + XReg prc = XReg(aux_gpr_idxs[0]); + + switch (store_num_) { + case 0: + break; + case 1: + h->str(src_h, post_ptr(dst, byte_offset_)); + break; + case 2: + h->str(src_s, post_ptr(dst, byte_offset_)); + break; + case 3: + h->str(src_s, post_ptr(dst, byte_offset_)); + h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); + h->st1(src.h[2], ptr(prc)); + break; + case 4: + h->str(src_d, post_ptr(dst, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); + } +} + +template +void jit_store_emitter::store_byte(const std::vector &in_idxs, const std::vector &out_idxs) const { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + BReg src_b = BReg(in_idxs[0]); + HReg src_h = HReg(in_idxs[0]); + SReg src_s = SReg(in_idxs[0]); + XReg dst = XReg(out_idxs[0]); + XReg prc = XReg(aux_gpr_idxs[0]); + + switch (store_num_) { + case 0: + break; + case 1: + h->str(src_b, post_ptr(dst, byte_offset_)); + break; + case 2: + h->str(src_h, post_ptr(dst, byte_offset_)); + break; + case 3: + h->str(src_h, post_ptr(dst, byte_offset_)); + h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); + h->st1(src.b[2], ptr(prc)); + break; + case 4: + h->str(src_s, post_ptr(dst, byte_offset_)); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); + } +} + +template +void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { + bool is_supported_precision = one_of(dst_prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + (src_prc_ == dst_prc_ || one_of(src_prc_, ov::element::f32, ov::element::i32)); + OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); + OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= static_cast((get_vec_length() / dst_prc_.size())), + "Unexpected number of elements to store."); + + switch (dst_prc_) { + case ov::element::f32: + switch (src_prc_) { + case ov::element::f32: + break; + case ov::element::i32: + cvt_i32_to_f32(h, in_idxs, aux_vec_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); + } + store_qbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + break; + case ov::element::i32: + switch (src_prc_) { + case ov::element::f32: + cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); + break; + case ov::element::i32: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); + } + store_qbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + break; + case ov::element::f16: + switch (src_prc_) { + case ov::element::f32: + cvt_f32_to_f16(h, in_idxs, aux_vec_idxs); + break; + case ov::element::i32: + cvt_i32_to_f32(h, in_idxs, aux_vec_idxs); + cvt_f32_to_f16(h, aux_vec_idxs, aux_vec_idxs); + break; + case ov::element::f16: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); + } + store_dbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + break; + case ov::element::i8: + case ov::element::u8: + switch (src_prc_) { + case ov::element::f32: + cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); + cvt_i32_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), false); + break; + case ov::element::i32: + cvt_i32_to_byte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), false); + break; + case ov::element::i8: + case ov::element::u8: + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); + } + store_byte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); + } +} + size_t jit_store_emitter::get_aux_gprs_count() const { if (store_num_ == 3) return 1; @@ -129,6 +400,13 @@ size_t jit_store_emitter::get_aux_gprs_count() const { return 0; } +size_t jit_store_emitter::get_aux_vecs_count() const { + if (src_prc_ != dst_prc_) + return 1; + + return 0; +} + } // namespace aarch64 } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 07ce2c2f89a8ea..0c7f79bb65f508 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -24,7 +24,14 @@ class jit_load_emitter : public jit_emitter { private: template void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; + template + void load_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + template + void load_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + template + void load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; size_t get_aux_gprs_count() const override; + size_t get_aux_vecs_count() const override; std::string name_; int load_num_; // the element number to load @@ -46,7 +53,14 @@ class jit_store_emitter : public jit_emitter { private: template void emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const; + template + void store_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + template + void store_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const; + template + void store_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; size_t get_aux_gprs_count() const override; + size_t get_aux_vecs_count() const override; std::string name_; int store_num_; // the element number to store diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp new file mode 100644 index 00000000000000..92a6358645f2b8 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "utils.hpp" + +namespace ov { +namespace intel_cpu { +namespace aarch64 { + +template +void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtl(dst.s4, src.h4); +} + +template +void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtn(dst.h4, src.s4); +} + +template +void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtzs(dst.s, src.s); +} + +template +void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->scvtf(dst.s, src.s); +} + +template +void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_saturated) { + if (is_signed) { + h->sqxtn(dst.h4, src.s4); + h->sqxtn(dst.b8, dst.h8); + } else { + h->uqxtn(dst.h4, src.s4); + h->uqxtn(dst.b8, dst.h8); + } + } else { + h->xtn(dst.h4, src.s4); + h->xtn(dst.b8, dst.h8); + } +} + +template +void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_signed) { + h->sxtl(dst.h8, src.b8); + h->sxtl(dst.s4, dst.h4); + } else { + h->uxtl(dst.h8, src.b8); + h->uxtl(dst.s4, dst.h4); + } +} + +template void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs); + +template void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs); + +template void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs); + +template void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs); + +template void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturation); + +template void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); + +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp new file mode 100644 index 00000000000000..c8218fc696c553 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "cpu/aarch64/cpu_isa_traits.hpp" +#include "cpu/aarch64/jit_generator.hpp" + +namespace ov { +namespace intel_cpu { +namespace aarch64 { + +template +void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); + +template +void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); + +template +void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); + +template +void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); + +template +void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated); + +template +void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); + +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp index 685ace977f7415..a56c2316183643 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp @@ -8,6 +8,7 @@ #include "emitters/utils.hpp" #include "emitters/snippets/cpu_runtime_configurator.hpp" #include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp" +#include "emitters/plugin/aarch64/jit_conversion_emitters.hpp" #include "emitters/snippets/aarch64/jit_kernel_emitter.hpp" #include "emitters/snippets/aarch64/jit_loop_emitters.hpp" #include "emitters/snippets/aarch64/jit_memory_emitters.hpp" @@ -108,6 +109,8 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa) jitters[snippets::op::VectorBuffer::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_nop_emitter); jitters[snippets::op::RankNormalization::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_nop_emitter); jitters[snippets::op::BroadcastMove::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_broadcast_move_emitter); + jitters[snippets::op::ConvertTruncation::get_type_info_static()] = CREATE_CPU_EMITTER(jit_convert_truncation_emitter); + jitters[snippets::op::ConvertSaturation::get_type_info_static()] = CREATE_CPU_EMITTER(jit_convert_saturation_emitter); // memory access jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index 4b497f0286169c..42d30795ad9e73 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -22,11 +22,10 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex } jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { - OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Only Supports equal input and output types but gets ", - src_prc.get_type_name(), - " and ", - dst_prc.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(src_prc == ov::element::f32, "Only supports FP32 precision."); + OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported input type: ", src_prc.get_type_name()); + OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported output type: ", dst_prc.get_type_name()); const auto load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression"); @@ -89,11 +88,10 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const s } jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { - OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Only supports equal input and output types but gets ", - src_prc.get_type_name(), - " and ", - dst_prc.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(src_prc == ov::element::f32, "Only supports FP32 precision."); + OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported input type: ", src_prc.get_type_name()); + OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), + "Unsupported output type: ", dst_prc.get_type_name()); const auto store = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(store != nullptr, "Expects Store expression"); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp index 818f54983b2dfc..22fa00e5752bfd 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp @@ -192,6 +192,28 @@ void MarkSubgraphOpAsSkipped(const std::shared_ptr &node) { } } +bool isSuitableConvert(const std::shared_ptr& node) { + if (!ov::is_type(node)) + return false; + auto isSuitableParent = [](const std::shared_ptr& node) { + for (const auto& input : node->inputs()) { + const auto parent = input.get_source_output().get_node_shared_ptr(); + if (!ov::is_type(parent)) + return false; + } + return true; + }; + auto isSuitableChild = [](const std::shared_ptr& node) { + for (const auto &out : node->outputs()) { + const auto &child = out.get_node_shared_ptr(); + if (!ov::is_type(child)) + return false; + } + return true; + }; + return isSuitableParent(node) || isSuitableChild(node); +} + auto is_skipped_op(const std::shared_ptr& op) -> bool { return ov::is_type(op) || ov::is_type(op) || @@ -225,6 +247,9 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { SetNodeFusingType(node, NodeFusingType::FusedWithMatMul); channelAxis = out_rank.is_static() ? out_rank.get_length() - 1 : DEFAULT_AXIS; } + } else if (isSuitableConvert(node)) { + SetSnippetsNodeType(node, snippets::pass::SnippetsNodeType::SkippedByPlugin); + channelAxis = DEFAULT_AXIS; } else { for (const auto fusingChainType : getContinuableChains(node)) { if (isSuitableChildForFusingBias(node, channelAxis)) { diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 54a038c9492db6..6014d1a7270f39 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -999,6 +999,7 @@ void Transformations::MainSnippets(void) { return (ov::is_type(n) || ov::is_type(n) || ov::is_type(n) || + ov::is_type(n) || ov::is_type(n) || ov::is_type(n) || ov::is_type(n) || @@ -1047,7 +1048,7 @@ void Transformations::MainSnippets(void) { // So i32 is supported exclusively for transposes and broadcast static const std::set supported_element_types = #if defined(OPENVINO_ARCH_ARM64) - { ov::element::f32 }; + {ov::element::f32, ov::element::f16, ov::element::i8, ov::element::u8}; #else {ov::element::f32, ov::element::bf16, ov::element::f16, ov::element::i8, ov::element::u8}; #endif diff --git a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt index dada93a41f7875..cc2fb1f62121b3 100644 --- a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt @@ -76,6 +76,15 @@ else() set(TMP_EXPLICITLY_ENABLED_TESTS "${TMP_LIST_OF_EXPLICITLY_ENABLED_TESTS}") endif() +if(X86_64) + list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/arm) +elseif(AARCH64) + list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/x64) +else() + list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/x64) + list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/arm) +endif() + if(NOT X86_64) list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/x64 diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 50fab4febfa150..d29a1b4dc1aa54 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -150,7 +150,7 @@ std::vector disabledTestPatterns() { R"(.*smoke_Proposal_(Static|Dynamic)_Test_Case1/ProposalLayerCPUTest.*)", // Issue: 111418 R"(.*smoke_Snippets_ConvertStub/ConvertStub\.CompareWithRefImpl/IS.*_OT=\(bf16\)_#N=2_#S=2_targetDevice=CPU.*)", - R"(.*smoke_Snippets_Convert/Convert\.CompareWithRefImpl/IS.*_IT=\(f32\)_OT=\(u8\)_#N=1_#S=1_targetDevice=CPU.*)", + R"(.*smoke_Snippets_Convert/Convert\.CompareWithRefImpl/IS.*_IT=\((f32|f16)\)_OT=\(u8\)_#N=1_#S=1_targetDevice=CPU.*)", R"(.*smoke_Snippets_ConvertManyOnInputs/ConvertManyOnInputs\.CompareWithRefImpl/IS.*_IT=\(f32\.u8\)_OT=\(\)_#N=1_#S=1_targetDevice=CPU.*)", // Issue: 106939 R"(.*ScatterNDUpdateLayerCPUTest.*-1.-1.-1.-2.-2.-2.*)", @@ -485,7 +485,7 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(smoke_Snippets.*\[.*\?.*\].*)"); retVector.emplace_back(R"(smoke_Snippets_Eltwise.*\[1.1..10.1..8.1..4\].*)"); // smoke_Snippets test cases are not supported on arm64 platforms, except for smoke_Snippets_Eltwise - retVector.emplace_back(R"(smoke_Snippets(?!_Eltwise).*)"); + retVector.emplace_back(R"(smoke_Snippets(?!_Eltwise|_Convert).*)"); // arm snippets doesn't support sve_128 that required by dnnl injector jit_uni_eltwise_injector_f32 yet retVector.emplace_back(R"(smoke_Snippets_Eltwise_TwoResults.*)"); retVector.emplace_back(R"(smoke_Snippets_Eltwise/TwoInputsAndOutputs.*)"); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/arm/convert.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/arm/convert.cpp new file mode 100644 index 00000000000000..1230034b778437 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/arm/convert.cpp @@ -0,0 +1,169 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/convert.hpp" +#include "common_test_utils/test_constants.hpp" + +namespace ov { +namespace test { +namespace snippets { + + +namespace { + +const std::vector, std::vector>> types_Convert = { + { { ov::element::f32 }, { ov::element::f16 } }, + { { ov::element::f32 }, { ov::element::i8 } }, + { { ov::element::f32 }, { ov::element::u8 } }, + + { { ov::element::f16 }, { ov::element::f32 } }, + { { ov::element::f16 }, { ov::element::i8 } }, + { { ov::element::f16 }, { ov::element::u8 } }, + + { { ov::element::i8 }, { ov::element::f32 } }, + { { ov::element::i8 }, { ov::element::f16 } }, + { { ov::element::i8 }, { ov::element::u8 } }, + + { { ov::element::u8 }, { ov::element::f32 } }, + { { ov::element::u8 }, { ov::element::f16 } }, + { { ov::element::u8 }, { ov::element::i8 } }, +}; + +const std::vector> inputShapes_Convert = { + { {{}, {{2, 16}}} }, + { {{}, {{5, 7}}} }, + { {{}, {{2, 12, 1}}} }, + { {{{1, 6}, 6}, {{6, 6}, {1, 6}, {6, 6}}} }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Convert, Convert, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_Convert), + ::testing::ValuesIn(types_Convert), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +const std::vector, std::vector>> types_ConvertInput = { + { { ov::element::f32 }, { ov::element::f16 } }, + + { { ov::element::f16 }, { ov::element::f32 } }, + + { { ov::element::i8 }, { ov::element::f32 } }, + { { ov::element::i8 }, { ov::element::f16 } }, + + { { ov::element::u8 }, { ov::element::f32 } }, + { { ov::element::u8 }, { ov::element::f16 } }, +}; + +const std::vector, std::vector>> types_ConvertStub = { + { { ov::element::i8 }, { ov::element::f32 } }, + { { ov::element::i8 }, { ov::element::f16 } }, + + { { ov::element::u8 }, { ov::element::f32 } }, + { { ov::element::u8 }, { ov::element::f16 } }, +}; + +const std::vector> inputShapes_ConvertInput = { + { {{}, {{2, 16}}}, {{}, {{1, 16}}} }, + { {{}, {{5, 18}}}, {{}, {{5, 1}}} }, + { {{}, {{3, 1}}}, {{}, {{3, 21}}} }, + { {{{1, 6}, 6}, {{6, 6}, {1, 6}, {6, 6}}}, {{{1, 6}, 6}, {{1, 6}, {6, 6}, {1, 6}}} }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertInput, ConvertInput, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertInput), + ::testing::ValuesIn(types_ConvertInput), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertOutput, ConvertOutput, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertInput), + ::testing::ValuesIn(types_ConvertInput), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertStub, ConvertStub, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertInput), + ::testing::ValuesIn(types_ConvertStub), + ::testing::Values(2), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +const std::vector, std::vector>> types_ConvertPartialInputsAndResults = { + { { ov::element::i8, ov::element::i8, ov::element::f32 }, { ov::element::f32, ov::element::i8 } }, +}; + +const std::vector> inputShapes_ConvertPartialInputsAndResults = { + { {{}, {{2, 16}}}, {{}, {{1, 16}}}, {{}, {{1, 1}}} }, + { {{}, {{5, 18}}}, {{}, {{5, 1}}}, {{}, {{1, 18}}} }, + { {{}, {{3, 1}}}, {{}, {{3, 21}}}, {{}, {{3, 1}}} }, + { {{{1, 3}, 4}, {{1, 4}, {2, 4}, {3, 4}}}, {{{1, 3}, 4}, {{3, 4}, {2, 4}, {3, 4}}}, {{{1, 3}, 4}, {{1, 4}, {1, 4}, {3, 4}}} }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertPartialInputsAndResults, ConvertPartialInputsAndResults, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertPartialInputsAndResults), + ::testing::ValuesIn(types_ConvertPartialInputsAndResults), + ::testing::Values(2), // subgraph & roll after subgraph + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +const std::vector, std::vector>> types_ConvertMany = { + { { ov::element::f32, ov::element::u8}, {} }, + { { ov::element::f32, ov::element::u8, ov::element::i8 }, {} }, + { { ov::element::f32, ov::element::f32, ov::element::i8, ov::element::i8 }, {} }, +}; + +const std::vector> inputShapes_ConvertManyOnInputs = { + { {{}, {{5, 5, 5, 5}}} }, + { {{{3, 5}, {3, 5}, {3, 5}, 5}, {{5, 5, 5, 5}, {3, 3, 3, 5}, {5, 5, 5, 5}}} } +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertManyOnInputs, ConvertManyOnInputs, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertManyOnInputs), + ::testing::ValuesIn(types_ConvertMany), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertManyOnOutputs, ConvertManyOnOutputs, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertManyOnInputs), + ::testing::ValuesIn(types_ConvertMany), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +const std::vector, std::vector>> types_ConvertManyIO = { + { { ov::element::f32, ov::element::u8}, {ov::element::i8} }, + { { ov::element::f32, ov::element::u8, ov::element::i8 }, { ov::element::u8, ov::element::i8, ov::element::f32, ov::element::f32 } }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvertManyOnInputOutput, ConvertManyOnInputOutput, + ::testing::Combine( + ::testing::ValuesIn(inputShapes_ConvertManyOnInputs), + ::testing::ValuesIn(types_ConvertManyIO), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + Convert::getTestCaseName); + +} // namespace +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/convert.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/x64/convert.cpp similarity index 100% rename from src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/convert.cpp rename to src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/x64/convert.cpp From d5e817461590e6d5d8bd189b2eb9d52982ece1dd Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Tue, 23 Jul 2024 09:18:05 +0000 Subject: [PATCH 02/33] Enable LoadConvertSaturation and three other counterparts --- .../aarch64/jit_load_store_emitters.cpp | 8 ++-- .../aarch64/jit_load_store_emitters.hpp | 3 +- .../snippets/aarch64/cpu_generator.cpp | 6 +++ .../snippets/aarch64/jit_memory_emitters.cpp | 48 +++++++++++++------ .../snippets/aarch64/jit_memory_emitters.hpp | 5 +- .../emitters/snippets/x64/cpu_generator.cpp | 6 +-- .../snippets/x64/jit_memory_emitters.cpp | 4 +- src/plugins/intel_cpu/src/extension.cpp | 4 +- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 35 ++++++++------ .../snippets/aarch64/shape_inference.cpp | 6 +++ .../{x64 => common}/op/load_convert.cpp | 0 .../{x64 => common}/op/load_convert.hpp | 0 .../{x64 => common}/op/store_convert.cpp | 0 .../{x64 => common}/op/store_convert.hpp | 0 .../lowered/fuse_load_store_and_convert.cpp | 4 +- .../lowered/fuse_load_store_and_convert.hpp | 0 16 files changed, 82 insertions(+), 47 deletions(-) rename src/plugins/intel_cpu/src/transformations/snippets/{x64 => common}/op/load_convert.cpp (100%) rename src/plugins/intel_cpu/src/transformations/snippets/{x64 => common}/op/load_convert.hpp (100%) rename src/plugins/intel_cpu/src/transformations/snippets/{x64 => common}/op/store_convert.cpp (100%) rename src/plugins/intel_cpu/src/transformations/snippets/{x64 => common}/op/store_convert.hpp (100%) rename src/plugins/intel_cpu/src/transformations/snippets/{x64 => common}/pass/lowered/fuse_load_store_and_convert.cpp (97%) rename src/plugins/intel_cpu/src/transformations/snippets/{x64 => common}/pass/lowered/fuse_load_store_and_convert.hpp (100%) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 8fa0278bacad42..70a9fbc6502c11 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -213,9 +213,9 @@ size_t jit_load_emitter::get_aux_vecs_count() const { jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, - ov::element::Type exec_prc, emitter_in_out_map in_out_type) + bool is_saturated, ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), - src_prc_(src_prc), dst_prc_(dst_prc) {} + is_saturated_(is_saturated), src_prc_(src_prc), dst_prc_(dst_prc) {} void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { @@ -375,10 +375,10 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: switch (src_prc_) { case ov::element::f32: cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); - cvt_i32_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), false); + cvt_i32_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); break; case ov::element::i32: - cvt_i32_to_byte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), false); + cvt_i32_to_byte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); break; case ov::element::i8: case ov::element::u8: diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 0c7f79bb65f508..77b2f124eff513 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -44,7 +44,7 @@ class jit_store_emitter : public jit_emitter { public: jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset_, - ov::element::Type exec_prc = ov::element::f32, + bool is_saturated = true, ov::element::Type exec_prc = ov::element::f32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; @@ -67,6 +67,7 @@ class jit_store_emitter : public jit_emitter { int byte_offset_; ov::element::Type src_prc_; ov::element::Type dst_prc_; + bool is_saturated_; // true: saturated; false: truncated }; } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp index a56c2316183643..0c405bdff1e5ab 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp @@ -14,6 +14,8 @@ #include "emitters/snippets/aarch64/jit_memory_emitters.hpp" #include "emitters/snippets/aarch64/jit_fill_emitter.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" @@ -115,7 +117,11 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa) // memory access jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter); jitters[snippets::op::BroadcastLoad::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_broadcast_emitter); + jitters[intel_cpu::LoadConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter); + jitters[intel_cpu::LoadConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter); jitters[snippets::op::Store::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_store_memory_emitter); + jitters[intel_cpu::StoreConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_store_memory_emitter); + jitters[intel_cpu::StoreConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_store_memory_emitter); // ternary jitters[intel_cpu::FusedMulAdd::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mul_add_emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index 42d30795ad9e73..2ac1f0c1c1de69 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -3,6 +3,8 @@ // #include "jit_memory_emitters.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" #include "emitters/utils.hpp" using namespace Xbyak_aarch64; @@ -15,13 +17,30 @@ using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; -jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_emitter(h, isa) { +jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr, emitter_in_out_map in_out_type) + : jit_emitter(h, isa) { + in_out_type_ = in_out_type; + const auto n = expr->get_node(); src_prc = n->get_input_element_type(0); dst_prc = n->get_output_element_type(0); + + const auto& memory_access = std::dynamic_pointer_cast(expr->get_node()); + if (in_out_type_ == emitter_in_out_map::gpr_to_vec) { + OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_input_port(0), "Must be input port - memory access"); + count = memory_access->get_input_count(); + byte_offset = memory_access->get_input_offset(); + } else if (in_out_type_ == emitter_in_out_map::vec_to_gpr) { + OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_output_port(0), "Must be output port - memory access"); + count = memory_access->get_output_count(); + byte_offset = memory_access->get_output_offset(); + } else { + OV_CPU_JIT_EMITTER_THROW("Unsupported in_out_type"); + } } -jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { +jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) { OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), "Unsupported input type: ", src_prc.get_type_name()); OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), @@ -29,9 +48,6 @@ jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa const auto load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression"); - count = load->get_count(); - byte_offset = load->get_offset(); - in_out_type_ = emitter_in_out_map::gpr_to_vec; load_emitter.reset(new jit_load_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); } @@ -56,7 +72,7 @@ void jit_load_memory_emitter::emit_data() const { } jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) - : jit_memory_emitter(h, isa, expr) { + : jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) { OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Only support equal input and output types but gets ", src_prc.get_type_name(), " and ", @@ -65,8 +81,6 @@ jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa const auto broadcast_load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(broadcast_load != nullptr, "Expects BroadcastLoad expression"); - byte_offset = broadcast_load->get_offset(); - in_out_type_ = emitter_in_out_map::gpr_to_vec; } void jit_load_broadcast_emitter::emit_impl(const std::vector& in, @@ -87,18 +101,22 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const s h->uni_ld1rw(dst.s, src, byte_offset); } -jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { +jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) { OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), "Unsupported input type: ", src_prc.get_type_name()); OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), "Unsupported output type: ", dst_prc.get_type_name()); - const auto store = ov::as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(store != nullptr, "Expects Store expression"); - count = store->get_count(); - byte_offset = store->get_offset(); - in_out_type_ = emitter_in_out_map::vec_to_gpr; - store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); + if (ov::is_type(expr->get_node())) { + store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, false)); + } else if (ov::is_type(expr->get_node())) { + store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, true)); + } else if (ov::is_type(expr->get_node())) { + store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); + } else { + OV_CPU_JIT_EMITTER_THROW("Expects Store node"); + } } void jit_store_memory_emitter::emit_impl(const std::vector& in, diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp index ba0b4e4acfedb4..72cee670ae4ca8 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp @@ -13,9 +13,8 @@ namespace aarch64 { class jit_memory_emitter : public jit_emitter { public: - jit_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, - dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + jit_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr, emitter_in_out_map in_out_type); protected: ov::element::Type src_prc; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index 4b000ee1521d43..01a87d849f9731 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -20,14 +20,14 @@ #include "emitters/plugin/x64/jit_dnnl_ext_emitters.hpp" #include "emitters/plugin/x64/jit_conversion_emitters.hpp" -#include "transformations/snippets/x64/op/load_convert.hpp" -#include "transformations/snippets/x64/op/store_convert.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" -#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp" +#include "transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.hpp" #include #include "emitters/snippets/cpu_kernel_executor_table.hpp" diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp index 52a77f4feced2c..e54f33a0f56db5 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp @@ -5,8 +5,8 @@ #include "jit_memory_emitters.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" -#include "transformations/snippets/x64/op/load_convert.hpp" -#include "transformations/snippets/x64/op/store_convert.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" #include "snippets/op/buffer.hpp" diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index 61a6255d9c8a01..d5a8801ffedeac 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -27,9 +27,9 @@ #include "transformations/cpu_opset/x64/op/qkv_proj.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include "transformations/snippets/x64/op/load_convert.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" #include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" -#include "transformations/snippets/x64/op/store_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" namespace { diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index b676f54e27c2d0..0935a8b8f1d5ab 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -24,6 +24,7 @@ #include "transformations/defs.hpp" #include "transformations/cpu_opset/common/pass/convert_to_swish_cpu.hpp" #include "transformations/snippets/common/pass/mul_add_to_fma.hpp" +#include "transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.hpp" #if defined(OPENVINO_ARCH_ARM64) #include "emitters/snippets/aarch64/cpu_generator.hpp" @@ -31,7 +32,6 @@ #else #include "emitters/snippets/x64/cpu_generator.hpp" #include "transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp" -#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp" #include "transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp" #include "transformations/snippets/x64/pass/remove_converts.hpp" #include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp" @@ -668,30 +668,35 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const { ControlFlowPasses backend_passes; -#if defined(OPENVINO_ARCH_X86_64) using PassPosition = ov::snippets::pass::PassPosition; using Place = PassPosition::Place; -# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) \ + +# define SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(PASS_PLACE, TARGET_PASS, PASS, ...) \ + backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), std::make_shared(__VA_ARGS__)) + +#if defined(OPENVINO_ARCH_X86_64) +# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) \ backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), std::make_shared(__VA_ARGS__)) #else -# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) +# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) #endif // OPENVINO_ARCH_X86_64 - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::MarkLoops, - ov::intel_cpu::pass::BrgemmCPUBlocking); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::InsertLoops, - ov::intel_cpu::pass::FuseLoadStoreConvert); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, - ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::snippets::lowered::pass::MarkLoops, + ov::intel_cpu::pass::BrgemmCPUBlocking); + SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::lowered::pass::InsertLoops, + ov::intel_cpu::pass::FuseLoadStoreConvert); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, + ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape); #ifdef SNIPPETS_LIBXSMM_TPP - SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, ov::intel_cpu::pass::BrgemmCPUBlocking, - ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, - ov::intel_cpu::tpp::pass::SetTPPLeadingDim); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::intel_cpu::pass::BrgemmCPUBlocking, + ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, + ov::intel_cpu::tpp::pass::SetTPPLeadingDim); #endif -#undef SNIPPETS_REGISTER_PASS_RELATIVE +#undef SNIPPETS_REGISTER_PASS_RELATIVE_COMMON +#undef SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 return backend_passes; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp index 967afe946e0793..41777712201a68 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp @@ -6,6 +6,8 @@ #include "snippets/shape_inference/shape_infer_instances.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" namespace ov { namespace snippets { @@ -29,6 +31,10 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry { SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::SwishNode, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertSaturation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer), + SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer), }; #undef SHAPE_INFER_OP_SPECIFIC #undef SHAPE_INFER_PREDEFINED diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/load_convert.cpp b/src/plugins/intel_cpu/src/transformations/snippets/common/op/load_convert.cpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/snippets/x64/op/load_convert.cpp rename to src/plugins/intel_cpu/src/transformations/snippets/common/op/load_convert.cpp diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/load_convert.hpp b/src/plugins/intel_cpu/src/transformations/snippets/common/op/load_convert.hpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/snippets/x64/op/load_convert.hpp rename to src/plugins/intel_cpu/src/transformations/snippets/common/op/load_convert.hpp diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/store_convert.cpp b/src/plugins/intel_cpu/src/transformations/snippets/common/op/store_convert.cpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/snippets/x64/op/store_convert.cpp rename to src/plugins/intel_cpu/src/transformations/snippets/common/op/store_convert.cpp diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/store_convert.hpp b/src/plugins/intel_cpu/src/transformations/snippets/common/op/store_convert.hpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/snippets/x64/op/store_convert.hpp rename to src/plugins/intel_cpu/src/transformations/snippets/common/op/store_convert.hpp diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.cpp similarity index 97% rename from src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp rename to src/plugins/intel_cpu/src/transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.cpp index de36f1e4a70148..d0fd97b5bd3133 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.cpp @@ -8,8 +8,8 @@ #include "snippets/snippets_isa.hpp" #include "snippets/lowered/loop_manager.hpp" -#include "transformations/snippets/x64/op/load_convert.hpp" -#include "transformations/snippets/x64/op/store_convert.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" bool ov::intel_cpu::pass::FuseLoadStoreConvert::fuse_load_convert(snippets::lowered::LinearIR& linear_ir, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp b/src/plugins/intel_cpu/src/transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.hpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp rename to src/plugins/intel_cpu/src/transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.hpp From 5e13530eaf2182810020aa515c45bfbe04697c66 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Wed, 31 Jul 2024 08:49:07 +0000 Subject: [PATCH 03/33] Fix the issue regarding initialization order --- .../src/emitters/plugin/aarch64/jit_load_store_emitters.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 70a9fbc6502c11..3a21cf6e2d70eb 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -215,7 +215,7 @@ jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *ho ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, bool is_saturated, ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), - is_saturated_(is_saturated), src_prc_(src_prc), dst_prc_(dst_prc) {} + src_prc_(src_prc), dst_prc_(dst_prc), is_saturated_(is_saturated) {} void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { From b73240310f2376e967e7a3ef381aff198160caf0 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Wed, 31 Jul 2024 09:22:23 +0000 Subject: [PATCH 04/33] Fix issue regarding incorrect path of headers --- .../src/transformations/snippets/x64/shape_inference.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp index 9096673bf08cc4..bca9f85bdd53e2 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp @@ -7,8 +7,8 @@ #include "op/brgemm_copy_b.hpp" #include "op/brgemm_cpu.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" -#include "op/load_convert.hpp" -#include "op/store_convert.hpp" +#include "transformations/snippets/common/op/load_convert.hpp" +#include "transformations/snippets/common/op/store_convert.hpp" #include "op/perf_count_rdtsc.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #ifdef SNIPPETS_LIBXSMM_TPP From b4c92bc3107e037a8257c8e4b3aa10ff6936512f Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 1 Aug 2024 03:27:16 +0000 Subject: [PATCH 05/33] Support conversion between the same precision --- .../emitters/plugin/aarch64/jit_conversion_emitters.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index 49a5358b0d24c6..b888cdecc7a216 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -17,6 +17,12 @@ template static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, ov::element::Type input_type, ov::element::Type output_type, bool is_saturated) { + if (input_type == output_type) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + h->mov(TReg(out_idxs[0]).b16, TReg(in_idxs[0]).b16); + return; + } + switch (output_type) { case ov::element::f32: switch (input_type) { @@ -117,8 +123,6 @@ void jit_convert_emitter::validate_types() const { "Unsupported input type: ", input_type.get_type_name()); OV_CPU_JIT_EMITTER_ASSERT(one_of(output_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), "Unsupported output type: ", output_type.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(input_type != output_type, "Input type ", input_type.get_type_name(), " and output type ", - output_type.get_type_name(), " should be different."); } size_t jit_convert_emitter::get_inputs_count() const { return 1; } From 96c06c6a6ab1d8512302302867396c9bf5aa6416 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 1 Aug 2024 04:27:48 +0000 Subject: [PATCH 06/33] Fix issue regarding primitive type --- .../custom/single_layer_tests/classes/conversion.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp index 2cde53b8ebddb6..aedc6e843b314d 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp @@ -63,7 +63,17 @@ void ConvertCPULayerTest::SetUp() { auto primitive = selectedType; if (primitive.empty()) primitive = getPrimitiveType(); - if (!isInOutPrecisionSupported(inPrc, outPrc)) +#if defined(OPENVINO_ARCH_ARM64) + if (inPrc == ov::element::u4 || inPrc == ov::element::i4) + primitive = "ref"; + else if (shapes.first.is_static() && shapes.first.rank().get_length() <= 6 && + inPrc != ov::element::bf16 && outPrc != ov::element::bf16 && + inPrc != ov::element::i32 && outPrc != ov::element::i32) // Apply "jit" for the snippets cases + primitive = "jit"; + else + primitive = "acl"; +#endif + if (primitive != "jit" && !isInOutPrecisionSupported(inPrc, outPrc)) primitive = "ref"; validate_out_prc(); From 205e579f4f19d7b03e94596815c8f9a9a0b791f5 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Mon, 5 Aug 2024 03:33:00 +0000 Subject: [PATCH 07/33] Skip test cases on unaligned conversion behavior --- .../functional/shared_tests_instances/skip_tests_config.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index d29a1b4dc1aa54..d771e9116d5f67 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -379,6 +379,12 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_RoPETest.*)"); #endif +#if defined(OPENVINO_ARCH_ARM64) + // Issue: 149216. For low precision model from original framework, Snippets PropagatePrecision should insert ConvertTruncation instead + // of ConvertSaturation when converting larger integer to smaller integer to align with c++ standard and ngraph reference. + retVector.emplace_back(R"(.*smoke_EltwiseChain_MergeConvert_int8/.*Op0=Prod.*Conversion=i8.*)"); +#endif + #if defined(OPENVINO_ARCH_RISCV64) // object is not initialized retVector.emplace_back(R"(.*StaticLoopDynamicSubgraphCPUTest.smoke_StaticLoopWithDynSubgraph.*)"); From b2c31a4582f0a22434fd8be99450d6c29b6486ac Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Tue, 6 Aug 2024 09:32:32 +0000 Subject: [PATCH 08/33] Apply review comments regarding conversion between f16 and i8(u8) --- .../aarch64/jit_conversion_emitters.cpp | 26 ++--- .../aarch64/jit_load_store_emitters.cpp | 12 ++- .../src/emitters/plugin/aarch64/utils.cpp | 94 ++++++++++++++++--- .../src/emitters/plugin/aarch64/utils.hpp | 23 ++++- 4 files changed, 122 insertions(+), 33 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index b888cdecc7a216..db3eae5f53fae6 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -36,7 +36,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_dbyte_to_i32(h, out_idxs, out_idxs, input_type.is_signed()); cvt_i32_to_f32(h, out_idxs, out_idxs); break; default: @@ -56,7 +57,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_dbyte_to_i32(h, out_idxs, out_idxs, input_type.is_signed()); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -75,9 +77,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_i32_to_f32(h, out_idxs, out_idxs); - cvt_f32_to_f16(h, out_idxs, out_idxs); + cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_dbyte_to_f16(h, out_idxs, out_idxs, input_type.is_signed()); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -88,20 +89,21 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, switch (input_type) { case ov::element::f32: cvt_f32_to_i32(h, in_idxs, out_idxs); - cvt_i32_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_i32_to_dbyte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; case ov::element::i32: - cvt_i32_to_byte(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_i32_to_dbyte(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; case ov::element::f16: - cvt_f16_to_f32(h, in_idxs, out_idxs); - cvt_f32_to_i32(h, out_idxs, out_idxs); - cvt_i32_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_f16_to_dbyte(h, in_idxs, out_idxs); + cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i32(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_i32_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 3a21cf6e2d70eb..bb1ebdbbb8069f 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -179,11 +179,13 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v load_byte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); switch (dst_prc_) { case ov::element::f32: - cvt_byte_to_i32(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); + cvt_byte_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); + cvt_dbyte_to_i32(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); cvt_i32_to_f32(h, aux_vec_idxs, out_idxs); break; case ov::element::i32: - cvt_byte_to_i32(h, aux_vec_idxs, out_idxs, src_prc_.is_signed()); + cvt_byte_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); + cvt_dbyte_to_i32(h, aux_vec_idxs, out_idxs, src_prc_.is_signed()); break; case ov::element::i8: case ov::element::u8: @@ -375,10 +377,12 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: switch (src_prc_) { case ov::element::f32: cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); - cvt_i32_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); + cvt_i32_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); + cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); break; case ov::element::i32: - cvt_i32_to_byte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); + cvt_i32_to_dbyte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); + cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); break; case ov::element::i8: case ov::element::u8: diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp index 92a6358645f2b8..1baebf7ea5242b 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp @@ -8,6 +8,13 @@ namespace ov { namespace intel_cpu { namespace aarch64 { +// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported +// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate +// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of +// neighbors in this graph will be done with single instruction. +// f16 - f32 - i32 - dbyte - byte +// | | +// - - - - - - - - - - - template void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -41,7 +48,7 @@ void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vecto } template -void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, bool is_signed, bool is_saturated) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); @@ -49,29 +56,75 @@ void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vect if (is_saturated) { if (is_signed) { h->sqxtn(dst.h4, src.s4); - h->sqxtn(dst.b8, dst.h8); } else { h->uqxtn(dst.h4, src.s4); - h->uqxtn(dst.b8, dst.h8); } } else { h->xtn(dst.h4, src.s4); - h->xtn(dst.b8, dst.h8); } } template -void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, bool is_signed) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); + if (is_signed) { + h->sxtl(dst.s4, src.h4); + } else { + h->uxtl(dst.s4, src.h4); + } +} + +template +void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtzs(dst.h, src.h); +} + +template +void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_signed) { + h->scvtf(dst.h, src.h); + } else { + h->ucvtf(dst.h, src.h); + } +} + +template +void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_saturated) { + if (is_signed) { + h->sqxtn(dst.b8, src.h8); + } else { + h->uqxtn(dst.b8, src.h8); + } + } else { + h->xtn(dst.b8, src.h8); + } +} + +template +void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); if (is_signed) { h->sxtl(dst.h8, src.b8); - h->sxtl(dst.s4, dst.h4); } else { h->uxtl(dst.h8, src.b8); - h->uxtl(dst.s4, dst.h4); } } @@ -87,13 +140,28 @@ template void cvt_f32_to_i32(dnnl::impl::cpu::a template void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); -template void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturation); +template void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturation); + +template void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); + +template void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs); + +template void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); + +template void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturation); -template void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); +template void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, + const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); } // namespace aarch64 } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp index c8218fc696c553..3bcd56db46c35c 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp @@ -24,12 +24,27 @@ template void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); template -void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturated); +void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated); template -void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); +void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); + +template +void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); + +template +void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); + +template +void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated); + +template +void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed); } // namespace aarch64 } // namespace intel_cpu From ffeadd5df28f0de97f8879ab1ed314b3fd8c793a Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Wed, 7 Aug 2024 09:24:25 +0000 Subject: [PATCH 09/33] Revise CMakeLists --- .../intel_cpu/tests/functional/CMakeLists.txt | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt index cc2fb1f62121b3..3092356e1189b6 100644 --- a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt @@ -47,9 +47,10 @@ elseif(NOT OV_CPU_WITH_SHL) endif() if(NOT (X86 OR X86_64)) - list(APPEND EXCLUDED_SOURCE_PATHS + list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/x64 ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/x64 + ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/x64 ${CMAKE_CURRENT_SOURCE_DIR}/utils/x64) endif() @@ -57,6 +58,7 @@ if(NOT (ARM OR AARCH64)) list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/arm ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/arm + ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/arm ${CMAKE_CURRENT_SOURCE_DIR}/utils/arm) else() # temporary disable all custom tests for ARM @@ -76,15 +78,6 @@ else() set(TMP_EXPLICITLY_ENABLED_TESTS "${TMP_LIST_OF_EXPLICITLY_ENABLED_TESTS}") endif() -if(X86_64) - list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/arm) -elseif(AARCH64) - list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/x64) -else() - list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/x64) - list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/shared_tests_instances/snippets/arm) -endif() - if(NOT X86_64) list(APPEND EXCLUDED_SOURCE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/x64 From 6b030811e9a5012264de800e40e9664164182644 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 8 Aug 2024 03:09:01 +0000 Subject: [PATCH 10/33] Apply arithmetic_mode to align with x64 --- .../plugin/aarch64/jit_load_store_emitters.cpp | 12 ++++++------ .../plugin/aarch64/jit_load_store_emitters.hpp | 10 ++++++++-- .../snippets/aarch64/jit_memory_emitters.cpp | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index bb1ebdbbb8069f..a11076449b4969 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -215,9 +215,9 @@ size_t jit_load_emitter::get_aux_vecs_count() const { jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, - bool is_saturated, ov::element::Type exec_prc, emitter_in_out_map in_out_type) + arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), - src_prc_(src_prc), dst_prc_(dst_prc), is_saturated_(is_saturated) {} + src_prc_(src_prc), dst_prc_(dst_prc), mode_(mode) {} void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { @@ -377,12 +377,12 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: switch (src_prc_) { case ov::element::f32: cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); - cvt_i32_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); - cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); + cvt_i32_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); + cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); break; case ov::element::i32: - cvt_i32_to_dbyte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); - cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_); + cvt_i32_to_dbyte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); + cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); break; case ov::element::i8: case ov::element::u8: diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 77b2f124eff513..6892a48f5a1b03 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -11,6 +11,12 @@ namespace ov { namespace intel_cpu { namespace aarch64 { +// Arithmetic modes for data type conversion in store_emitter +enum arithmetic_mode { + saturation, + truncation +}; + class jit_load_emitter : public jit_emitter { public: jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, @@ -44,7 +50,7 @@ class jit_store_emitter : public jit_emitter { public: jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset_, - bool is_saturated = true, ov::element::Type exec_prc = ov::element::f32, + arithmetic_mode mode = arithmetic_mode::saturation, ov::element::Type exec_prc = ov::element::f32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; @@ -67,7 +73,7 @@ class jit_store_emitter : public jit_emitter { int byte_offset_; ov::element::Type src_prc_; ov::element::Type dst_prc_; - bool is_saturated_; // true: saturated; false: truncated + arithmetic_mode mode_ = arithmetic_mode::saturation; }; } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index 2ac1f0c1c1de69..985b3d3cc3d580 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -109,9 +109,9 @@ jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t i "Unsupported output type: ", dst_prc.get_type_name()); if (ov::is_type(expr->get_node())) { - store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, false)); + store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation)); } else if (ov::is_type(expr->get_node())) { - store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, true)); + store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::saturation)); } else if (ov::is_type(expr->get_node())) { store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); } else { From 4d99437aedb36f3d845f28dc5235e6803e49478b Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 8 Aug 2024 03:41:07 +0000 Subject: [PATCH 11/33] Update precision assertion --- .../snippets/aarch64/jit_memory_emitters.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index 985b3d3cc3d580..ae4cda0e13ce8c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -41,10 +41,9 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) { - OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported input type: ", src_prc.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported output type: ", dst_prc.get_type_name()); + bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + (src_prc == dst_prc || one_of(dst_prc, ov::element::f32, ov::element::i32)); + OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); const auto load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression"); @@ -103,10 +102,9 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const s jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) { - OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported input type: ", src_prc.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported output type: ", dst_prc.get_type_name()); + bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + (src_prc == dst_prc || one_of(src_prc, ov::element::f32, ov::element::i32)); + OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); if (ov::is_type(expr->get_node())) { store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation)); From f03800b945d07589334d56e154b1bc01fcdfef30 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 9 Aug 2024 05:32:36 +0000 Subject: [PATCH 12/33] Replace post_ptr with ptr --- .../aarch64/jit_load_store_emitters.cpp | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index a11076449b4969..4f6d5bda38f82b 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -43,13 +43,13 @@ void jit_load_emitter::load_qbyte(const std::vector &in_idxs, const std: case 0: break; case 1: - h->ldr(dst_s, post_ptr(src, byte_offset_)); + h->ldr(dst_s, ptr(src, byte_offset_)); break; case 2: - h->ldr(dst_d, post_ptr(src, byte_offset_)); + h->ldr(dst_d, ptr(src, byte_offset_)); break; case 3: - h->ldr(dst_d, post_ptr(src, byte_offset_)); + h->ldr(dst_d, ptr(src, byte_offset_)); h->add_imm(prc, src, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); h->ld1(dst.s[2], ptr(prc)); break; @@ -75,18 +75,18 @@ void jit_load_emitter::load_dbyte(const std::vector &in_idxs, const std: case 0: break; case 1: - h->ldr(dst_h, post_ptr(src, byte_offset_)); + h->ldr(dst_h, ptr(src, byte_offset_)); break; case 2: - h->ldr(dst_s, post_ptr(src, byte_offset_)); + h->ldr(dst_s, ptr(src, byte_offset_)); break; case 3: - h->ldr(dst_s, post_ptr(src, byte_offset_)); + h->ldr(dst_s, ptr(src, byte_offset_)); h->add_imm(prc, src, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); h->ld1(dst.h[2], ptr(prc)); break; case 4: - h->ldr(dst_d, post_ptr(src, byte_offset_)); + h->ldr(dst_d, ptr(src, byte_offset_)); break; default: OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); @@ -107,18 +107,18 @@ void jit_load_emitter::load_byte(const std::vector &in_idxs, const std:: case 0: break; case 1: - h->ldr(dst_b, post_ptr(src, byte_offset_)); + h->ldr(dst_b, ptr(src, byte_offset_)); break; case 2: - h->ldr(dst_h, post_ptr(src, byte_offset_)); + h->ldr(dst_h, ptr(src, byte_offset_)); break; case 3: - h->ldr(dst_h, post_ptr(src, byte_offset_)); + h->ldr(dst_h, ptr(src, byte_offset_)); h->add_imm(prc, src, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); h->ld1(dst.b[2], ptr(prc)); break; case 4: - h->ldr(dst_s, post_ptr(src, byte_offset_)); + h->ldr(dst_s, ptr(src, byte_offset_)); break; default: OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to load."); @@ -241,18 +241,18 @@ void jit_store_emitter::store_qbyte(const std::vector &in_idxs, const st case 0: break; case 1: - h->str(src_s, post_ptr(dst, byte_offset_)); + h->str(src_s, ptr(dst, byte_offset_)); break; case 2: - h->str(src_d, post_ptr(dst, byte_offset_)); + h->str(src_d, ptr(dst, byte_offset_)); break; case 3: - h->str(src_d, post_ptr(dst, byte_offset_)); + h->str(src_d, ptr(dst, byte_offset_)); h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); h->st1(src.s[2], ptr(prc)); break; case 4: - h->str(src_q, post_ptr(dst, byte_offset_)); + h->str(src_q, ptr(dst, byte_offset_)); break; default: OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); @@ -273,18 +273,18 @@ void jit_store_emitter::store_dbyte(const std::vector &in_idxs, const st case 0: break; case 1: - h->str(src_h, post_ptr(dst, byte_offset_)); + h->str(src_h, ptr(dst, byte_offset_)); break; case 2: - h->str(src_s, post_ptr(dst, byte_offset_)); + h->str(src_s, ptr(dst, byte_offset_)); break; case 3: - h->str(src_s, post_ptr(dst, byte_offset_)); + h->str(src_s, ptr(dst, byte_offset_)); h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); h->st1(src.h[2], ptr(prc)); break; case 4: - h->str(src_d, post_ptr(dst, byte_offset_)); + h->str(src_d, ptr(dst, byte_offset_)); break; default: OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); @@ -305,18 +305,18 @@ void jit_store_emitter::store_byte(const std::vector &in_idxs, const std case 0: break; case 1: - h->str(src_b, post_ptr(dst, byte_offset_)); + h->str(src_b, ptr(dst, byte_offset_)); break; case 2: - h->str(src_h, post_ptr(dst, byte_offset_)); + h->str(src_h, ptr(dst, byte_offset_)); break; case 3: - h->str(src_h, post_ptr(dst, byte_offset_)); + h->str(src_h, ptr(dst, byte_offset_)); h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); h->st1(src.b[2], ptr(prc)); break; case 4: - h->str(src_s, post_ptr(dst, byte_offset_)); + h->str(src_s, ptr(dst, byte_offset_)); break; default: OV_CPU_JIT_EMITTER_THROW("Unexpected number of elements to store."); From 63d16c2d46403f62b52cb480242e719f0ff67001 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Mon, 19 Aug 2024 06:21:53 +0000 Subject: [PATCH 13/33] Set IGNORE_CALLBACK if rank > 6 --- .../single_layer_tests/classes/conversion.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp index aedc6e843b314d..4989fb3a0f04b7 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp @@ -5,6 +5,7 @@ #include "conversion.hpp" #include "gtest/gtest.h" +#include "internal_properties.hpp" #include "utils/cpu_test_utils.hpp" #include "common_test_utils/data_utils.hpp" #include "shared_test_classes/base/utils/compare_results.hpp" @@ -64,14 +65,18 @@ void ConvertCPULayerTest::SetUp() { if (primitive.empty()) primitive = getPrimitiveType(); #if defined(OPENVINO_ARCH_ARM64) - if (inPrc == ov::element::u4 || inPrc == ov::element::i4) + if (inPrc == ov::element::u4 || inPrc == ov::element::i4) { primitive = "ref"; - else if (shapes.first.is_static() && shapes.first.rank().get_length() <= 6 && + } else if (shapes.first.is_static() && inPrc != ov::element::bf16 && outPrc != ov::element::bf16 && - inPrc != ov::element::i32 && outPrc != ov::element::i32) // Apply "jit" for the snippets cases + inPrc != ov::element::i32 && outPrc != ov::element::i32) { // Apply "jit" for the snippets cases primitive = "jit"; - else + if (shapes.first.rank().get_length() > 6) { + configuration.insert(ov::intel_cpu::snippets_mode(ov::intel_cpu::SnippetsMode::IGNORE_CALLBACK)); + } + } else { primitive = "acl"; + } #endif if (primitive != "jit" && !isInOutPrecisionSupported(inPrc, outPrc)) primitive = "ref"; From 1ea1b6b6ebeda8aee14c0bfffc6143ecc9906fbe Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Mon, 19 Aug 2024 07:30:07 +0000 Subject: [PATCH 14/33] Update isSuitableConvert --- .../snippets/aarch64/pass/snippets_mark_skipped.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp index 22fa00e5752bfd..c38d088ef95e7b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp @@ -206,7 +206,7 @@ bool isSuitableConvert(const std::shared_ptr& node) { auto isSuitableChild = [](const std::shared_ptr& node) { for (const auto &out : node->outputs()) { const auto &child = out.get_node_shared_ptr(); - if (!ov::is_type(child)) + if (!ov::is_type(child)) return false; } return true; From da5a32d6cd2110009f0dc35b01d4593eb2a83b24 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Mon, 19 Aug 2024 08:01:57 +0000 Subject: [PATCH 15/33] Update jit_store_memory_emitter constructor --- .../snippets/aarch64/jit_memory_emitters.cpp | 36 ++++++++----------- .../snippets/aarch64/jit_memory_emitters.hpp | 5 +-- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index ae4cda0e13ce8c..9dcaeaea126de4 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -17,36 +17,22 @@ using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; -jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr, emitter_in_out_map in_out_type) - : jit_emitter(h, isa) { - in_out_type_ = in_out_type; - +jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_emitter(h, isa) { const auto n = expr->get_node(); src_prc = n->get_input_element_type(0); dst_prc = n->get_output_element_type(0); - - const auto& memory_access = std::dynamic_pointer_cast(expr->get_node()); - if (in_out_type_ == emitter_in_out_map::gpr_to_vec) { - OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_input_port(0), "Must be input port - memory access"); - count = memory_access->get_input_count(); - byte_offset = memory_access->get_input_offset(); - } else if (in_out_type_ == emitter_in_out_map::vec_to_gpr) { - OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_output_port(0), "Must be output port - memory access"); - count = memory_access->get_output_count(); - byte_offset = memory_access->get_output_offset(); - } else { - OV_CPU_JIT_EMITTER_THROW("Unsupported in_out_type"); - } } -jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) - : jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) { +jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && (src_prc == dst_prc || one_of(dst_prc, ov::element::f32, ov::element::i32)); OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); const auto load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression"); + count = load->get_count(); + byte_offset = load->get_offset(); + in_out_type_ = emitter_in_out_map::gpr_to_vec; load_emitter.reset(new jit_load_emitter(h, isa, src_prc, dst_prc, count, byte_offset)); } @@ -71,7 +57,7 @@ void jit_load_memory_emitter::emit_data() const { } jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) - : jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) { + : jit_memory_emitter(h, isa, expr) { OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Only support equal input and output types but gets ", src_prc.get_type_name(), " and ", @@ -80,6 +66,8 @@ jit_load_broadcast_emitter::jit_load_broadcast_emitter(jit_generator* h, cpu_isa const auto broadcast_load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(broadcast_load != nullptr, "Expects BroadcastLoad expression"); + byte_offset = broadcast_load->get_offset(); + in_out_type_ = emitter_in_out_map::gpr_to_vec; } void jit_load_broadcast_emitter::emit_impl(const std::vector& in, @@ -100,12 +88,16 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const s h->uni_ld1rw(dst.s, src, byte_offset); } -jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) - : jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) { +jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) { bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && (src_prc == dst_prc || one_of(src_prc, ov::element::f32, ov::element::i32)); OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); + const auto store = ov::as_type_ptr(expr->get_node()); + OV_CPU_JIT_EMITTER_ASSERT(store != nullptr, "Expects Store expression"); + count = store->get_count(); + byte_offset = store->get_offset(); + in_out_type_ = emitter_in_out_map::vec_to_gpr; if (ov::is_type(expr->get_node())) { store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation)); } else if (ov::is_type(expr->get_node())) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp index 72cee670ae4ca8..ba0b4e4acfedb4 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.hpp @@ -13,8 +13,9 @@ namespace aarch64 { class jit_memory_emitter : public jit_emitter { public: - jit_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr, emitter_in_out_map in_out_type); + jit_memory_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); protected: ov::element::Type src_prc; From 7e8d9c036a4390af2049f65796c3a8b76fa30472 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Mon, 19 Aug 2024 08:03:47 +0000 Subject: [PATCH 16/33] Update enum class arithmetic_mode --- .../src/emitters/plugin/aarch64/jit_load_store_emitters.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 6892a48f5a1b03..34a422c96739fd 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -12,7 +12,7 @@ namespace intel_cpu { namespace aarch64 { // Arithmetic modes for data type conversion in store_emitter -enum arithmetic_mode { +enum class arithmetic_mode { saturation, truncation }; From 17471479219561ecfb361f793958c90fdd378d4d Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Mon, 19 Aug 2024 09:32:19 +0000 Subject: [PATCH 17/33] Call convert_emitter in load/store_emitter --- .../aarch64/jit_conversion_emitters.cpp | 144 ++++++++++++++- .../aarch64/jit_conversion_emitters.hpp | 6 + .../aarch64/jit_load_store_emitters.cpp | 121 +++---------- .../aarch64/jit_load_store_emitters.hpp | 6 +- .../src/emitters/plugin/aarch64/utils.cpp | 168 ------------------ .../src/emitters/plugin/aarch64/utils.hpp | 51 ------ 6 files changed, 174 insertions(+), 322 deletions(-) delete mode 100644 src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp delete mode 100644 src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index db3eae5f53fae6..faf481b525c351 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -4,7 +4,6 @@ #include "jit_conversion_emitters.hpp" #include "emitters/utils.hpp" -#include "utils.hpp" using namespace dnnl::impl::cpu::aarch64; using namespace Xbyak_aarch64; @@ -13,6 +12,126 @@ namespace ov { namespace intel_cpu { namespace aarch64 { +// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported +// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate +// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of +// neighbors in this graph will be done with single instruction. +// f16 - f32 - i32 - dbyte - byte +// | | +// - - - - - - - - - - - +template +static void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtl(dst.s4, src.h4); +} + +template +static void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtn(dst.h4, src.s4); +} + +template +static void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtzs(dst.s, src.s); +} + +template +static void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->scvtf(dst.s, src.s); +} + +template +static void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_saturated) { + if (is_signed) { + h->sqxtn(dst.h4, src.s4); + } else { + h->uqxtn(dst.h4, src.s4); + } + } else { + h->xtn(dst.h4, src.s4); + } +} + +template +static void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_signed) { + h->sxtl(dst.s4, src.h4); + } else { + h->uxtl(dst.s4, src.h4); + } +} + +template +static void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + h->fcvtzs(dst.h, src.h); +} + +template +static void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_signed) { + h->scvtf(dst.h, src.h); + } else { + h->ucvtf(dst.h, src.h); + } +} + +template +static void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed, bool is_saturated) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_saturated) { + if (is_signed) { + h->sqxtn(dst.b8, src.h8); + } else { + h->uqxtn(dst.b8, src.h8); + } + } else { + h->xtn(dst.b8, src.h8); + } +} + +template +static void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_signed) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (is_signed) { + h->sxtl(dst.h8, src.b8); + } else { + h->uxtl(dst.h8, src.b8); + } +} + template static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, @@ -120,6 +239,15 @@ jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa output_type = node->get_output_element_type(0); } +jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, + ov::element::Type input_prc, + ov::element::Type output_prc, + ov::element::Type exec_prc) +: jit_emitter(host, host_isa, exec_prc) { + input_type = input_prc; + output_type = output_prc; +} + void jit_convert_emitter::validate_types() const { OV_CPU_JIT_EMITTER_ASSERT(one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), "Unsupported input type: ", input_type.get_type_name()); @@ -138,6 +266,13 @@ jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *ho : jit_convert_emitter(host, host_isa, node, exec_prc) { } +jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa, + ov::element::Type input_prc, + ov::element::Type output_prc, + ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) { +} + void jit_convert_truncation_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { validate_types(); if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { @@ -157,6 +292,13 @@ jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *ho : jit_convert_emitter(host, host_isa, node, exec_prc) { } +jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa, + ov::element::Type input_prc, + ov::element::Type output_prc, + ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) { +} + void jit_convert_saturation_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { validate_types(); if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp index df24f714ccc55b..16a08bdf4b0f0e 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp @@ -14,6 +14,8 @@ class jit_convert_emitter : public jit_emitter { public: jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_count() const override; @@ -33,6 +35,8 @@ class jit_convert_truncation_emitter : public jit_convert_emitter { public: jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32); private: void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; @@ -48,6 +52,8 @@ class jit_convert_saturation_emitter : public jit_convert_emitter { public: jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32); private: void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 4f6d5bda38f82b..0f9d2a817fa1ac 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -5,7 +5,6 @@ #include "jit_load_store_emitters.hpp" #include "cpu/aarch64/cpu_isa_traits.hpp" #include "emitters/utils.hpp" -#include "utils.hpp" using namespace Xbyak_aarch64; @@ -20,7 +19,9 @@ jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, int byte_offset, ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset), - src_prc_(src_prc), dst_prc_(dst_prc) {} + src_prc_(src_prc), dst_prc_(dst_prc) { + convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); +} void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { @@ -135,68 +136,23 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v switch (src_prc_) { case ov::element::f32: - load_qbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); - switch (dst_prc_) { - case ov::element::f32: - break; - case ov::element::i32: - cvt_f32_to_i32(h, aux_vec_idxs, out_idxs); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); - } - break; case ov::element::i32: load_qbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); - switch (dst_prc_) { - case ov::element::f32: - cvt_i32_to_f32(h, aux_vec_idxs, out_idxs); - break; - case ov::element::i32: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); - } break; case ov::element::f16: load_dbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); - switch (dst_prc_) { - case ov::element::f32: - cvt_f16_to_f32(h, aux_vec_idxs, out_idxs); - break; - case ov::element::i32: - cvt_f16_to_f32(h, aux_vec_idxs, aux_vec_idxs); - cvt_f32_to_i32(h, aux_vec_idxs, out_idxs); - break; - case ov::element::f16: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); - } break; case ov::element::i8: case ov::element::u8: load_byte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); - switch (dst_prc_) { - case ov::element::f32: - cvt_byte_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); - cvt_dbyte_to_i32(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); - cvt_i32_to_f32(h, aux_vec_idxs, out_idxs); - break; - case ov::element::i32: - cvt_byte_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed()); - cvt_dbyte_to_i32(h, aux_vec_idxs, out_idxs, src_prc_.is_signed()); - break; - case ov::element::i8: - case ov::element::u8: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); - } break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); } + + if (src_prc_ != dst_prc_) { + convert_emitter->emit_code(aux_vec_idxs, out_idxs); + } } size_t jit_load_emitter::get_aux_gprs_count() const { @@ -217,7 +173,15 @@ jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *ho ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), - src_prc_(src_prc), dst_prc_(dst_prc), mode_(mode) {} + src_prc_(src_prc), dst_prc_(dst_prc) { + if (mode == arithmetic_mode::truncation) { + convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + } else if (mode == arithmetic_mode::saturation) { + convert_emitter.reset(new jit_convert_saturation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + } else { + OV_CPU_JIT_EMITTER_THROW("Unsupported Convert emitter."); + } +} void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { @@ -331,65 +295,20 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= static_cast((get_vec_length() / dst_prc_.size())), "Unexpected number of elements to store."); + if (src_prc_ != dst_prc_) { + convert_emitter->emit_code(in_idxs, aux_vec_idxs); + } + switch (dst_prc_) { case ov::element::f32: - switch (src_prc_) { - case ov::element::f32: - break; - case ov::element::i32: - cvt_i32_to_f32(h, in_idxs, aux_vec_idxs); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); - } - store_qbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); - break; case ov::element::i32: - switch (src_prc_) { - case ov::element::f32: - cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); - break; - case ov::element::i32: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); - } store_qbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); break; case ov::element::f16: - switch (src_prc_) { - case ov::element::f32: - cvt_f32_to_f16(h, in_idxs, aux_vec_idxs); - break; - case ov::element::i32: - cvt_i32_to_f32(h, in_idxs, aux_vec_idxs); - cvt_f32_to_f16(h, aux_vec_idxs, aux_vec_idxs); - break; - case ov::element::f16: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); - } store_dbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); break; case ov::element::i8: case ov::element::u8: - switch (src_prc_) { - case ov::element::f32: - cvt_f32_to_i32(h, in_idxs, aux_vec_idxs); - cvt_i32_to_dbyte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); - cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); - break; - case ov::element::i32: - cvt_i32_to_dbyte(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); - cvt_dbyte_to_byte(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation); - break; - case ov::element::i8: - case ov::element::u8: - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); - } store_byte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); break; default: diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 34a422c96739fd..3467e083cdf6b1 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -6,6 +6,7 @@ #include "jit_emitter.hpp" #include "cpu/aarch64/jit_generator.hpp" +#include "emitters/plugin/aarch64/jit_conversion_emitters.hpp" namespace ov { namespace intel_cpu { @@ -39,6 +40,8 @@ class jit_load_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; size_t get_aux_vecs_count() const override; + std::unique_ptr convert_emitter = nullptr; + std::string name_; int load_num_; // the element number to load int byte_offset_; @@ -68,12 +71,13 @@ class jit_store_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; size_t get_aux_vecs_count() const override; + std::unique_ptr convert_emitter = nullptr; + std::string name_; int store_num_; // the element number to store int byte_offset_; ov::element::Type src_prc_; ov::element::Type dst_prc_; - arithmetic_mode mode_ = arithmetic_mode::saturation; }; } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp deleted file mode 100644 index 1baebf7ea5242b..00000000000000 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "utils.hpp" - -namespace ov { -namespace intel_cpu { -namespace aarch64 { - -// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported -// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate -// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of -// neighbors in this graph will be done with single instruction. -// f16 - f32 - i32 - dbyte - byte -// | | -// - - - - - - - - - - - -template -void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - h->fcvtl(dst.s4, src.h4); -} - -template -void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - h->fcvtn(dst.h4, src.s4); -} - -template -void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - h->fcvtzs(dst.s, src.s); -} - -template -void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - h->scvtf(dst.s, src.s); -} - -template -void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturated) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - if (is_saturated) { - if (is_signed) { - h->sqxtn(dst.h4, src.s4); - } else { - h->uqxtn(dst.h4, src.s4); - } - } else { - h->xtn(dst.h4, src.s4); - } -} - -template -void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - if (is_signed) { - h->sxtl(dst.s4, src.h4); - } else { - h->uxtl(dst.s4, src.h4); - } -} - -template -void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - h->fcvtzs(dst.h, src.h); -} - -template -void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - if (is_signed) { - h->scvtf(dst.h, src.h); - } else { - h->ucvtf(dst.h, src.h); - } -} - -template -void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturated) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - if (is_saturated) { - if (is_signed) { - h->sqxtn(dst.b8, src.h8); - } else { - h->uqxtn(dst.b8, src.h8); - } - } else { - h->xtn(dst.b8, src.h8); - } -} - -template -void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - if (is_signed) { - h->sxtl(dst.h8, src.b8); - } else { - h->uxtl(dst.h8, src.b8); - } -} - -template void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs); - -template void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs); - -template void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs); - -template void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs); - -template void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturation); - -template void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); - -template void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs); - -template void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); - -template void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturation); - -template void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); - -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp deleted file mode 100644 index 3bcd56db46c35c..00000000000000 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "cpu/aarch64/cpu_isa_traits.hpp" -#include "cpu/aarch64/jit_generator.hpp" - -namespace ov { -namespace intel_cpu { -namespace aarch64 { - -template -void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); - -template -void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); - -template -void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); - -template -void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); - -template -void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturated); - -template -void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); - -template -void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs); - -template -void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); - -template -void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturated); - -template -void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed); - -} // namespace aarch64 -} // namespace intel_cpu -} // namespace ov From f7d4f11e565af8c780e6da91c8a09afe1b8fe2f1 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Tue, 20 Aug 2024 06:35:15 +0000 Subject: [PATCH 18/33] Update arguments for instructions regarding f16 conversion --- .../src/emitters/plugin/aarch64/jit_conversion_emitters.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index faf481b525c351..c3eccbf8727b4b 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -86,7 +86,7 @@ static void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const s using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); - h->fcvtzs(dst.h, src.h); + h->fcvtzs(dst.h4, src.h4); } template @@ -96,9 +96,9 @@ static void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const s TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); if (is_signed) { - h->scvtf(dst.h, src.h); + h->scvtf(dst.h4, src.h4); } else { - h->ucvtf(dst.h, src.h); + h->ucvtf(dst.h4, src.h4); } } From b21c9ddbf7c4d30491275fa586f718685cd72490 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Tue, 20 Aug 2024 08:46:34 +0000 Subject: [PATCH 19/33] Make conversion between f16 and i8 compatible with ARMv8 --- .../aarch64/jit_conversion_emitters.cpp | 79 +++++++++---------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index c3eccbf8727b4b..09a9ac7d28b483 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -14,11 +14,16 @@ namespace aarch64 { // In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported // conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate -// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of +// precision i16 in the following graph. Then the conversion between each pair of // neighbors in this graph will be done with single instruction. -// f16 - f32 - i32 - dbyte - byte -// | | -// - - - - - - - - - - - +// f16 - f32 - i32 - i16 - byte +// | | +// - - - - - - - - - - +// Note that using single instruction for conversion between f16 and i16 is only available for +// architecture ARMv8.2-A or later versions. So ARM platforms like Raspberry (Model name Cortex-A72) +// with architecture ARMv8 do not support such instructions. And as the isa asimd we supported +// does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three +// instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16). template static void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; @@ -52,37 +57,28 @@ static void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std } template -static void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed, bool is_saturated) { +static void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, + bool is_saturated) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); if (is_saturated) { - if (is_signed) { - h->sqxtn(dst.h4, src.s4); - } else { - h->uqxtn(dst.h4, src.s4); - } + h->sqxtn(dst.h4, src.s4); } else { h->xtn(dst.h4, src.s4); } } template -static void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed) { +static void cvt_i16_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); - if (is_signed) { - h->sxtl(dst.s4, src.h4); - } else { - h->uxtl(dst.s4, src.h4); - } + h->sxtl(dst.s4, src.h4); } template -static void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { +static void cvt_f16_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); @@ -90,20 +86,15 @@ static void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const s } template -static void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, - bool is_signed) { +static void cvt_i16_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); - if (is_signed) { - h->scvtf(dst.h4, src.h4); - } else { - h->ucvtf(dst.h4, src.h4); - } + h->scvtf(dst.h4, src.h4); } template -static void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +static void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, bool is_signed, bool is_saturated) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); @@ -120,7 +111,7 @@ static void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const } template -static void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +static void cvt_byte_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, bool is_signed) { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; TReg src = TReg(in_idxs[0]); @@ -155,8 +146,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_dbyte_to_i32(h, out_idxs, out_idxs, input_type.is_signed()); + cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i16_to_i32(h, out_idxs, out_idxs); cvt_i32_to_f32(h, out_idxs, out_idxs); break; default: @@ -176,8 +167,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_dbyte_to_i32(h, out_idxs, out_idxs, input_type.is_signed()); + cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i16_to_i32(h, out_idxs, out_idxs); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -196,8 +187,10 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_dbyte_to_f16(h, out_idxs, out_idxs, input_type.is_signed()); + cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i16_to_i32(h, out_idxs, out_idxs); + cvt_i32_to_f32(h, out_idxs, out_idxs); + cvt_f32_to_f16(h, out_idxs, out_idxs); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -208,21 +201,23 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, switch (input_type) { case ov::element::f32: cvt_f32_to_i32(h, in_idxs, out_idxs); - cvt_i32_to_dbyte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); - cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_i32_to_i16(h, out_idxs, out_idxs, is_saturated); + cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; case ov::element::i32: - cvt_i32_to_dbyte(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated); - cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_i32_to_i16(h, in_idxs, out_idxs, is_saturated); + cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; case ov::element::f16: - cvt_f16_to_dbyte(h, in_idxs, out_idxs); - cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_f16_to_f32(h, in_idxs, out_idxs); + cvt_f32_to_i32(h, out_idxs, out_idxs); + cvt_i32_to_i16(h, out_idxs, out_idxs, is_saturated); + cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_dbyte(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_dbyte_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); + cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); From aa0b7ebadf7db6c6e26736c18a10baeeec97fb22 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 04:11:39 +0000 Subject: [PATCH 20/33] Update XReg prc --- .../aarch64/jit_load_store_emitters.cpp | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 0f9d2a817fa1ac..4cab844df4105f 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -35,7 +35,6 @@ template void jit_load_emitter::load_qbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); - XReg prc = XReg(aux_gpr_idxs[0]); TReg dst = TReg(out_idxs[0]); SReg dst_s = SReg(out_idxs[0]); DReg dst_d = DReg(out_idxs[0]); @@ -49,11 +48,13 @@ void jit_load_emitter::load_qbyte(const std::vector &in_idxs, const std: case 2: h->ldr(dst_d, ptr(src, byte_offset_)); break; - case 3: + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); h->ldr(dst_d, ptr(src, byte_offset_)); h->add_imm(prc, src, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); h->ld1(dst.s[2], ptr(prc)); break; + } case 4: h->uni_ldr(dst, src, byte_offset_); break; @@ -66,11 +67,10 @@ template void jit_load_emitter::load_dbyte(const std::vector &in_idxs, const std::vector &out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); - XReg prc = XReg(aux_gpr_idxs[0]); TReg dst = TReg(out_idxs[0]); - DReg dst_d = DReg(out_idxs[0]); HReg dst_h = HReg(out_idxs[0]); SReg dst_s = SReg(out_idxs[0]); + DReg dst_d = DReg(out_idxs[0]); switch (load_num_) { case 0: @@ -81,11 +81,13 @@ void jit_load_emitter::load_dbyte(const std::vector &in_idxs, const std: case 2: h->ldr(dst_s, ptr(src, byte_offset_)); break; - case 3: + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); h->ldr(dst_s, ptr(src, byte_offset_)); h->add_imm(prc, src, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); h->ld1(dst.h[2], ptr(prc)); break; + } case 4: h->ldr(dst_d, ptr(src, byte_offset_)); break; @@ -98,7 +100,6 @@ template void jit_load_emitter::load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const { using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; XReg src = XReg(in_idxs[0]); - XReg prc = XReg(aux_gpr_idxs[0]); TReg dst = TReg(out_idxs[0]); BReg dst_b = BReg(out_idxs[0]); HReg dst_h = HReg(out_idxs[0]); @@ -113,11 +114,13 @@ void jit_load_emitter::load_byte(const std::vector &in_idxs, const std:: case 2: h->ldr(dst_h, ptr(src, byte_offset_)); break; - case 3: + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); h->ldr(dst_h, ptr(src, byte_offset_)); h->add_imm(prc, src, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); h->ld1(dst.b[2], ptr(prc)); break; + } case 4: h->ldr(dst_s, ptr(src, byte_offset_)); break; @@ -199,7 +202,6 @@ void jit_store_emitter::store_qbyte(const std::vector &in_idxs, const st DReg src_d = DReg(in_idxs[0]); QReg src_q = QReg(in_idxs[0]); XReg dst = XReg(out_idxs[0]); - XReg prc = XReg(aux_gpr_idxs[0]); switch (store_num_) { case 0: @@ -210,11 +212,13 @@ void jit_store_emitter::store_qbyte(const std::vector &in_idxs, const st case 2: h->str(src_d, ptr(dst, byte_offset_)); break; - case 3: + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); h->str(src_d, ptr(dst, byte_offset_)); h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(float), h->X_DEFAULT_ADDR); h->st1(src.s[2], ptr(prc)); break; + } case 4: h->str(src_q, ptr(dst, byte_offset_)); break; @@ -231,7 +235,6 @@ void jit_store_emitter::store_dbyte(const std::vector &in_idxs, const st SReg src_s = SReg(in_idxs[0]); DReg src_d = DReg(in_idxs[0]); XReg dst = XReg(out_idxs[0]); - XReg prc = XReg(aux_gpr_idxs[0]); switch (store_num_) { case 0: @@ -242,11 +245,13 @@ void jit_store_emitter::store_dbyte(const std::vector &in_idxs, const st case 2: h->str(src_s, ptr(dst, byte_offset_)); break; - case 3: + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); h->str(src_s, ptr(dst, byte_offset_)); h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(uint16_t), h->X_DEFAULT_ADDR); h->st1(src.h[2], ptr(prc)); break; + } case 4: h->str(src_d, ptr(dst, byte_offset_)); break; @@ -263,7 +268,6 @@ void jit_store_emitter::store_byte(const std::vector &in_idxs, const std HReg src_h = HReg(in_idxs[0]); SReg src_s = SReg(in_idxs[0]); XReg dst = XReg(out_idxs[0]); - XReg prc = XReg(aux_gpr_idxs[0]); switch (store_num_) { case 0: @@ -274,11 +278,13 @@ void jit_store_emitter::store_byte(const std::vector &in_idxs, const std case 2: h->str(src_h, ptr(dst, byte_offset_)); break; - case 3: + case 3: { + XReg prc = XReg(aux_gpr_idxs[0]); h->str(src_h, ptr(dst, byte_offset_)); h->add_imm(prc, dst, byte_offset_ + 2 * sizeof(int8_t), h->X_DEFAULT_ADDR); h->st1(src.b[2], ptr(prc)); break; + } case 4: h->str(src_s, ptr(dst, byte_offset_)); break; From f879613aaaad9b1c64aef4c91442ae5b3fe28079 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 06:21:54 +0000 Subject: [PATCH 21/33] Remove unnecessary aux_vec_idxs --- .../aarch64/jit_load_store_emitters.cpp | 30 +++++-------------- .../aarch64/jit_load_store_emitters.hpp | 2 -- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 4cab844df4105f..728699d38e247d 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -140,21 +140,21 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v switch (src_prc_) { case ov::element::f32: case ov::element::i32: - load_qbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + load_qbyte(in_idxs, out_idxs); break; case ov::element::f16: - load_dbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + load_dbyte(in_idxs, out_idxs); break; case ov::element::i8: case ov::element::u8: - load_byte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); + load_byte(in_idxs, out_idxs); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); } if (src_prc_ != dst_prc_) { - convert_emitter->emit_code(aux_vec_idxs, out_idxs); + convert_emitter->emit_code(out_idxs, out_idxs); } } @@ -165,13 +165,6 @@ size_t jit_load_emitter::get_aux_gprs_count() const { return 0; } -size_t jit_load_emitter::get_aux_vecs_count() const { - if (src_prc_ != dst_prc_) - return 1; - - return 0; -} - jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type) @@ -302,20 +295,20 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: "Unexpected number of elements to store."); if (src_prc_ != dst_prc_) { - convert_emitter->emit_code(in_idxs, aux_vec_idxs); + convert_emitter->emit_code(in_idxs, in_idxs); } switch (dst_prc_) { case ov::element::f32: case ov::element::i32: - store_qbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + store_qbyte(in_idxs, out_idxs); break; case ov::element::f16: - store_dbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + store_dbyte(in_idxs, out_idxs); break; case ov::element::i8: case ov::element::u8: - store_byte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); + store_byte(in_idxs, out_idxs); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); @@ -329,13 +322,6 @@ size_t jit_store_emitter::get_aux_gprs_count() const { return 0; } -size_t jit_store_emitter::get_aux_vecs_count() const { - if (src_prc_ != dst_prc_) - return 1; - - return 0; -} - } // namespace aarch64 } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 3467e083cdf6b1..0829da352b6e5c 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -38,7 +38,6 @@ class jit_load_emitter : public jit_emitter { template void load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; size_t get_aux_gprs_count() const override; - size_t get_aux_vecs_count() const override; std::unique_ptr convert_emitter = nullptr; @@ -69,7 +68,6 @@ class jit_store_emitter : public jit_emitter { template void store_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; size_t get_aux_gprs_count() const override; - size_t get_aux_vecs_count() const override; std::unique_ptr convert_emitter = nullptr; From 6365b00a9016eb38e545c6416b2234095ee22ac8 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 06:28:29 +0000 Subject: [PATCH 22/33] Update mov logic --- .../src/emitters/plugin/aarch64/jit_conversion_emitters.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index 09a9ac7d28b483..bbb4c6d5d7a959 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -128,6 +128,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, ov::element::Type input_type, ov::element::Type output_type, bool is_saturated) { if (input_type == output_type) { + if (in_idxs[0] == out_idxs[0]) + return; using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; h->mov(TReg(out_idxs[0]).b16, TReg(in_idxs[0]).b16); return; From 19a43de36603ec384a874f77d0c50af47930fe2b Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 06:36:30 +0000 Subject: [PATCH 23/33] Update swtich-case for identical input and output precisions --- .../src/emitters/plugin/aarch64/jit_conversion_emitters.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index bbb4c6d5d7a959..a0f02e1eb71fd3 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -138,8 +138,6 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, switch (output_type) { case ov::element::f32: switch (input_type) { - case ov::element::f32: - break; case ov::element::i32: cvt_i32_to_f32(h, in_idxs, out_idxs); break; @@ -161,8 +159,6 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::f32: cvt_f32_to_i32(h, in_idxs, out_idxs); break; - case ov::element::i32: - break; case ov::element::f16: cvt_f16_to_f32(h, in_idxs, out_idxs); cvt_f32_to_i32(h, out_idxs, out_idxs); @@ -185,8 +181,6 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, cvt_i32_to_f32(h, in_idxs, out_idxs); cvt_f32_to_f16(h, out_idxs, out_idxs); break; - case ov::element::f16: - break; case ov::element::i8: case ov::element::u8: cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); From cb094a73fac4d531207c7f2b746e44775f9c7c1d Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 08:04:00 +0000 Subject: [PATCH 24/33] Update template for conversion functions --- .../aarch64/jit_conversion_emitters.cpp | 133 +++++++----------- 1 file changed, 53 insertions(+), 80 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index a0f02e1eb71fd3..a7d86504c9bd06 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -24,44 +24,29 @@ namespace aarch64 { // with architecture ARMv8 do not support such instructions. And as the isa asimd we supported // does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three // instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16). -template -static void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->fcvtl(dst.s4, src.h4); } -template -static void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->fcvtn(dst.h4, src.s4); } -template -static void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->fcvtzs(dst.s, src.s); } -template -static void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->scvtf(dst.s, src.s); } -template -static void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +template +static inline void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst, bool is_saturated) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); if (is_saturated) { h->sqxtn(dst.h4, src.s4); } else { @@ -69,36 +54,24 @@ static void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std } } -template -static void cvt_i16_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_i16_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->sxtl(dst.s4, src.h4); } -template -static void cvt_f16_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_f16_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->fcvtzs(dst.h4, src.h4); } -template -static void cvt_i16_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); +template +static inline void cvt_i16_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { h->scvtf(dst.h4, src.h4); } -template -static void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +template +static inline void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); if (is_saturated) { if (is_signed) { h->sqxtn(dst.b8, src.h8); @@ -110,12 +83,9 @@ static void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const st } } -template -static void cvt_byte_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, +template +static inline void cvt_byte_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst, bool is_signed) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); if (is_signed) { h->sxtl(dst.h8, src.b8); } else { @@ -127,11 +97,14 @@ template static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector &in_idxs, const std::vector &out_idxs, ov::element::Type input_type, ov::element::Type output_type, bool is_saturated) { + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + if (input_type == output_type) { if (in_idxs[0] == out_idxs[0]) return; - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - h->mov(TReg(out_idxs[0]).b16, TReg(in_idxs[0]).b16); + h->mov(dst.b16, src.b16); return; } @@ -139,16 +112,16 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::f32: switch (input_type) { case ov::element::i32: - cvt_i32_to_f32(h, in_idxs, out_idxs); + cvt_i32_to_f32(h, src, dst); break; case ov::element::f16: - cvt_f16_to_f32(h, in_idxs, out_idxs); + cvt_f16_to_f32(h, src, dst); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_i16_to_i32(h, out_idxs, out_idxs); - cvt_i32_to_f32(h, out_idxs, out_idxs); + cvt_byte_to_i16(h, src, dst, input_type.is_signed()); + cvt_i16_to_i32(h, dst, dst); + cvt_i32_to_f32(h, dst, dst); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -157,16 +130,16 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::i32: switch (input_type) { case ov::element::f32: - cvt_f32_to_i32(h, in_idxs, out_idxs); + cvt_f32_to_i32(h, src, dst); break; case ov::element::f16: - cvt_f16_to_f32(h, in_idxs, out_idxs); - cvt_f32_to_i32(h, out_idxs, out_idxs); + cvt_f16_to_f32(h, src, dst); + cvt_f32_to_i32(h, dst, dst); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_i16_to_i32(h, out_idxs, out_idxs); + cvt_byte_to_i16(h, src, dst, input_type.is_signed()); + cvt_i16_to_i32(h, dst, dst); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -175,18 +148,18 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::f16: switch (input_type) { case ov::element::f32: - cvt_f32_to_f16(h, in_idxs, out_idxs); + cvt_f32_to_f16(h, src, dst); break; case ov::element::i32: - cvt_i32_to_f32(h, in_idxs, out_idxs); - cvt_f32_to_f16(h, out_idxs, out_idxs); + cvt_i32_to_f32(h, src, dst); + cvt_f32_to_f16(h, dst, dst); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_i16_to_i32(h, out_idxs, out_idxs); - cvt_i32_to_f32(h, out_idxs, out_idxs); - cvt_f32_to_f16(h, out_idxs, out_idxs); + cvt_byte_to_i16(h, src, dst, input_type.is_signed()); + cvt_i16_to_i32(h, dst, dst); + cvt_i32_to_f32(h, dst, dst); + cvt_f32_to_f16(h, dst, dst); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -196,24 +169,24 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::u8: switch (input_type) { case ov::element::f32: - cvt_f32_to_i32(h, in_idxs, out_idxs); - cvt_i32_to_i16(h, out_idxs, out_idxs, is_saturated); - cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_f32_to_i32(h, src, dst); + cvt_i32_to_i16(h, dst, dst, is_saturated); + cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::i32: - cvt_i32_to_i16(h, in_idxs, out_idxs, is_saturated); - cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_i32_to_i16(h, src, dst, is_saturated); + cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::f16: - cvt_f16_to_f32(h, in_idxs, out_idxs); - cvt_f32_to_i32(h, out_idxs, out_idxs); - cvt_i32_to_i16(h, out_idxs, out_idxs, is_saturated); - cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_f16_to_f32(h, src, dst); + cvt_f32_to_i32(h, dst, dst); + cvt_i32_to_i16(h, dst, dst, is_saturated); + cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, in_idxs, out_idxs, input_type.is_signed()); - cvt_i16_to_byte(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated); + cvt_byte_to_i16(h, src, dst, input_type.is_signed()); + cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); From d4dd7096a8018664e9ceb229c3425c994467399b Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 08:17:10 +0000 Subject: [PATCH 25/33] Apply mov for conversion between i8 and u8 for truncation mode --- .../emitters/plugin/aarch64/jit_conversion_emitters.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index a7d86504c9bd06..0f757f9f0a6c89 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -101,10 +101,11 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, TReg src = TReg(in_idxs[0]); TReg dst = TReg(out_idxs[0]); - if (input_type == output_type) { - if (in_idxs[0] == out_idxs[0]) - return; - h->mov(dst.b16, src.b16); + if (input_type == output_type || (!is_saturated && + one_of(input_type, ov::element::i8, ov::element::u8) && one_of(output_type, ov::element::i8, ov::element::u8))) { + if (in_idxs[0] != out_idxs[0]) { + h->mov(dst.b16, src.b16); + } return; } From 66b70daec150999e85f1eef9f6e72e18aea9a465 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Thu, 22 Aug 2024 08:25:18 +0000 Subject: [PATCH 26/33] Update jit_convert_emitter constructor --- .../src/emitters/plugin/aarch64/jit_conversion_emitters.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index 0f757f9f0a6c89..c3cfe514d69b6e 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -199,9 +199,7 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, } jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { - input_type = node->get_input_element_type(0); - output_type = node->get_output_element_type(0); +: jit_convert_emitter(host, host_isa, node->get_input_element_type(0), node->get_output_element_type(0), exec_prc) { } jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, From f7ed03c883fd7caea308eaddf74037679bb4986f Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 04:14:49 +0000 Subject: [PATCH 27/33] revert removing unnecessary aux_vec_idxs --- .../aarch64/jit_load_store_emitters.cpp | 30 ++++++++++++++----- .../aarch64/jit_load_store_emitters.hpp | 2 ++ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 728699d38e247d..4cab844df4105f 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -140,21 +140,21 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v switch (src_prc_) { case ov::element::f32: case ov::element::i32: - load_qbyte(in_idxs, out_idxs); + load_qbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); break; case ov::element::f16: - load_dbyte(in_idxs, out_idxs); + load_dbyte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); break; case ov::element::i8: case ov::element::u8: - load_byte(in_idxs, out_idxs); + load_byte(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name()); } if (src_prc_ != dst_prc_) { - convert_emitter->emit_code(out_idxs, out_idxs); + convert_emitter->emit_code(aux_vec_idxs, out_idxs); } } @@ -165,6 +165,13 @@ size_t jit_load_emitter::get_aux_gprs_count() const { return 0; } +size_t jit_load_emitter::get_aux_vecs_count() const { + if (src_prc_ != dst_prc_) + return 1; + + return 0; +} + jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset, arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type) @@ -295,20 +302,20 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: "Unexpected number of elements to store."); if (src_prc_ != dst_prc_) { - convert_emitter->emit_code(in_idxs, in_idxs); + convert_emitter->emit_code(in_idxs, aux_vec_idxs); } switch (dst_prc_) { case ov::element::f32: case ov::element::i32: - store_qbyte(in_idxs, out_idxs); + store_qbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); break; case ov::element::f16: - store_dbyte(in_idxs, out_idxs); + store_dbyte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); break; case ov::element::i8: case ov::element::u8: - store_byte(in_idxs, out_idxs); + store_byte(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name()); @@ -322,6 +329,13 @@ size_t jit_store_emitter::get_aux_gprs_count() const { return 0; } +size_t jit_store_emitter::get_aux_vecs_count() const { + if (src_prc_ != dst_prc_) + return 1; + + return 0; +} + } // namespace aarch64 } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 0829da352b6e5c..3467e083cdf6b1 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -38,6 +38,7 @@ class jit_load_emitter : public jit_emitter { template void load_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; size_t get_aux_gprs_count() const override; + size_t get_aux_vecs_count() const override; std::unique_ptr convert_emitter = nullptr; @@ -68,6 +69,7 @@ class jit_store_emitter : public jit_emitter { template void store_byte(const std::vector &in_idxs, const std::vector &out_idxs) const; size_t get_aux_gprs_count() const override; + size_t get_aux_vecs_count() const override; std::unique_ptr convert_emitter = nullptr; From 0fde0e2fb54a7f2eb0542013923f52902683f65b Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 06:10:21 +0000 Subject: [PATCH 28/33] Make conversion functions to be member functions of base class --- .../aarch64/jit_conversion_emitters.cpp | 102 +++++++++--------- .../aarch64/jit_conversion_emitters.hpp | 25 +++++ 2 files changed, 75 insertions(+), 52 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp index c3cfe514d69b6e..38c6b9d710db7c 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp @@ -25,28 +25,27 @@ namespace aarch64 { // does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three // instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16). template -static inline void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_f16_to_f32(const TReg &src, const TReg &dst) const { h->fcvtl(dst.s4, src.h4); } template -static inline void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_f32_to_f16(const TReg &src, const TReg &dst) const { h->fcvtn(dst.h4, src.s4); } template -static inline void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_f32_to_i32(const TReg &src, const TReg &dst) const { h->fcvtzs(dst.s, src.s); } template -static inline void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_i32_to_f32(const TReg &src, const TReg &dst) const { h->scvtf(dst.s, src.s); } template -static inline void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst, - bool is_saturated) { +inline void jit_convert_emitter::cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const { if (is_saturated) { h->sqxtn(dst.h4, src.s4); } else { @@ -55,23 +54,22 @@ static inline void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, co } template -static inline void cvt_i16_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_i16_to_i32(const TReg &src, const TReg &dst) const { h->sxtl(dst.s4, src.h4); } template -static inline void cvt_f16_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_f16_to_i16(const TReg &src, const TReg &dst) const { h->fcvtzs(dst.h4, src.h4); } template -static inline void cvt_i16_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst) { +inline void jit_convert_emitter::cvt_i16_to_f16(const TReg &src, const TReg &dst) const { h->scvtf(dst.h4, src.h4); } template -static inline void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst, - bool is_signed, bool is_saturated) { +inline void jit_convert_emitter::cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const { if (is_saturated) { if (is_signed) { h->sqxtn(dst.b8, src.h8); @@ -84,8 +82,7 @@ static inline void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, c } template -static inline void cvt_byte_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const TReg &src, const TReg &dst, - bool is_signed) { +inline void jit_convert_emitter::cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const { if (is_signed) { h->sxtl(dst.h8, src.b8); } else { @@ -93,17 +90,12 @@ static inline void cvt_byte_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, c } } -template -static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, - const std::vector &in_idxs, const std::vector &out_idxs, - ov::element::Type input_type, ov::element::Type output_type, bool is_saturated) { - using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; - TReg src = TReg(in_idxs[0]); - TReg dst = TReg(out_idxs[0]); - +template +void jit_convert_emitter::jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type, + bool is_saturated) const { if (input_type == output_type || (!is_saturated && one_of(input_type, ov::element::i8, ov::element::u8) && one_of(output_type, ov::element::i8, ov::element::u8))) { - if (in_idxs[0] != out_idxs[0]) { + if (src.getIdx() != dst.getIdx()) { h->mov(dst.b16, src.b16); } return; @@ -113,16 +105,16 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::f32: switch (input_type) { case ov::element::i32: - cvt_i32_to_f32(h, src, dst); + cvt_i32_to_f32(src, dst); break; case ov::element::f16: - cvt_f16_to_f32(h, src, dst); + cvt_f16_to_f32(src, dst); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, src, dst, input_type.is_signed()); - cvt_i16_to_i32(h, dst, dst); - cvt_i32_to_f32(h, dst, dst); + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_i32(dst, dst); + cvt_i32_to_f32(dst, dst); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -131,16 +123,16 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::i32: switch (input_type) { case ov::element::f32: - cvt_f32_to_i32(h, src, dst); + cvt_f32_to_i32(src, dst); break; case ov::element::f16: - cvt_f16_to_f32(h, src, dst); - cvt_f32_to_i32(h, dst, dst); + cvt_f16_to_f32(src, dst); + cvt_f32_to_i32(dst, dst); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, src, dst, input_type.is_signed()); - cvt_i16_to_i32(h, dst, dst); + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_i32(dst, dst); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -149,18 +141,18 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::f16: switch (input_type) { case ov::element::f32: - cvt_f32_to_f16(h, src, dst); + cvt_f32_to_f16(src, dst); break; case ov::element::i32: - cvt_i32_to_f32(h, src, dst); - cvt_f32_to_f16(h, dst, dst); + cvt_i32_to_f32(src, dst); + cvt_f32_to_f16(dst, dst); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, src, dst, input_type.is_signed()); - cvt_i16_to_i32(h, dst, dst); - cvt_i32_to_f32(h, dst, dst); - cvt_f32_to_f16(h, dst, dst); + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_i32(dst, dst); + cvt_i32_to_f32(dst, dst); + cvt_f32_to_f16(dst, dst); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -170,24 +162,24 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h, case ov::element::u8: switch (input_type) { case ov::element::f32: - cvt_f32_to_i32(h, src, dst); - cvt_i32_to_i16(h, dst, dst, is_saturated); - cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); + cvt_f32_to_i32(src, dst); + cvt_i32_to_i16(dst, dst, is_saturated); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::i32: - cvt_i32_to_i16(h, src, dst, is_saturated); - cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); + cvt_i32_to_i16(src, dst, is_saturated); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::f16: - cvt_f16_to_f32(h, src, dst); - cvt_f32_to_i32(h, dst, dst); - cvt_i32_to_i16(h, dst, dst, is_saturated); - cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); + cvt_f16_to_f32(src, dst); + cvt_f32_to_i32(dst, dst); + cvt_i32_to_i16(dst, dst, is_saturated); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; case ov::element::i8: case ov::element::u8: - cvt_byte_to_i16(h, src, dst, input_type.is_signed()); - cvt_i16_to_byte(h, dst, dst, output_type.is_signed(), is_saturated); + cvt_byte_to_i16(src, dst, input_type.is_signed()); + cvt_i16_to_byte(dst, dst, output_type.is_signed(), is_saturated); break; default: OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name()); @@ -247,7 +239,10 @@ void jit_convert_truncation_emitter::emit_impl(const std::vector &in_idx template void jit_convert_truncation_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { - jit_convert_process(h, in_idxs, out_idxs, input_type, output_type, false); + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + jit_convert_process(src, dst, input_type, output_type, false); } jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa, @@ -273,7 +268,10 @@ void jit_convert_saturation_emitter::emit_impl(const std::vector &in_idx template void jit_convert_saturation_emitter::emit_isa(const std::vector &in_idxs, const std::vector &out_idxs) const { - jit_convert_process(h, in_idxs, out_idxs, input_type, output_type, true); + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + TReg src = TReg(in_idxs[0]); + TReg dst = TReg(out_idxs[0]); + jit_convert_process(src, dst, input_type, output_type, true); } } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp index 16a08bdf4b0f0e..2bb1ad34a7272f 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.hpp @@ -22,9 +22,34 @@ class jit_convert_emitter : public jit_emitter { protected: void emit_data() const override; void validate_types() const; + template + void jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type, + bool is_saturated) const; ov::element::Type input_type; ov::element::Type output_type; + +private: + template + inline void cvt_f16_to_f32(const TReg &src, const TReg &dst) const; + template + inline void cvt_f32_to_f16(const TReg &src, const TReg &dst) const; + template + inline void cvt_f32_to_i32(const TReg &src, const TReg &dst) const; + template + inline void cvt_i32_to_f32(const TReg &src, const TReg &dst) const; + template + inline void cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const; + template + inline void cvt_i16_to_i32(const TReg &src, const TReg &dst) const; + template + inline void cvt_f16_to_i16(const TReg &src, const TReg &dst) const; + template + inline void cvt_i16_to_f16(const TReg &src, const TReg &dst) const; + template + inline void cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const; + template + inline void cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const; }; // This emitter is covered by specification of "Convert" operation. The implementation uses a "warp-around" conversion. From e547663dc733ab6394ec24316ac3d4f9247984bf Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 06:31:52 +0000 Subject: [PATCH 29/33] Add assertion --- .../src/emitters/plugin/aarch64/jit_load_store_emitters.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 4cab844df4105f..529aaa3cb79993 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -154,6 +154,7 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v } if (src_prc_ != dst_prc_) { + OPENVINO_ASSERT(convert_emitter, "Invalid convert_emitter."); convert_emitter->emit_code(aux_vec_idxs, out_idxs); } } @@ -302,6 +303,7 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: "Unexpected number of elements to store."); if (src_prc_ != dst_prc_) { + OPENVINO_ASSERT(convert_emitter, "Invalid convert_emitter."); convert_emitter->emit_code(in_idxs, aux_vec_idxs); } From 431bf95b033e29ae2aefb63c44cc0c67c82378b8 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 06:58:55 +0000 Subject: [PATCH 30/33] Update SNIPPETS_REGISTER_PASS_RELATIVE --- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 35 +++++++++----------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 0935a8b8f1d5ab..e166fc8bf453e7 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -670,33 +670,30 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const { using PassPosition = ov::snippets::pass::PassPosition; using Place = PassPosition::Place; - -# define SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(PASS_PLACE, TARGET_PASS, PASS, ...) \ +# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) \ backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), std::make_shared(__VA_ARGS__)) #if defined(OPENVINO_ARCH_X86_64) -# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) \ - backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), std::make_shared(__VA_ARGS__)) -#else -# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) -#endif // OPENVINO_ARCH_X86_64 + SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::MarkLoops, + ov::intel_cpu::pass::BrgemmCPUBlocking); +#endif + + SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::InsertLoops, + ov::intel_cpu::pass::FuseLoadStoreConvert); - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::snippets::lowered::pass::MarkLoops, - ov::intel_cpu::pass::BrgemmCPUBlocking); - SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::lowered::pass::InsertLoops, - ov::intel_cpu::pass::FuseLoadStoreConvert); - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, - ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape); +#if defined(OPENVINO_ARCH_X86_64) + SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, + ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape); +#endif #ifdef SNIPPETS_LIBXSMM_TPP - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::intel_cpu::pass::BrgemmCPUBlocking, - ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, - ov::intel_cpu::tpp::pass::SetTPPLeadingDim); + SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, ov::intel_cpu::pass::BrgemmCPUBlocking, + ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); + SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, + ov::intel_cpu::tpp::pass::SetTPPLeadingDim); #endif -#undef SNIPPETS_REGISTER_PASS_RELATIVE_COMMON -#undef SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 +#undef SNIPPETS_REGISTER_PASS_RELATIVE return backend_passes; } From cb497a147033feaedce1ff65a6204be0fdf51cb1 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 07:03:55 +0000 Subject: [PATCH 31/33] Apply convert_truncation_emitter --- .../src/emitters/plugin/aarch64/jit_load_store_emitters.cpp | 6 +++--- .../src/emitters/plugin/aarch64/jit_load_store_emitters.hpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 529aaa3cb79993..2023526793f050 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -20,7 +20,7 @@ jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset), src_prc_(src_prc), dst_prc_(dst_prc) { - convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + convert_truncation_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); } void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { @@ -154,8 +154,8 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v } if (src_prc_ != dst_prc_) { - OPENVINO_ASSERT(convert_emitter, "Invalid convert_emitter."); - convert_emitter->emit_code(aux_vec_idxs, out_idxs); + OPENVINO_ASSERT(convert_truncation_emitter, "Invalid convert_truncation_emitter."); + convert_truncation_emitter->emit_code(aux_vec_idxs, out_idxs); } } diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp index 3467e083cdf6b1..6ad53146ebd4ce 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.hpp @@ -40,7 +40,7 @@ class jit_load_emitter : public jit_emitter { size_t get_aux_gprs_count() const override; size_t get_aux_vecs_count() const override; - std::unique_ptr convert_emitter = nullptr; + std::unique_ptr convert_truncation_emitter = nullptr; std::string name_; int load_num_; // the element number to load From c16fd381f3f53e86aac490fc1a75d4107cf6104c Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 07:10:04 +0000 Subject: [PATCH 32/33] Add condition for creating conversion emitters --- .../plugin/aarch64/jit_load_store_emitters.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 2023526793f050..5ca032ca57a70d 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -20,7 +20,9 @@ jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset), src_prc_(src_prc), dst_prc_(dst_prc) { - convert_truncation_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + if (src_prc_ != dst_prc_) { + convert_truncation_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + } } void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { @@ -178,12 +180,14 @@ jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *ho arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type) : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), src_prc_(src_prc), dst_prc_(dst_prc) { - if (mode == arithmetic_mode::truncation) { - convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); - } else if (mode == arithmetic_mode::saturation) { - convert_emitter.reset(new jit_convert_saturation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); - } else { - OV_CPU_JIT_EMITTER_THROW("Unsupported Convert emitter."); + if (src_prc_ != dst_prc_) { + if (mode == arithmetic_mode::truncation) { + convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + } else if (mode == arithmetic_mode::saturation) { + convert_emitter.reset(new jit_convert_saturation_emitter(host, host_isa, src_prc, dst_prc, exec_prc)); + } else { + OV_CPU_JIT_EMITTER_THROW("Unsupported Convert emitter."); + } } } From dda522564dffcffce2990e2502e7d5c290780e48 Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Fri, 23 Aug 2024 10:33:56 +0000 Subject: [PATCH 33/33] Update assertion for element number --- .../src/emitters/plugin/aarch64/jit_load_store_emitters.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp index 5ca032ca57a70d..019c9177f81401 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp @@ -136,8 +136,7 @@ void jit_load_emitter::emit_isa(const std::vector &in_idxs, const std::v bool is_supported_precision = one_of(src_prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && (src_prc_ == dst_prc_ || one_of(dst_prc_, ov::element::f32, ov::element::i32)); OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); - OV_CPU_JIT_EMITTER_ASSERT(load_num_ <= static_cast((get_vec_length() / dst_prc_.size())), - "Unexpected number of elements to load."); + OV_CPU_JIT_EMITTER_ASSERT(load_num_ <= 4, "Unexpected number of elements to load."); switch (src_prc_) { case ov::element::f32: @@ -303,8 +302,7 @@ void jit_store_emitter::emit_isa(const std::vector &in_idxs, const std:: bool is_supported_precision = one_of(dst_prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && (src_prc_ == dst_prc_ || one_of(src_prc_, ov::element::f32, ov::element::i32)); OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); - OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= static_cast((get_vec_length() / dst_prc_.size())), - "Unexpected number of elements to store."); + OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= 4, "Unexpected number of elements to store."); if (src_prc_ != dst_prc_) { OPENVINO_ASSERT(convert_emitter, "Invalid convert_emitter.");