Skip to content

Commit 66ca73f

Browse files
committed
Apply review comments regarding conversion between f16 and i8(u8)
1 parent 57d8b92 commit 66ca73f

File tree

4 files changed

+122
-33
lines changed

4 files changed

+122
-33
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_conversion_emitters.cpp

+14-12
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
3636
break;
3737
case ov::element::i8:
3838
case ov::element::u8:
39-
cvt_byte_to_i32<isa>(h, in_idxs, out_idxs, input_type.is_signed());
39+
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
40+
cvt_dbyte_to_i32<isa>(h, out_idxs, out_idxs, input_type.is_signed());
4041
cvt_i32_to_f32<isa>(h, out_idxs, out_idxs);
4142
break;
4243
default:
@@ -56,7 +57,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
5657
break;
5758
case ov::element::i8:
5859
case ov::element::u8:
59-
cvt_byte_to_i32<isa>(h, in_idxs, out_idxs, input_type.is_signed());
60+
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
61+
cvt_dbyte_to_i32<isa>(h, out_idxs, out_idxs, input_type.is_signed());
6062
break;
6163
default:
6264
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
@@ -75,9 +77,8 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
7577
break;
7678
case ov::element::i8:
7779
case ov::element::u8:
78-
cvt_byte_to_i32<isa>(h, in_idxs, out_idxs, input_type.is_signed());
79-
cvt_i32_to_f32<isa>(h, out_idxs, out_idxs);
80-
cvt_f32_to_f16<isa>(h, out_idxs, out_idxs);
80+
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
81+
cvt_dbyte_to_f16<isa>(h, out_idxs, out_idxs, input_type.is_signed());
8182
break;
8283
default:
8384
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());
@@ -88,20 +89,21 @@ static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
8889
switch (input_type) {
8990
case ov::element::f32:
9091
cvt_f32_to_i32<isa>(h, in_idxs, out_idxs);
91-
cvt_i32_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
92+
cvt_i32_to_dbyte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
93+
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
9294
break;
9395
case ov::element::i32:
94-
cvt_i32_to_byte<isa>(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated);
96+
cvt_i32_to_dbyte<isa>(h, in_idxs, out_idxs, output_type.is_signed(), is_saturated);
97+
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
9598
break;
9699
case ov::element::f16:
97-
cvt_f16_to_f32<isa>(h, in_idxs, out_idxs);
98-
cvt_f32_to_i32<isa>(h, out_idxs, out_idxs);
99-
cvt_i32_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
100+
cvt_f16_to_dbyte<isa>(h, in_idxs, out_idxs);
101+
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
100102
break;
101103
case ov::element::i8:
102104
case ov::element::u8:
103-
cvt_byte_to_i32<isa>(h, in_idxs, out_idxs, input_type.is_signed());
104-
cvt_i32_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
105+
cvt_byte_to_dbyte<isa>(h, in_idxs, out_idxs, input_type.is_signed());
106+
cvt_dbyte_to_byte<isa>(h, out_idxs, out_idxs, output_type.is_signed(), is_saturated);
105107
break;
106108
default:
107109
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", input_type.get_type_name());

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,13 @@ void jit_load_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::v
179179
load_byte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
180180
switch (dst_prc_) {
181181
case ov::element::f32:
182-
cvt_byte_to_i32<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
182+
cvt_byte_to_dbyte<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
183+
cvt_dbyte_to_i32<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
183184
cvt_i32_to_f32<isa>(h, aux_vec_idxs, out_idxs);
184185
break;
185186
case ov::element::i32:
186-
cvt_byte_to_i32<isa>(h, aux_vec_idxs, out_idxs, src_prc_.is_signed());
187+
cvt_byte_to_dbyte<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
188+
cvt_dbyte_to_i32<isa>(h, aux_vec_idxs, out_idxs, src_prc_.is_signed());
187189
break;
188190
case ov::element::i8:
189191
case ov::element::u8:
@@ -375,10 +377,12 @@ void jit_store_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::
375377
switch (src_prc_) {
376378
case ov::element::f32:
377379
cvt_f32_to_i32<isa>(h, in_idxs, aux_vec_idxs);
378-
cvt_i32_to_byte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_);
380+
cvt_i32_to_dbyte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_);
381+
cvt_dbyte_to_byte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_);
379382
break;
380383
case ov::element::i32:
381-
cvt_i32_to_byte<isa>(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_);
384+
cvt_i32_to_dbyte<isa>(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_);
385+
cvt_dbyte_to_byte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), is_saturated_);
382386
break;
383387
case ov::element::i8:
384388
case ov::element::u8:

src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.cpp

