Skip to content

Commit f0ebba0

Browse files
authored
[CPU] [ARM64] jit emitter pipeline fix (openvinotoolkit#23387)
### Details: - *[CPU] [ARM64] jit emitter pipeline fix*
1 parent 496a5de commit f0ebba0

File tree

4 files changed

+152
-62
lines changed

4 files changed

+152
-62
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp

+124-44
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "jit_emitter.hpp"
66
#include <vector>
77
#include "utils/general_utils.h"
8+
#include "emitters/utils.hpp"
89

910
using namespace dnnl::impl::cpu;
1011
using namespace dnnl::impl;
@@ -13,7 +14,7 @@ namespace ov {
1314
namespace intel_cpu {
1415
namespace aarch64 {
1516

16-
const std::vector<uint32_t> jit_emitter::store_gpr_regs = {
17+
const std::vector<size_t> jit_emitter::store_gpr_regs = {
1718
// Parameter/result registers
1819
0, 1, 2, 3, 4, 5, 6, 7,
1920
// r8: Indirect result location register
@@ -24,6 +25,13 @@ const std::vector<uint32_t> jit_emitter::store_gpr_regs = {
2425
29, 30
2526
};
2627

28+
static const std::vector<size_t> vec_regs = {
29+
0, 1, 2, 3, 4, 5, 6, 7,
30+
8, 9, 10, 11, 12, 13, 14, 15,
31+
16, 17, 18, 19, 20, 21, 22, 23,
32+
24, 25, 26, 27, 28, 29, 30, 31
33+
};
34+
2735
void jit_emitter::emit_code(const std::vector<size_t> &in_idxs,
2836
const std::vector<size_t> &out_idxs,
2937
const std::vector<size_t> &pool_vec_idxs,
@@ -98,135 +106,207 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
98106
OPENVINO_THROW("Failed to allocate required number of gpr registers");
99107
}
100108

109+
using namespace Xbyak_aarch64::util;
110+
const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
111+
(in_out_type_ == emitter_in_out_map::vec_to_gpr);
112+
const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
113+
(in_out_type_ == emitter_in_out_map::gpr_to_vec);
114+
115+
// vector registers
101116
for (auto idx : pool_aux_vec_idxs) {
102117
aux_vec_idxs.push_back(static_cast<uint32_t>(idx));
103118
}
104119

120+
for (size_t idx = 0; idx < get_max_vecs_count(); idx++) {
121+
if (aux_vec_idxs.size() >= get_aux_vecs_count()) break;
122+
123+
if (is_vec_input) {
124+
if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue;
125+
}
126+
if (is_vec_output) {
127+
if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue;
128+
}
129+
130+
if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue;
131+
if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue;
132+
133+
if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) continue;
134+
if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) continue;
135+
136+
aux_vec_idxs.push_back(idx);
137+
preserved_vec_idxs.push_back(idx);
138+
}
139+
if (aux_vec_idxs.size() < get_aux_vecs_count())
140+
OV_CPU_JIT_EMITTER_THROW("Failed to allocate required number of vector registers");
141+
142+
// gpr registers
105143
for (auto idx : pool_aux_gpr_idxs) {
106-
aux_gpr_idxs.push_back(static_cast<uint32_t>(idx));
107-
preserved_gpr_idxs.push_back(static_cast<uint32_t>(idx));
144+
aux_gpr_idxs.push_back(idx);
145+
}
146+
147+
const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30;
148+
for (size_t gpr_idx = 0; gpr_idx <= end_gpr_idx; ++gpr_idx) {
149+
size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end
150+
151+
if (aux_gpr_idxs.size() >= get_aux_gprs_count()) break;
152+
if ((_idx == Xbyak_aarch64::Operand::X18) ||
153+
(_idx == Xbyak_aarch64::Operand::X23) ||
154+
(_idx == Xbyak_aarch64::Operand::X28)) continue;
155+
156+
if (!is_vec_input) {
157+
if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue;
158+
}
159+
if (!is_vec_output) {
160+
if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue;
161+
}
162+
163+
if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue;
164+
if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue;
165+
166+
aux_gpr_idxs.push_back(_idx);
167+
preserved_gpr_idxs.push_back(_idx);
108168
}
169+
if (aux_gpr_idxs.size() < get_aux_gprs_count())
170+
OV_CPU_JIT_EMITTER_THROW("Failed to allocate required number of general-purpose registers");
109171

110172
if (!entry_map_.empty()) {
111173
// last aux_gpr_idx is for p_table, we can use aux_gpr_idxs from idx 0 for other purpose
112174
p_table = Xbyak_aarch64::XReg(aux_gpr_idxs[aux_gpr_idxs.size() - 1]);
113175
aux_gpr_idxs.erase(aux_gpr_idxs.end() - 1);
114176
}
115177

116-
for (size_t i = 0; i < preserved_gpr_idxs.size(); ++i) {
117-
h->str(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), pre_ptr(h->sp, -16));
118-
}
178+
store_context(preserved_gpr_idxs, preserved_vec_idxs);
119179

120180
if (!entry_map_.empty()) {
121181
load_table_addr();
122182
}
123183
}
124184

125185
void jit_emitter::emitter_postamble() const {
126-
const int size = static_cast<int>(preserved_gpr_idxs.size());
127-
for (int i = (size - 1); i >= 0; --i) {
128-
h->ldr(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), post_ptr(h->sp, 16));
129-
}
186+
restore_context(preserved_gpr_idxs, preserved_vec_idxs);
187+
188+
preserved_vec_idxs.clear();
130189
preserved_gpr_idxs.clear();
131190

132191
aux_vec_idxs.clear();
133192
aux_gpr_idxs.clear();
134193
}
135194

136195
void jit_emitter::store_context(const std::unordered_set<size_t>& ignore_registers) const {
196+
store_context(store_gpr_regs, vec_regs, ignore_registers);
197+
}
198+
199+
void jit_emitter::store_context(
200+
const std::vector<size_t>& gpr_regs,
201+
const std::vector<size_t>& vec_regs,
202+
const std::unordered_set<size_t>& ignore_vec_regs) const {
137203
// 1. General-purpose Registers
138204
// 1.1. store pair registers
139-
const auto store_gpr_regs_size = store_gpr_regs.size();
205+
const auto store_gpr_regs_size = gpr_regs.size();
140206
const auto last = store_gpr_regs_size % 2;
141207
for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) {
142-
h->stp(Xbyak_aarch64::XReg(store_gpr_regs[i]),
143-
Xbyak_aarch64::XReg(store_gpr_regs[i + 1]),
144-
pre_ptr(h->sp, -get_gpr_length() * 2));
208+
h->stp(Xbyak_aarch64::XReg(gpr_regs[i]),
209+
Xbyak_aarch64::XReg(gpr_regs[i + 1]),
210+
pre_ptr(h->sp, -get_gpr_length() * 2));
145211
}
146-
147-
// 1.1. store the remaining register
212+
// 1.2. store the remaining register
148213
if (last != 0) {
149-
h->str(Xbyak_aarch64::XReg(store_gpr_regs[store_gpr_regs_size - 1]),
150-
pre_ptr(h->sp, -get_gpr_length() * 2));
214+
h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]),
215+
pre_ptr(h->sp, -get_gpr_length()));
151216
}
152217

