Skip to content

Commit 47abc61

Browse files
xuchen-intelababushk
authored andcommitted
[CPU] [Snippets] Implement Convert for Snippets on ARM (openvinotoolkit#25815)
### Details: - *Add jit implementation for Convert emitters on ARM* - *Add jit implementation for Load/Store emitters for precision i32, f16, i8, u8 on ARM* - *Add Snippets tokenization for Convert on ARM* - *Enable LoadConvertSaturation and three other counterparts* - *Test case coverage* ### Tickets: - *[CVS-141288](https://jira.devtools.intel.com/browse/CVS-141288)* - *[CVS-141294](https://jira.devtools.intel.com/browse/CVS-141294)*
1 parent 4034f74 commit 47abc61

25 files changed

+906
-59
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "jit_conversion_emitters.hpp"
6+
#include "emitters/utils.hpp"
7+
8+
using namespace dnnl::impl::cpu::aarch64;
9+
using namespace Xbyak_aarch64;
10+
11+
namespace ov {
12+
namespace intel_cpu {
13+
namespace aarch64 {
14+
15+
// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported
16+
// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate
17+
// precision i16 in the following graph. Then the conversion between each pair of
18+
// neighbors in this graph will be done with single instruction.
19+
// f16 - f32 - i32 - i16 - byte
20+
// | |
21+
// - - - - - - - - - -
22+
// Note that using single instruction for conversion between f16 and i16 is only available for
23+
// architecture ARMv8.2-A or later versions. So ARM platforms like Raspberry (Model name Cortex-A72)
24+
// with architecture ARMv8 do not support such instructions. And as the isa asimd we supported
25+
// does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three
26+
// instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16).
27+
template <typename TReg>
28+
inline void jit_convert_emitter::cvt_f16_to_f32(const TReg &src, const TReg &dst) const {
29+
h->fcvtl(dst.s4, src.h4);
30+
}
31+
32+
template <typename TReg>
33+
inline void jit_convert_emitter::cvt_f32_to_f16(const TReg &src, const TReg &dst) const {
34+
h->fcvtn(dst.h4, src.s4);
35+
}
36+
37+
template <typename TReg>
38+
inline void jit_convert_emitter::cvt_f32_to_i32(const TReg &src, const TReg &dst) const {
39+
h->fcvtzs(dst.s, src.s);
40+
}
41+
42+
template <typename TReg>
43+
inline void jit_convert_emitter::cvt_i32_to_f32(const TReg &src, const TReg &dst) const {
44+
h->scvtf(dst.s, src.s);
45+
}
46+
47+
template <typename TReg>
48+
inline void jit_convert_emitter::cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const {
49+
if (is_saturated) {
50+
h->sqxtn(dst.h4, src.s4);
51+
} else {
52+
h->xtn(dst.h4, src.s4);
53+
}
54+
}
55+
56+
template <typename TReg>
57+
inline void jit_convert_emitter::cvt_i16_to_i32(const TReg &src, const TReg &dst) const {
58+
h->sxtl(dst.s4, src.h4);
59+
}
60+
61+
template <typename TReg>
62+
inline void jit_convert_emitter::cvt_f16_to_i16(const TReg &src, const TReg &dst) const {
63+
h->fcvtzs(dst.h4, src.h4);
64+
}
65+
66+
template <typename TReg>
67+
inline void jit_convert_emitter::cvt_i16_to_f16(const TReg &src, const TReg &dst) const {
68+
h->scvtf(dst.h4, src.h4);
69+
}
70+
71+
template <typename TReg>
72+
inline void jit_convert_emitter::cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const {
73+
if (is_saturated) {
74+
if (is_signed) {
75+
h->sqxtn(dst.b8, src.h8);
76+
} else {
77+
h->uqxtn(dst.b8, src.h8);
78+
}
79+
} else {
80+
h->xtn(dst.b8, src.h8);
81+
}
82+
}
83+
84+
template <typename TReg>
85+
inline void jit_convert_emitter::cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const {
86+
if (is_signed) {
87+
h->sxtl(dst.h8, src.b8);
88+
} else {
89+
h->uxtl(dst.h8, src.b8);
90+
}
91+
}
92+
93+
template <typename TReg>
94+
void jit_convert_emitter::jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type,
95+
bool is_saturated) const {
96+
if (input_type == output_type || (!is_saturated &&
97+
one_of(input_type, ov::element::i8, ov::element::u8) && one_of(output_type, ov::element::i8, ov::element::u8))) {
98+
if (src.getIdx() != dst.getIdx()) {
99+
h->mov(dst.b16, src.b16);
100+
}
101+
return;
102+
}
103+
104+
switch (output_type) {
105+
case ov::element::f32:
106+
switch (input_type) {
107+
case ov::element::i32:
108+
cvt_i32_to_f32<TReg>(src, dst);
109+
break;
110+
case ov::element::f16:
111+
cvt_f16_to_f32<TReg>(src, dst);
112+
break;
113+
case ov::element::i8:
114+
case ov::element::u8:
115+
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
116+
cvt_i16_to_i32<TReg>(dst, dst);
117+
cvt_i32_to_f32<TReg>(dst, dst);
118+
break;
119+
default:
120+
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
121+
}
122+
break;
123+
case ov::element::i32:
124+
switch (input_type) {
125+
case ov::element::f32:
126+
cvt_f32_to_i32<TReg>(src, dst);
127+
break;
128+
case ov::element::f16:
129+
cvt_f16_to_f32<TReg>(src, dst);
130+
cvt_f32_to_i32<TReg>(dst, dst);
131+
break;
132+
case ov::element::i8:
133+
case ov::element::u8:
134+
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
135+
cvt_i16_to_i32<TReg>(dst, dst);
136+
break;
137+
default:
138+
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
139+
}
140+
break;
141+
case ov::element::f16:
142+
switch (input_type) {
143+
case ov::element::f32:
144+
cvt_f32_to_f16<TReg>(src, dst);
145+
break;
146+
case ov::element::i32:
147+
cvt_i32_to_f32<TReg>(src, dst);
148+
cvt_f32_to_f16<TReg>(dst, dst);
149+
break;
150+
case ov::element::i8:
151+
case ov::element::u8:
152+
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
153+
cvt_i16_to_i32<TReg>(dst, dst);
154+
cvt_i32_to_f32<TReg>(dst, dst);
155+
cvt_f32_to_f16<TReg>(dst, dst);
156+
break;
157+
default:
158+
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
159+
}
160+
break;
161+
case ov::element::i8:
162+
case ov::element::u8:
163+
switch (input_type) {
164+
case ov::element::f32:
165+
cvt_f32_to_i32<TReg>(src, dst);
166+
cvt_i32_to_i16<TReg>(dst, dst, is_saturated);
167+
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
168+
break;
169+
case ov::element::i32:
170+
cvt_i32_to_i16<TReg>(src, dst, is_saturated);
171+
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
172+
break;
173+
case ov::element::f16:
174+
cvt_f16_to_f32<TReg>(src, dst);
175+
cvt_f32_to_i32<TReg>(dst, dst);
176+
cvt_i32_to_i16<TReg>(dst, dst, is_saturated);
177+
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
178+
break;
179+
case ov::element::i8:
180+
case ov::element::u8:
181+
cvt_byte_to_i16<TReg>(src, dst, input_type.is_signed());
182+
cvt_i16_to_byte<TReg>(dst, dst, output_type.is_signed(), is_saturated);
183+
break;
184+
default:
185+
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
186+
}
187+
break;
188+
default:
189+
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", output_type.get_type_name());
190+
}
191+
}
192+
193+
jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
194+
: jit_convert_emitter(host, host_isa, node->get_input_element_type(0), node->get_output_element_type(0), exec_prc) {
195+
}
196+
197+
jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa,
198+
ov::element::Type input_prc,
199+
ov::element::Type output_prc,
200+
ov::element::Type exec_prc)
201+
: jit_emitter(host, host_isa, exec_prc) {
202+
input_type = input_prc;
203+
output_type = output_prc;
204+
}
205+
206+
void jit_convert_emitter::validate_types() const {
207+
OV_CPU_JIT_EMITTER_ASSERT(one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
208+
"Unsupported input type: ", input_type.get_type_name());
209+
OV_CPU_JIT_EMITTER_ASSERT(one_of(output_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
210+
"Unsupported output type: ", output_type.get_type_name());
211+
}
212+
213+
size_t jit_convert_emitter::get_inputs_count() const { return 1; }
214+
215+
void jit_convert_emitter::emit_data() const {
216+
jit_emitter::emit_data();
217+
}
218+
219+
jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
220+
const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
221+
: jit_convert_emitter(host, host_isa, node, exec_prc) {
222+
}
223+
224+
jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
225+
ov::element::Type input_prc,
226+
ov::element::Type output_prc,
227+
ov::element::Type exec_prc)
228+
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
229+
}
230+
231+
void jit_convert_truncation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
232+
validate_types();
233+
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
234+
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_idxs, out_idxs);
235+
} else {
236+
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
237+
}
238+
}
239+
240+
template <cpu_isa_t isa>
241+
void jit_convert_truncation_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
242+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
243+
TReg src = TReg(in_idxs[0]);
244+
TReg dst = TReg(out_idxs[0]);
245+
jit_convert_process<TReg>(src, dst, input_type, output_type, false);
246+
}
247+
248+
jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa,
249+
const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
250+
: jit_convert_emitter(host, host_isa, node, exec_prc) {
251+
}
252+
253+
jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa,
254+
ov::element::Type input_prc,
255+
ov::element::Type output_prc,
256+
ov::element::Type exec_prc)
257+
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
258+
}
259+
260+
void jit_convert_saturation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
261+
validate_types();
262+
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
263+
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_idxs, out_idxs);
264+
} else {
265+
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
266+
}
267+
}
268+
269+
template <cpu_isa_t isa>
270+
void jit_convert_saturation_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
271+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
272+
TReg src = TReg(in_idxs[0]);
273+
TReg dst = TReg(out_idxs[0]);
274+
jit_convert_process<TReg>(src, dst, input_type, output_type, true);
275+
}
276+
277+
} // namespace aarch64
278+
} // namespace intel_cpu
279+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "jit_emitter.hpp"
8+
9+
namespace ov {
10+
namespace intel_cpu {
11+
namespace aarch64 {
12+
13+
class jit_convert_emitter : public jit_emitter {
14+
public:
15+
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
16+
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
17+
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
18+
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);
19+
20+
size_t get_inputs_count() const override;
21+
22+
protected:
23+
void emit_data() const override;
24+
void validate_types() const;
25+
template <typename TReg>
26+
void jit_convert_process(const TReg &src, const TReg &dst, ov::element::Type input_type, ov::element::Type output_type,
27+
bool is_saturated) const;
28+
29+
ov::element::Type input_type;
30+
ov::element::Type output_type;
31+
32+
private:
33+
template <typename TReg>
34+
inline void cvt_f16_to_f32(const TReg &src, const TReg &dst) const;
35+
template <typename TReg>
36+
inline void cvt_f32_to_f16(const TReg &src, const TReg &dst) const;
37+
template <typename TReg>
38+
inline void cvt_f32_to_i32(const TReg &src, const TReg &dst) const;
39+
template <typename TReg>
40+
inline void cvt_i32_to_f32(const TReg &src, const TReg &dst) const;
41+
template <typename TReg>
42+
inline void cvt_i32_to_i16(const TReg &src, const TReg &dst, bool is_saturated) const;
43+
template <typename TReg>
44+
inline void cvt_i16_to_i32(const TReg &src, const TReg &dst) const;
45+
template <typename TReg>
46+
inline void cvt_f16_to_i16(const TReg &src, const TReg &dst) const;
47+
template <typename TReg>
48+
inline void cvt_i16_to_f16(const TReg &src, const TReg &dst) const;
49+
template <typename TReg>
50+
inline void cvt_i16_to_byte(const TReg &src, const TReg &dst, bool is_signed, bool is_saturated) const;
51+
template <typename TReg>
52+
inline void cvt_byte_to_i16(const TReg &src, const TReg &dst, bool is_signed) const;
53+
};
54+
55+
// This emitter is covered by specification of "Convert" operation. The implementation uses a "warp-around" conversion.
56+
// Example:
57+
// int32_t -> int8_t
58+
// 129 -> -127
59+
class jit_convert_truncation_emitter : public jit_convert_emitter {
60+
public:
61+
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
62+
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
63+
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
64+
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);
65+
66+
private:
67+
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
68+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
69+
void emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const;
70+
};
71+
72+
// This emitter is covered by the common dnnl behavior. The implementation uses a "saturation" conversion.
73+
// Example:
74+
// int32_t -> int8_t
75+
// 129 -> 127
76+
class jit_convert_saturation_emitter : public jit_convert_emitter {
77+
public:
78+
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
79+
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
80+
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
81+
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);
82+
83+
private:
84+
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
85+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
86+
void emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const;
87+
};
88+
89+
} // namespace aarch64
90+
} // namespace intel_cpu
91+
} // namespace ov

0 commit comments

Comments
 (0)