+81-13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ namespace ov {
88
namespace intel_cpu {
99
namespace aarch64 {
1010

11+
// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported
12+
// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate
13+
// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of
14+
// neighbors in this graph will be done with single instruction.
15+
// f16 - f32 - i32 - dbyte - byte
16+
// | |
17+
// - - - - - - - - - - -
1118
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
1219
void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
1320
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
@@ -41,37 +48,83 @@ void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vecto
4148
}
4249

4350
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
44-
void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
51+
void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
4552
bool is_signed, bool is_saturated) {
4653
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
4754
TReg src = TReg(in_idxs[0]);
4855
TReg dst = TReg(out_idxs[0]);
4956
if (is_saturated) {
5057
if (is_signed) {
5158
h->sqxtn(dst.h4, src.s4);
52-
h->sqxtn(dst.b8, dst.h8);
5359
} else {
5460
h->uqxtn(dst.h4, src.s4);
55-
h->uqxtn(dst.b8, dst.h8);
5661
}
5762
} else {
5863
h->xtn(dst.h4, src.s4);
59-
h->xtn(dst.b8, dst.h8);
6064
}
6165
}
6266

6367
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
64-
void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
68+
void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
6569
bool is_signed) {
6670
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
6771
TReg src = TReg(in_idxs[0]);
6872
TReg dst = TReg(out_idxs[0]);
73+
if (is_signed) {
74+
h->sxtl(dst.s4, src.h4);
75+
} else {
76+
h->uxtl(dst.s4, src.h4);
77+
}
78+
}
79+
80+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
81+
void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
82+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
83+
TReg src = TReg(in_idxs[0]);
84+
TReg dst = TReg(out_idxs[0]);
85+
h->fcvtzs(dst.h, src.h);
86+
}
87+
88+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
89+
void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
90+
bool is_signed) {
91+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
92+
TReg src = TReg(in_idxs[0]);
93+
TReg dst = TReg(out_idxs[0]);
94+
if (is_signed) {
95+
h->scvtf(dst.h, src.h);
96+
} else {
97+
h->ucvtf(dst.h, src.h);
98+
}
99+
}
100+
101+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
102+
void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
103+
bool is_signed, bool is_saturated) {
104+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
105+
TReg src = TReg(in_idxs[0]);
106+
TReg dst = TReg(out_idxs[0]);
107+
if (is_saturated) {
108+
if (is_signed) {
109+
h->sqxtn(dst.b8, src.h8);
110+
} else {
111+
h->uqxtn(dst.b8, src.h8);
112+
}
113+
} else {
114+
h->xtn(dst.b8, src.h8);
115+
}
116+
}
117+
118+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
119+
void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
120+
bool is_signed) {
121+
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
122+
TReg src = TReg(in_idxs[0]);
123+
TReg dst = TReg(out_idxs[0]);
69124
if (is_signed) {
70125
h->sxtl(dst.h8, src.b8);
71-
h->sxtl(dst.s4, dst.h4);
72126
} else {
73127
h->uxtl(dst.h8, src.b8);
74-
h->uxtl(dst.s4, dst.h4);
75128
}
76129
}
77130

@@ -87,13 +140,28 @@ template void cvt_f32_to_i32<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::a
87140
template void cvt_i32_to_f32<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
88141
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs);
89142

90-
template void cvt_i32_to_byte<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
91-
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
92-
bool is_signed, bool is_saturation);
143+
template void cvt_i32_to_dbyte<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
144+
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
145+
bool is_signed, bool is_saturation);
146+
147+
template void cvt_dbyte_to_i32<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
148+
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
149+
bool is_signed);
150+
151+
template void cvt_f16_to_dbyte<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
152+
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs);
153+
154+
template void cvt_dbyte_to_f16<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
155+
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
156+
bool is_signed);
157+
158+
template void cvt_dbyte_to_byte<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
159+
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
160+
bool is_signed, bool is_saturation);
93161

94-
template void cvt_byte_to_i32<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
95-
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
96-
bool is_signed);
162+
template void cvt_byte_to_dbyte<dnnl::impl::cpu::aarch64::asimd>(dnnl::impl::cpu::aarch64::jit_generator* h,
163+
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
164+
bool is_signed);
97165

98166
} // namespace aarch64
99167
} // namespace intel_cpu

src/plugins/intel_cpu/src/emitters/plugin/aarch64/utils.hpp

+19-4
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,27 @@ template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
2424
void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs);
2525

2626
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
27-
void cvt_i32_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
28-
bool is_signed, bool is_saturated);
27+
void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
28+
bool is_signed, bool is_saturated);
2929

3030
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
31-
void cvt_byte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
32-
bool is_signed);
31+
void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
32+
bool is_signed);
33+
34+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
35+
void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs);
36+
37+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
38+
void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
39+
bool is_signed);
40+
41+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
42+
void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
43+
bool is_signed, bool is_saturated);
44+
45+
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
46+
void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
47+
bool is_signed);
3348

3449
} // namespace aarch64
3550
} // namespace intel_cpu

0 commit comments

Comments
 (0)