153218
// 2. SIMD and Floating-Point registers
154219
// 2.1. store pair registers
155220
int prev_reg_idx = -1;
156221
size_t ignore_registers_count = 0;
157-
for (size_t reg_idx = 0; reg_idx < get_asimd_vectors_count(); reg_idx++) {
158-
if (ignore_registers.find(reg_idx) != ignore_registers.end()) {
222+
for (size_t reg_idx = 0; reg_idx < vec_regs.size(); reg_idx++) {
223+
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
159224
ignore_registers_count++;
160225
continue;
161226
}
162-
163227
if (prev_reg_idx == -1) {
164228
prev_reg_idx = static_cast<int>(reg_idx);
165229
continue;
166230
}
167-
168231
h->stp(Xbyak_aarch64::QReg(prev_reg_idx),
169232
Xbyak_aarch64::QReg(reg_idx),
170233
pre_ptr(h->sp, -get_vec_length() * 2));
171234
prev_reg_idx = -1;
172235
}
173-
OPENVINO_ASSERT(ignore_registers_count == ignore_registers.size(),
174-
"ignored registers size is not equal actual ignored registers count");
175236

176237
// 2.1. store the remaining register
177238
if (prev_reg_idx != -1) {
178-
h->str(Xbyak_aarch64::QReg(prev_reg_idx),
179-
pre_ptr(h->sp, -get_vec_length()));
239+
if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) {
240+
h->str(Xbyak_aarch64::QReg(prev_reg_idx),
241+
pre_ptr(h->sp, -get_vec_length()));
242+
} else {
243+
ignore_registers_count++;
244+
}
180245
}
246+
247+
OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
248+
"ignored registers size is not equal actual ignored registers count");
181249
}
182250

