Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] [Snippets] Implement Convert for Snippets on ARM #25815

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ba3fd33
[CPU] [Snippets] Implement load, store, convert emitters, and add con…
xuchen-intel Jul 2, 2024
d5e8174
Enable LoadConvertSaturation and three other counterparts
xuchen-intel Jul 23, 2024
5e13530
Fix the issue regarding initialization order
xuchen-intel Jul 31, 2024
b732403
Fix issue regarding incorrect path of headers
xuchen-intel Jul 31, 2024
b4c92bc
Support conversion between the same precision
xuchen-intel Aug 1, 2024
96c06c6
Fix issue regarding primitive type
xuchen-intel Aug 1, 2024
205e579
Skip test cases on unaligned conversion behavior
xuchen-intel Aug 5, 2024
b2c31a4
Apply review comments regarding conversion between f16 and i8(u8)
xuchen-intel Aug 6, 2024
ffeadd5
Revise CMakeLists
xuchen-intel Aug 7, 2024
6b03081
Apply arithmetic_mode to align with x64
xuchen-intel Aug 8, 2024
4d99437
Update precision assertion
xuchen-intel Aug 8, 2024
f03800b
Replace post_ptr with ptr
xuchen-intel Aug 9, 2024
63d16c2
Set IGNORE_CALLBACK if rank > 6
xuchen-intel Aug 19, 2024
1ea1b6b
Update isSuitableConvert
xuchen-intel Aug 19, 2024
da5a32d
Update jit_store_memory_emitter constructor
xuchen-intel Aug 19, 2024
7e8d9c0
Update enum class arithmetic_mode
xuchen-intel Aug 19, 2024
1747147
Call convert_emitter in load/store_emitter
xuchen-intel Aug 19, 2024
f7d4f11
Update arguments for instructions regarding f16 conversion
xuchen-intel Aug 20, 2024
b21c9dd
Make conversion between f16 and i8 compatible with ARMv8
xuchen-intel Aug 20, 2024
aa0b7eb
Update XReg prc
xuchen-intel Aug 22, 2024
f879613
Remove unnecessary aux_vec_idxs
xuchen-intel Aug 22, 2024
6365b00
Update mov logic
xuchen-intel Aug 22, 2024
19a43de
Update swtich-case for identical input and output precisions
xuchen-intel Aug 22, 2024
cb094a7
Update template for conversion functions
xuchen-intel Aug 22, 2024
d4dd709
Apply mov for conversion between i8 and u8 for truncation mode
xuchen-intel Aug 22, 2024
66b70da
Update jit_convert_emitter constructor
xuchen-intel Aug 22, 2024
f7ed03c
revert removing unnecessary aux_vec_idxs
xuchen-intel Aug 23, 2024
0fde0e2
Make conversion functions to be member functions of base class
xuchen-intel Aug 23, 2024
e547663
Add assertion
xuchen-intel Aug 23, 2024
431bf95
Update SNIPPETS_REGISTER_PASS_RELATIVE
xuchen-intel Aug 23, 2024
cb497a1
Apply convert_truncation_emitter
xuchen-intel Aug 23, 2024
c16fd38
Add condition for creating conversion emitters
xuchen-intel Aug 23, 2024
dda5225
Update assertion for element number
xuchen-intel Aug 23, 2024
7ec419a
Merge branch 'master' into feature/arm_snippets_convert
xuchen-intel Sep 2, 2024
11a8d58
Merge branch 'master' into feature/arm_snippets_convert
xuchen-intel Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_conversion_emitters.hpp"
#include "emitters/utils.hpp"

using namespace dnnl::impl::cpu::aarch64;
using namespace Xbyak_aarch64;

