5
5
#include " jit_emitter.hpp"
6
6
#include < vector>
7
7
#include " utils/general_utils.h"
8
+ #include " emitters/utils.hpp"
8
9
9
10
using namespace dnnl ::impl::cpu;
10
11
using namespace dnnl ::impl;
@@ -13,7 +14,7 @@ namespace ov {
13
14
namespace intel_cpu {
14
15
namespace aarch64 {
15
16
16
- const std::vector<uint32_t > jit_emitter::store_gpr_regs = {
17
+ const std::vector<size_t > jit_emitter::store_gpr_regs = {
17
18
// Parameter/result registers
18
19
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
19
20
// r8: Indirect result location register
@@ -24,6 +25,13 @@ const std::vector<uint32_t> jit_emitter::store_gpr_regs = {
24
25
29 , 30
25
26
};
26
27
28
+ static const std::vector<size_t > vec_regs = {
29
+ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
30
+ 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ,
31
+ 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 ,
32
+ 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31
33
+ };
34
+
27
35
void jit_emitter::emit_code (const std::vector<size_t > &in_idxs,
28
36
const std::vector<size_t > &out_idxs,
29
37
const std::vector<size_t > &pool_vec_idxs,
@@ -98,135 +106,207 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
98
106
OPENVINO_THROW (" Failed to allocate required number of gpr registers" );
99
107
}
100
108
109
+ using namespace Xbyak_aarch64 ::util;
110
+ const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
111
+ (in_out_type_ == emitter_in_out_map::vec_to_gpr);
112
+ const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
113
+ (in_out_type_ == emitter_in_out_map::gpr_to_vec);
114
+
115
+ // vector registers
101
116
for (auto idx : pool_aux_vec_idxs) {
102
117
aux_vec_idxs.push_back (static_cast <uint32_t >(idx));
103
118
}
104
119
120
+ for (size_t idx = 0 ; idx < get_max_vecs_count (); idx++) {
121
+ if (aux_vec_idxs.size () >= get_aux_vecs_count ()) break ;
122
+
123
+ if (is_vec_input) {
124
+ if (std::find (in_idxs.begin (), in_idxs.end (), idx) != in_idxs.end ()) continue ;
125
+ }
126
+ if (is_vec_output) {
127
+ if (std::find (out_idxs.begin (), out_idxs.end (), idx) != out_idxs.end ()) continue ;
128
+ }
129
+
130
+ if (std::find (in_idxs.begin (), in_idxs.end (), idx) != in_idxs.end ()) continue ;
131
+ if (std::find (out_idxs.begin (), out_idxs.end (), idx) != out_idxs.end ()) continue ;
132
+
133
+ if (std::find (aux_vec_idxs.begin (), aux_vec_idxs.end (), idx) != aux_vec_idxs.end ()) continue ;
134
+ if (std::find (preserved_vec_idxs.begin (), preserved_vec_idxs.end (), idx) != preserved_vec_idxs.end ()) continue ;
135
+
136
+ aux_vec_idxs.push_back (idx);
137
+ preserved_vec_idxs.push_back (idx);
138
+ }
139
+ if (aux_vec_idxs.size () < get_aux_vecs_count ())
140
+ OV_CPU_JIT_EMITTER_THROW (" Failed to allocate required number of vector registers" );
141
+
142
+ // gpr registers
105
143
for (auto idx : pool_aux_gpr_idxs) {
106
- aux_gpr_idxs.push_back (static_cast <uint32_t >(idx));
107
- preserved_gpr_idxs.push_back (static_cast <uint32_t >(idx));
144
+ aux_gpr_idxs.push_back (idx);
145
+ }
146
+
147
+ const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30;
148
+ for (size_t gpr_idx = 0 ; gpr_idx <= end_gpr_idx; ++gpr_idx) {
149
+ size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end
150
+
151
+ if (aux_gpr_idxs.size () >= get_aux_gprs_count ()) break ;
152
+ if ((_idx == Xbyak_aarch64::Operand::X18) ||
153
+ (_idx == Xbyak_aarch64::Operand::X23) ||
154
+ (_idx == Xbyak_aarch64::Operand::X28)) continue ;
155
+
156
+ if (!is_vec_input) {
157
+ if (std::find (in_idxs.begin (), in_idxs.end (), _idx) != in_idxs.end ()) continue ;
158
+ }
159
+ if (!is_vec_output) {
160
+ if (std::find (out_idxs.begin (), out_idxs.end (), _idx) != out_idxs.end ()) continue ;
161
+ }
162
+
163
+ if (std::find (aux_gpr_idxs.begin (), aux_gpr_idxs.end (), _idx) != aux_gpr_idxs.end ()) continue ;
164
+ if (std::find (preserved_gpr_idxs.begin (), preserved_gpr_idxs.end (), _idx) != preserved_gpr_idxs.end ()) continue ;
165
+
166
+ aux_gpr_idxs.push_back (_idx);
167
+ preserved_gpr_idxs.push_back (_idx);
108
168
}
169
+ if (aux_gpr_idxs.size () < get_aux_gprs_count ())
170
+ OV_CPU_JIT_EMITTER_THROW (" Failed to allocate required number of general-purpose registers" );
109
171
110
172
if (!entry_map_.empty ()) {
111
173
// last aux_gpr_idx is for p_table, we can use aux_gpr_idxs from idx 0 for other purpose
112
174
p_table = Xbyak_aarch64::XReg (aux_gpr_idxs[aux_gpr_idxs.size () - 1 ]);
113
175
aux_gpr_idxs.erase (aux_gpr_idxs.end () - 1 );
114
176
}
115
177
116
- for (size_t i = 0 ; i < preserved_gpr_idxs.size (); ++i) {
117
- h->str (Xbyak_aarch64::XReg (preserved_gpr_idxs[i]), pre_ptr (h->sp , -16 ));
118
- }
178
+ store_context (preserved_gpr_idxs, preserved_vec_idxs);
119
179
120
180
if (!entry_map_.empty ()) {
121
181
load_table_addr ();
122
182
}
123
183
}
124
184
125
185
void jit_emitter::emitter_postamble () const {
126
- const int size = static_cast <int >(preserved_gpr_idxs.size ());
127
- for (int i = (size - 1 ); i >= 0 ; --i) {
128
- h->ldr (Xbyak_aarch64::XReg (preserved_gpr_idxs[i]), post_ptr (h->sp , 16 ));
129
- }
186
+ restore_context (preserved_gpr_idxs, preserved_vec_idxs);
187
+
188
+ preserved_vec_idxs.clear ();
130
189
preserved_gpr_idxs.clear ();
131
190
132
191
aux_vec_idxs.clear ();
133
192
aux_gpr_idxs.clear ();
134
193
}
135
194
136
195
void jit_emitter::store_context (const std::unordered_set<size_t >& ignore_registers) const {
196
+ store_context (store_gpr_regs, vec_regs, ignore_registers);
197
+ }
198
+
199
+ void jit_emitter::store_context (
200
+ const std::vector<size_t >& gpr_regs,
201
+ const std::vector<size_t >& vec_regs,
202
+ const std::unordered_set<size_t >& ignore_vec_regs) const {
137
203
// 1. General-purpose Registers
138
204
// 1.1. store pair registers
139
- const auto store_gpr_regs_size = store_gpr_regs .size ();
205
+ const auto store_gpr_regs_size = gpr_regs .size ();
140
206
const auto last = store_gpr_regs_size % 2 ;
141
207
for (size_t i = 0 ; i < (store_gpr_regs_size - last); i += 2 ) {
142
- h->stp (Xbyak_aarch64::XReg (store_gpr_regs [i]),
143
- Xbyak_aarch64::XReg (store_gpr_regs [i + 1 ]),
144
- pre_ptr (h->sp , -get_gpr_length () * 2 ));
208
+ h->stp (Xbyak_aarch64::XReg (gpr_regs [i]),
209
+ Xbyak_aarch64::XReg (gpr_regs [i + 1 ]),
210
+ pre_ptr (h->sp , -get_gpr_length () * 2 ));
145
211
}
146
-
147
- // 1.1. store the remaining register
212
+ // 1.2. store the remaining register
148
213
if (last != 0 ) {
149
- h->str (Xbyak_aarch64::XReg (store_gpr_regs [store_gpr_regs_size - 1 ]),
150
- pre_ptr (h->sp , -get_gpr_length () * 2 ));
214
+ h->str (Xbyak_aarch64::XReg (gpr_regs [store_gpr_regs_size - 1 ]),
215
+ pre_ptr (h->sp , -get_gpr_length ()));
151
216
}
152
217
153
218
// 2. SIMD and Floating-Point registers
154
219
// 2.1. store pair registers
155
220
int prev_reg_idx = -1 ;
156
221
size_t ignore_registers_count = 0 ;
157
- for (size_t reg_idx = 0 ; reg_idx < get_asimd_vectors_count (); reg_idx++) {
158
- if (ignore_registers .find (reg_idx) != ignore_registers .end ()) {
222
+ for (size_t reg_idx = 0 ; reg_idx < vec_regs. size (); reg_idx++) {
223
+ if (ignore_vec_regs .find (reg_idx) != ignore_vec_regs .end ()) {
159
224
ignore_registers_count++;
160
225
continue ;
161
226
}
162
-
163
227
if (prev_reg_idx == -1 ) {
164
228
prev_reg_idx = static_cast <int >(reg_idx);
165
229
continue ;
166
230
}
167
-
168
231
h->stp (Xbyak_aarch64::QReg (prev_reg_idx),
169
232
Xbyak_aarch64::QReg (reg_idx),
170
233
pre_ptr (h->sp , -get_vec_length () * 2 ));
171
234
prev_reg_idx = -1 ;
172
235
}
173
- OPENVINO_ASSERT (ignore_registers_count == ignore_registers.size (),
174
- " ignored registers size is not equal actual ignored registers count" );
175
236
176
237
// 2.1. store the remaining register
177
238
if (prev_reg_idx != -1 ) {
178
- h->str (Xbyak_aarch64::QReg (prev_reg_idx),
179
- pre_ptr (h->sp , -get_vec_length ()));
239
+ if (ignore_vec_regs.find (prev_reg_idx) == ignore_vec_regs.end ()) {
240
+ h->str (Xbyak_aarch64::QReg (prev_reg_idx),
241
+ pre_ptr (h->sp , -get_vec_length ()));
242
+ } else {
243
+ ignore_registers_count++;
244
+ }
180
245
}
246
+
247
+ OPENVINO_ASSERT (ignore_registers_count == ignore_vec_regs.size (),
248
+ " ignored registers size is not equal actual ignored registers count" );
181
249
}
182
250
183
- void jit_emitter::restore_context (const std::unordered_set<size_t >& ignore_registers) const {
251
+ void jit_emitter::restore_context (const std::unordered_set<size_t >& ignore_vec_regs) const {
252
+ restore_context (store_gpr_regs, vec_regs, ignore_vec_regs);
253
+ }
254
+
255
+ void jit_emitter::restore_context (
256
+ const std::vector<size_t >& gpr_regs,
257
+ const std::vector<size_t >& vec_regs,
258
+ const std::unordered_set<size_t >& ignore_vec_regs) const {
184
259
// 1. SIMD and Floating-Point registers
185
260
// 1.1. restore the remaining register
186
- const auto v_last = (get_asimd_vectors_count () - ignore_registers .size ()) % 2 ;
261
+ auto v_last = (vec_regs. size () - ignore_vec_regs .size ()) % 2 ;
187
262
if (v_last != 0 ) {
188
- const auto reg_idx = get_asimd_vectors_count () - 1 ;
189
- h->ldr (Xbyak_aarch64::QReg (reg_idx),
190
- post_ptr (h->sp , get_vec_length ()));
263
+ for (size_t i = 0 ; i < vec_regs.size (); i++) {
264
+ const auto reg_idx = vec_regs.size () - 1 - i;
265
+ if (ignore_vec_regs.find (reg_idx) != ignore_vec_regs.end ()) {
266
+ v_last++;
267
+ continue ;
268
+ }
269
+
270
+ h->ldr (Xbyak_aarch64::QReg (reg_idx),
271
+ post_ptr (h->sp , get_vec_length ()));
272
+ break ;
273
+ }
191
274
}
192
-
193
- // 2.2. restore pair registers
275
+ // 1.2. restore pair registers
194
276
size_t ignore_registers_count = 0 ;
195
277
int prev_reg_idx = -1 ;
196
- for (size_t i = v_last; i < get_asimd_vectors_count (); i++) {
197
- const auto reg_idx = get_asimd_vectors_count () - 1 - i;
198
- if (ignore_registers .find (reg_idx) != ignore_registers .end ()) {
278
+ for (size_t i = v_last; i < vec_regs. size (); i++) {
279
+ const auto reg_idx = vec_regs. size () - 1 - i;
280
+ if (ignore_vec_regs .find (reg_idx) != ignore_vec_regs .end ()) {
199
281
ignore_registers_count++;
200
282
continue ;
201
283
}
202
-
203
284
if (prev_reg_idx == -1 ) {
204
285
prev_reg_idx = static_cast <int >(reg_idx);
205
286
continue ;
206
287
}
207
-
208
288
h->ldp (Xbyak_aarch64::QReg (reg_idx),
209
289
Xbyak_aarch64::QReg (prev_reg_idx),
210
290
post_ptr (h->sp , get_vec_length () * 2 ));
211
291
prev_reg_idx = -1 ;
212
292
}
213
293
214
- OPENVINO_ASSERT (ignore_registers_count == ignore_registers .size (),
294
+ OPENVINO_ASSERT (ignore_registers_count == ignore_vec_regs .size (),
215
295
" ignored registers size is not equal actual ignored registers count" );
216
296
217
297
// 2. General-purpose Registers
218
298
// 2.1. restore the remaining register
219
- const auto save_gpr_regs_size = store_gpr_regs .size ();
299
+ const auto save_gpr_regs_size = gpr_regs .size ();
220
300
const auto last = save_gpr_regs_size % 2 ;
221
301
if (last != 0 ) {
222
- h->ldr (Xbyak_aarch64::XReg (store_gpr_regs [save_gpr_regs_size - 1 ]),
223
- post_ptr (h->sp , get_gpr_length () * 2 ));
302
+ h->ldr (Xbyak_aarch64::XReg (gpr_regs [save_gpr_regs_size - 1 ]),
303
+ post_ptr (h->sp , get_gpr_length ()));
224
304
}
225
305
226
306
// 2.2. restore pair registers
227
307
for (size_t i = last; i < save_gpr_regs_size; i += 2 ) {
228
- h->ldp (Xbyak_aarch64::XReg (store_gpr_regs [save_gpr_regs_size - 1 - (i + 1 )]),
229
- Xbyak_aarch64::XReg (store_gpr_regs [save_gpr_regs_size - 1 - i]),
308
+ h->ldp (Xbyak_aarch64::XReg (gpr_regs [save_gpr_regs_size - 1 - (i + 1 )]),
309
+ Xbyak_aarch64::XReg (gpr_regs [save_gpr_regs_size - 1 - i]),
230
310
post_ptr (h->sp , get_gpr_length () * 2 ));
231
311
}
232
312
}
0 commit comments