183-
void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_registers) const {
251+
void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_vec_regs) const {
252+
restore_context(store_gpr_regs, vec_regs, ignore_vec_regs);
253+
}
254+
255+
void jit_emitter::restore_context(
256+
const std::vector<size_t>& gpr_regs,
257+
const std::vector<size_t>& vec_regs,
258+
const std::unordered_set<size_t>& ignore_vec_regs) const {
184259
// 1. SIMD and Floating-Point registers
185260
// 1.1. restore the remaining register
186-
const auto v_last = (get_asimd_vectors_count() - ignore_registers.size()) % 2;
261+
auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2;
187262
if (v_last != 0) {
188-
const auto reg_idx = get_asimd_vectors_count() - 1;
189-
h->ldr(Xbyak_aarch64::QReg(reg_idx),
190-
post_ptr(h->sp, get_vec_length()));
263+
for (size_t i = 0; i < vec_regs.size(); i++) {
264+
const auto reg_idx = vec_regs.size() - 1 - i;
265+
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
266+
v_last++;
267+
continue;
268+
}
269+
270+
h->ldr(Xbyak_aarch64::QReg(reg_idx),
271+
post_ptr(h->sp, get_vec_length()));
272+
break;
273+
}
191274
}
192-
193-
// 2.2. restore pair registers
275+
// 1.2. restore pair registers
194276
size_t ignore_registers_count = 0;
195277
int prev_reg_idx = -1;
196-
for (size_t i = v_last; i < get_asimd_vectors_count(); i++) {
197-
const auto reg_idx = get_asimd_vectors_count() - 1 - i;
198-
if (ignore_registers.find(reg_idx) != ignore_registers.end()) {
278+
for (size_t i = v_last; i < vec_regs.size(); i++) {
279+
const auto reg_idx = vec_regs.size() - 1 - i;
280+
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
199281
ignore_registers_count++;
200282
continue;
201283
}
202-
203284
if (prev_reg_idx == -1) {
204285
prev_reg_idx = static_cast<int>(reg_idx);
205286
continue;
206287
}
207-
208288
h->ldp(Xbyak_aarch64::QReg(reg_idx),
209289
Xbyak_aarch64::QReg(prev_reg_idx),
210290
post_ptr(h->sp, get_vec_length() * 2));
211291
prev_reg_idx = -1;
212292
}
213293

214-
OPENVINO_ASSERT(ignore_registers_count == ignore_registers.size(),
294+
OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
215295
"ignored registers size is not equal actual ignored registers count");
216296

217297
// 2. General-purpose Registers
218298
// 2.1. restore the remaining register
219-
const auto save_gpr_regs_size = store_gpr_regs.size();
299+
const auto save_gpr_regs_size = gpr_regs.size();
220300
const auto last = save_gpr_regs_size % 2;
221301
if (last != 0) {
222-
h->ldr(Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1]),
223-
post_ptr(h->sp, get_gpr_length() * 2));
302+
h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]),
303+
post_ptr(h->sp, get_gpr_length()));
224304
}
225305

226306
// 2.2. restore pair registers
227307
for (size_t i = last; i < save_gpr_regs_size; i += 2) {
228-
h->ldp(Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
229-
Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1 - i]),
308+
h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
309+
Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]),
230310
post_ptr(h->sp, get_gpr_length() * 2));
231311
}
232312
}

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,11 @@ class jit_emitter : public ov::snippets::Emitter {
141141
}
142142

143143
private:
144+
mutable std::vector<size_t> preserved_vec_idxs;
144145
mutable std::vector<size_t> preserved_gpr_idxs;
145146