namespace ov {
namespace intel_cpu {
namespace aarch64 {

// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported
// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate
// precision i16 in the following graph. Then the conversion between each pair of
// neighbors in this graph will be done with single instruction.
// f16 - f32 - i32 - i16 - byte
// | |
// - - - - - - - - - -
// Note that using single instruction for conversion between f16 and i16 is only available for
// architecture ARMv8.2-A or later versions. So ARM platforms like Raspberry (Model name Cortex-A72)
// with architecture ARMv8 do not support such instructions. And as the isa asimd we supported
// does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three
// instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16).
template <typename TReg>
inline void jit_convert_emitter::cvt_f16_to_f32(const TReg &src, const TReg &dst) const {
h->fcvtl(dst.s4, src.h4);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_f32_to_f16(const TReg &src, const TReg &dst) const {
h->fcvtn(dst.h4, src.s4);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_f32_to_i32(const TReg &src, const TReg &dst) const {
h->fcvtzs(dst.s, src.s);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_i32_to_f32(const TReg &src, const TReg &dst) const {
h->scvtf(dst.s, src.s);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const {
if (is_saturated) {
h->sqxtn(dst.h4, src.s4);
} else {
h->xtn(dst.h4, src.s4);
}
}

template <typename TReg>
inline void jit_convert_emitter::cvt_i16_to_i32(const TReg &src, const TReg &dst) const {
h->sxtl(dst.s4, src.h4);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_f16_to_i16(const TReg &src, const TReg &dst) const {
h->fcvtzs(dst.h4, src.h4);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_i16_to_f16(const TReg &src, const TReg &dst) const {
h->scvtf(dst.h4, src.h4);
}

template <typename TReg>
inline void jit_convert_emitter::cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const {
if (is_saturated) {
if (is_signed) {
h->sqxtn(dst.b8, src.h8);
} else {
h->uqxtn(dst.b8, src.h8);
}
} else {
h->xtn(dst.b8, src.h8);
}
}

template <typename TReg>
inline void jit_convert_emitter::cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const {
if (is_signed) {
h->sxtl(dst.h8, src.b8);
} else {
h->uxtl(dst.h8, src.b8);
}
}

template <typename TReg>
void jit_convert_emitter::jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type,
bool is_saturated) const {
if (input_type == output_type || (!is_saturated &&
one_of(input_type, ov::element::i8, ov::element::u8) && one_of(output_type, ov::element::i8, ov::element::u8))) {
if (src.getIdx() != dst.getIdx()) {
h->mov(dst.b16, src.b16);
}
return;
}

switch (output_type) {
case ov::element::f32:
switch (input_type) {
case ov::element::i32:
cvt_i32_to_f32<TReg>(src, dst);
break;
case ov::element::f16:
cvt_f16_to_f32<TReg>(src, dst);
break;
case ov::element::i8:
case ov::element::u8:
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
cvt_i16_to_i32<TReg>(dst, dst);
cvt_i32_to_f32<TReg>(dst, dst);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
}
break;
case ov::element::i32:
switch (input_type) {
case ov::element::f32:
cvt_f32_to_i32<TReg>(src, dst);
break;
case ov::element::f16:
cvt_f16_to_f32<TReg>(src, dst);
cvt_f32_to_i32<TReg>(dst, dst);
break;
case ov::element::i8:
case ov::element::u8:
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
cvt_i16_to_i32<TReg>(dst, dst);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
}
break;
case ov::element::f16:
switch (input_type) {
case ov::element::f32:
cvt_f32_to_f16<TReg>(src, dst);
break;
case ov::element::i32:
cvt_i32_to_f32<TReg>(src, dst);
cvt_f32_to_f16<TReg>(dst, dst);
break;
case ov::element::i8:
case ov::element::u8:
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
cvt_i16_to_i32<TReg>(dst, dst);
cvt_i32_to_f32<TReg>(dst, dst);
cvt_f32_to_f16<TReg>(dst, dst);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
}
break;
case ov::element::i8:
case ov::element::u8:
switch (input_type) {
case ov::element::f32:
cvt_f32_to_i32<TReg>(src, dst);
cvt_i32_to_i16<TReg>(dst, dst, is_saturated);
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
break;
case ov::element::i32:
cvt_i32_to_i16<TReg>(src, dst, is_saturated);
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
break;
case ov::element::f16:
cvt_f16_to_f32<TReg>(src, dst);
cvt_f32_to_i32<TReg>(dst, dst);
cvt_i32_to_i16<TReg>(dst, dst, is_saturated);
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
break;
case ov::element::i8:
case ov::element::u8:
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
}
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", output_type.get_type_name());
}
}

jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, node->get_input_element_type(0), node->get_output_element_type(0), exec_prc) {
}

jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
input_type = input_prc;
output_type = output_prc;
}

void jit_convert_emitter::validate_types() const {
OV_CPU_JIT_EMITTER_ASSERT(one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported input type: ", input_type.get_type_name());
OV_CPU_JIT_EMITTER_ASSERT(one_of(output_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported output type: ", output_type.get_type_name());
}

size_t jit_convert_emitter::get_inputs_count() const { return 1; }

void jit_convert_emitter::emit_data() const {
jit_emitter::emit_data();
}

jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, node, exec_prc) {
}

jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
}

void jit_convert_truncation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
validate_types();
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_idxs, out_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
}
}

template <cpu_isa_t isa>
void jit_convert_truncation_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
jit_convert_process<TReg>(src, dst, input_type, output_type, false);
}

jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, node, exec_prc) {
}

jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
}

void jit_convert_saturation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
validate_types();
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_idxs, out_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
}
}

template <cpu_isa_t isa>
void jit_convert_saturation_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
jit_convert_process<TReg>(src, dst, input_type, output_type, true);
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "jit_emitter.hpp"

namespace ov {
namespace intel_cpu {
namespace aarch64 {

class jit_convert_emitter : public jit_emitter {
public:
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

size_t get_inputs_count() const override;

protected:
void emit_data() const override;
void validate_types() const;
template <typename TReg>
void jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type,
bool is_saturated) const;

ov::element::Type input_type;
ov::element::Type output_type;

private:
template <typename TReg>
inline void cvt_f16_to_f32(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_f32_to_f16(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_f32_to_i32(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_i32_to_f32(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const;
template <typename TReg>
inline void cvt_i16_to_i32(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_f16_to_i16(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_i16_to_f16(const TReg &src, const TReg &dst) const;
template <typename TReg>
inline void cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const;
template <typename TReg>
inline void cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const;
};

// This emitter is covered by specification of "Convert" operation. The implementation uses a "warp-around" conversion.
// Example:
// int32_t -> int8_t
// 129 -> -127
class jit_convert_truncation_emitter : public jit_convert_emitter {
public:
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

private:
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const;
};

// This emitter is covered by the common dnnl behavior. The implementation uses a "saturation" conversion.
// Example:
// int32_t -> int8_t
// 129 -> 127
class jit_convert_saturation_emitter : public jit_convert_emitter {
public:
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

private:
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const;
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Loading
Loading