Skip to content

Commit 6eef6da

Browse files
committed
Make conversion between f16 and i8 compatible with ARMv8
1 parent e010c59 commit 6eef6da

File tree

1 file changed

+37
-42
lines changed

1 file changed

+37
-42
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp

+37-42
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ namespace aarch64 {
1414

1515
// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported
1616
// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate
17-
// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of
17+
// precision i16 in the following graph. Then the conversion between each pair of
1818
// neighbors in this graph will be done with single instruction.
19-
// f16 - f32 - i32 - dbyte - byte
20-
// | |
21-
// - - - - - - - - - - -
19+
// f16 - f32 - i32 - i16 - byte
20+
// | |
21+
// - - - - - - - - - -
22+
// Note that using single instruction for conversion between f16 and i16 is only available for
23+
// architecture ARMv8.2-A or later versions. So ARM platforms like Raspberry (Model name Cortex-A72)
24+
// with architecture ARMv8 do not support such instructions. And as the isa asimd we supported
25+
// does not distinguish ARMv8.2 with ARMv8.2-A, conversion between f16 and i16 will still use three
26+
// instructions f16 -> f32 -> i32 -> i16 (f16 <- f32 <- i32 <- i16).
2227
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
2328
static void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
2429
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
@@ -52,58 +57,44 @@ static void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std
5257
}
5358

5459
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
55-
static void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
56-
bool is_signed, bool is_saturated) {
60+
static void cvt_i32_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
61+
bool is_saturated) {
5762
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
5863
TReg src = TReg(in_idxs[0]);
5964
TReg dst = TReg(out_idxs[0]);
6065
if (is_saturated) {
61-
if (is_signed) {
62-
h->sqxtn(dst.h4, src.s4);
63-
} else {
64-
h->uqxtn(dst.h4, src.s4);
65-
}
66+
h->sqxtn(dst.h4, src.s4);
6667
} else {
6768
h->xtn(dst.h4, src.s4);
6869
}
6970
}
7071

7172
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
72-
static void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
73-
bool is_signed) {
73+
static void cvt_i16_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
7474
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
7575
TReg src = TReg(in_idxs[0]);
7676
TReg dst = TReg(out_idxs[0]);
77-
if (is_signed) {
78-
h->sxtl(dst.s4, src.h4);
79-
} else {
80-
h->uxtl(dst.s4, src.h4);
81-
}
77+
h->sxtl(dst.s4, src.h4);
8278
}
8379

8480
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
85-
static void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
81+
static void cvt_f16_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
8682
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
8783
TReg src = TReg(in_idxs[0]);
8884
TReg dst = TReg(out_idxs[0]);
8985
h->fcvtzs(dst.h4, src.h4);
9086
}
9187

9288
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
93-
static void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
94-
bool is_signed) {
89+
static void cvt_i16_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
9590
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
9691
TReg src = TReg(in_idxs[0]);
9792
TReg dst = TReg(out_idxs[0]);
98-
if (is_signed) {
99-
h->scvtf(dst.h4, src.h4);
100-
} else {
101-
h->ucvtf(dst.h4, src.h4);
102-
}
93+
h->scvtf(dst.h4, src.h4);
10394
}
10495

10596
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
106-
static void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
97+
static void cvt_i16_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
10798
bool is_signed, bool is_saturated) {
10899
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
109100
TReg src = TReg(in_idxs[0]);
@@ -120,7 +111,7 @@ static void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const
120111
}
121112

122113
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
123-
static void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
114+
static void cvt_byte_to_i16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
124115
bool is_signed) {
125116
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
126117
TReg src = TReg(in_idxs[0]);
@@ -155,8 +146,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
155146
break;
156147
case ov::element::i8:
157148
case ov::element::u8:
158-
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
159-
cvt_dbyte_to_i32<isa>(h, out_idxs, out_idxs, input_type.is_signed());
149+
cvt_byte_to_i16<isa>(h, in_idxs, out_idxs, input_type.is_signed());
150+
cvt_i16_to_i32<isa>(h, out_idxs, out_idxs);
160151
cvt_i32_to_f32<isa>(h, out_idxs, out_idxs);
161152
break;
162153
default:
@@ -176,8 +167,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
176167
break;
177168
case ov::element::i8:
178169
case ov::element::u8:
179-
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
180-
cvt_dbyte_to_i32<isa>(h, out_idxs, out_idxs, input_type.is_signed());
170+
cvt_byte_to_i16<isa>(h, in_idxs, out_idxs, input_type.is_signed());
171+
cvt_i16_to_i32<isa>(h, out_idxs, out_idxs);
181172
break;
182173
default:
183174
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
@@ -196,8 +187,10 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
196187
break;
197188
case ov::element::i8:
198189
case ov::element::u8:
199-
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
200-
cvt_dbyte_to_f16<isa>(h, out_idxs, out_idxs, input_type.is_signed());
190+
cvt_byte_to_i16<isa>(h, in_idxs, out_idxs, input_type.is_signed());
191+
cvt_i16_to_i32<isa>(h, out_idxs, out_idxs);
192+
cvt_i32_to_f32<isa>(h, out_idxs, out_idxs);
193+
cvt_f32_to_f16<isa>(h, out_idxs, out_idxs);
201194
break;
202195
default:
203196
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
@@ -208,21 +201,23 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
208201
switch (input_type) {
209202
case ov::element::f32:
210203
cvt_f32_to_i32<isa>(h, in_idxs, out_idxs);
211-
cvt_i32_to_dbyte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
212-
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
204+
cvt_i32_to_i16<isa>(h, out_idxs, out_idxs, is_saturated);
205+
cvt_i16_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
213206
break;
214207
case ov::element::i32:
215-
cvt_i32_to_dbyte<isa>(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated);
216-
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
208+
cvt_i32_to_i16<isa>(h, in_idxs, out_idxs, is_saturated);
209+
cvt_i16_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
217210
break;
218211
case ov::element::f16:
219-
cvt_f16_to_dbyte<isa>(h, in_idxs, out_idxs);
220-
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
212+
cvt_f16_to_f32<isa>(h, in_idxs, out_idxs);
213+
cvt_f32_to_i32<isa>(h, out_idxs, out_idxs);
214+
cvt_i32_to_i16<isa>(h, out_idxs, out_idxs, is_saturated);
215+
cvt_i16_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
221216
break;
222217
case ov::element::i8:
223218
case ov::element::u8:
224-
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
225-
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
219+
cvt_byte_to_i16<isa>(h, in_idxs, out_idxs, input_type.is_signed());
220+
cvt_i16_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
226221
break;
227222
default:
228223
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());

0 commit comments

Comments
 (0)