146147
// General-purpose Registers
147-
static const std::vector<uint32_t> store_gpr_regs;
148+
static const std::vector<size_t> store_gpr_regs;
148149

149150
size_t table_off(const std::string& key, const size_t key_off_val_shift = 0) const {
150151
// assumption: all table entries sharing the same key also
@@ -163,6 +164,14 @@ class jit_emitter : public ov::snippets::Emitter {
163164
inline int32_t get_gpr_length() const {
164165
return h->x0.getBit() / 8;
165166
}
167+
168+
void store_context(const std::vector<size_t>& gpr_regs,
169+
const std::vector<size_t>& vec_regs,
170+
const std::unordered_set<size_t>& ignore_vec_regs = {}) const;
171+
172+
void restore_context(const std::vector<size_t>& gpr_regs,
173+
const std::vector<size_t>& vec_regs,
174+
const std::unordered_set<size_t>& ignore_vec_regs = {}) const;
166175
};
167176

168177
} // namespace aarch64

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "jit_emitter.hpp"
66
#include <vector>
77
#include "utils/general_utils.h"
8+
#include "utils.hpp"
89

910
using namespace dnnl::impl::cpu;
1011
using namespace dnnl::impl;

src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
9797
void generate() override;
9898

9999
private:
100-
const Xbyak_aarch64::XReg X_TMP_0 = x10;
101-
const Xbyak_aarch64::XReg X_TMP_1 = x11;
102-
103-
XReg reg_post_op_ptrs = X_TMP_0;
100+
XReg reg_post_op_ptrs = x10;
104101
XReg start_to_offsets = reg_post_op_ptrs;
105102

106103
XReg reg_oc_off = x12;
@@ -126,10 +123,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
126123
// X10 | ker temporary| R10 | src ptr
127124
// X11 | ker temporary| R11 | src ptr
128125
// X12 | ker temporary (abi_not_param1) | R12 | src ptr
129-
// X13 | [not used] | R13 | src ptr
130-
// X14 | [not used] | R14 | src ptr
131-
// X15 | dst | R15 | temporary
132-
// X16 | [not used: IP1]
126+
// X13 | temporary | R13 | src ptr
127+
// X14 | temporary | R14 | src ptr
128+
// X15 | temporary | R15 | temporary
129+
// X16 | dst
133130
// X17 | [not used: IP0]
134131
// X18 | [not used: Apple: The platforms reserve register x18. Don't use this register.]
135132

@@ -138,32 +135,35 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
138135
// X20 | src ptr
139136
// X21 | src ptr
140137
// X22 | src ptr
141-
// X23 | src ptr
142-
// X24 | src ptr
138+
// X23 | kernel used (oneDNN: X_TMP_0)
139+
// X24 | kernel used (oneDNN: X_TMP_1)
143140
// X25 | src ptr
144-
// X26 | temporary
145-
// X27 | temporary
146-
// X28 | kernel used (X_DEFAULT_ADDR)
141+
// X26 | src ptr
142+
// X27 | src ptr
143+
// X28 | kernel used (oneDNN: X_DEFAULT_ADDR)
147144

148145
// X29 | [not used: The Frame Pointer (FP)]
149146
// X30 | [not used: The Link Register (LR)]
150147
// X31 | [not used: The Stack Pointer (SP)]
151148

152149
const XReg reg_work_amount = x9;
153-
const XReg reg_dst = x15;
150+
const XReg reg_dst = x16;
154151

155152
inline XReg get_src_reg(uint32_t idx) {
156153
if (idx > MAX_ELTWISE_INPUTS) {
157154
OPENVINO_THROW("source vector ptr register " + std::to_string(idx) + " is not supported");
158155
}
159-
return XReg(19 + idx);
156+
157+
static const std::vector<uint32_t> src_gprs = { 19, 20, 21, 22, 25, 26, 27 };
158+
return XReg(src_gprs[idx]);
160159
}
161160

162161
inline XReg get_aux_gpr(const uint32_t idx) {
163-
if (idx > 2) {
162+
if (idx > 3) {
164163
OPENVINO_THROW("aux gpr register " + std::to_string(idx) + " is not supported");
165164
}
166-
return XReg(26 + idx);
165+
166+
return XReg(13 + idx);
167167
}
168168

169169
// Vector registers mapping

0 commit comments

Comments
 (0)