@@ -147,6 +147,8 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {
147
147
148
148
private:
149
149
const size_t xmm_len = 16 ;
150
+ const size_t ymm_len = 32 ;
151
+ const size_t zmm_len = 64 ;
150
152
#ifdef _WIN32
151
153
const size_t xmm_to_preserve_start = 6 ;
152
154
const size_t xmm_to_preserve = 10 ;
@@ -182,6 +184,91 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {
182
184
183
185
inline size_t get_size_of_abi_save_regs () { return size_of_abi_save_regs; }
184
186
187
+ using Xbyak::CodeGenerator::push;
188
+ using Xbyak::CodeGenerator::pop;
189
+
190
+ inline void push (const Xbyak::Xmm &xmm) {
191
+ if (xmm.isXMM ()) {
192
+ sub (rsp, xmm_len);
193
+ uni_vmovdqu (ptr[rsp], xmm);
194
+ } else if (xmm.isYMM ()) {
195
+ sub (rsp, ymm_len);
196
+ uni_vmovdqu (ptr[rsp], Xbyak::Ymm{xmm.getIdx ()});
197
+ } else if (xmm.isZMM ()) {
198
+ sub (rsp, zmm_len);
199
+ uni_vmovdqu (ptr[rsp], Xbyak::Zmm{xmm.getIdx ()});
200
+ }
201
+ }
202
+
203
+ inline void push (const std::vector<Xbyak::Xmm> &xmms) {
204
+ std::vector<std::function<void ()>> deferred_movs{};
205
+ size_t offset = 0 ;
206
+ for (size_t i = 0 ; i < xmms.size (); ++i) {
207
+ const auto & xmm = xmms[i];
208
+ if (xmm.isXMM ()) {
209
+ deferred_movs.emplace_back ([this , offset, &xmm]() {
210
+ uni_vmovdqu (ptr[rsp + offset], xmm);
211
+ });
212
+ offset += xmm_len;
213
+ } else if (xmm.isYMM ()) {
214
+ deferred_movs.emplace_back ([this , offset, &xmm]() {
215
+ uni_vmovdqu (ptr[rsp + offset], Xbyak::Ymm{xmm.getIdx ()});
216
+ });
217
+ offset += ymm_len;
218
+ } else if (xmm.isZMM ()) {
219
+ deferred_movs.emplace_back ([this , offset, &xmm]() {
220
+ uni_vmovdqu (ptr[rsp + offset], Xbyak::Zmm{xmm.getIdx ()});
221
+ });
222
+ offset += zmm_len;
223
+ }
224
+ }
225
+ sub (rsp, offset);
226
+ for (const auto & def_mov : deferred_movs) {
227
+ def_mov ();
228
+ }
229
+ }
230
+
231
+ inline void pop (const Xbyak::Xmm &xmm) {
232
+ if (xmm.isXMM ()) {
233
+ uni_vmovdqu (xmm, ptr[rsp]);
234
+ add (rsp, xmm_len);
235
+ } else if (xmm.isYMM ()) {
236
+ uni_vmovdqu (Xbyak::Ymm{xmm.getIdx ()}, ptr[rsp]);
237
+ add (rsp, ymm_len);
238
+ } else if (xmm.isZMM ()) {
239
+ uni_vmovdqu (Xbyak::Zmm{xmm.getIdx ()}, ptr[rsp]);
240
+ add (rsp, zmm_len);
241
+ }
242
+ }
243
+
244
+ inline void pop (const std::vector<Xbyak::Xmm> &xmms) {
245
+ std::vector<std::function<void ()>> deferred_movs{};
246
+ size_t offset = 0 ;
247
+ for (size_t i = 0 ; i < xmms.size (); ++i) {
248
+ const auto & xmm = xmms[i];
249
+ if (xmm.isXMM ()) {
250
+ deferred_movs.emplace_back ([this , offset, &xmm]() {
251
+ uni_vmovdqu (xmm, ptr[rsp + offset]);
252
+ });
253
+ offset += xmm_len;
254
+ } else if (xmm.isYMM ()) {
255
+ deferred_movs.emplace_back ([this , offset, &xmm]() {
256
+ uni_vmovdqu (Xbyak::Ymm{xmm.getIdx ()}, ptr[rsp + offset]);
257
+ });
258
+ offset += ymm_len;
259
+ } else if (xmm.isZMM ()) {
260
+ deferred_movs.emplace_back ([this , offset, &xmm]() {
261
+ uni_vmovdqu (Xbyak::Zmm{xmm.getIdx ()}, ptr[rsp + offset]);
262
+ });
263
+ offset += zmm_len;
264
+ }
265
+ }
266
+ for (const auto & def_mov : deferred_movs) {
267
+ def_mov ();
268
+ }
269
+ add (rsp, offset);
270
+ }
271
+
185
272
void preamble () {
186
273
if (xmm_to_preserve) {
187
274
sub (rsp, xmm_to_preserve * xmm_len);
0 commit comments