@@ -3332,22 +3332,55 @@ template <typename T>
3332
3332
typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
3333
3333
|| std::is_same<T, Xbyak::Address>::value)>::type
3334
3334
jit_uni_binary_injector_t <isa, Vmm>::execute_prelu_binary(const Vmm &dst,
3335
- const Vmm &lhs, const T &rhs) const {
3336
- // in sse4 vmm_aux0 as mask it's index must be 0
3337
- Vmm vmm_aux0 = Vmm (rhs_arg_static_params_.rhs_prelu_helper_vmm_idx );
3338
-
3339
- if (dst == vmm_aux0) {
3340
- vmm_aux0 = Vmm (14 );
3341
- if (isa == sse41)
3342
- assert (!" conflict mask register" );
3343
- }
3335
+ const Vmm &lhs, const T &rhs) const {
3336
+ if (is_superset (isa, avx)) {
3337
+ host_->uni_vmulps (rhs, rhs, lhs);
3338
+ host_->uni_vblendvps (dst, lhs, rhs, lhs);
3339
+ } else {
3340
+ using dnnl::impl::utils::one_of;
3341
+ // in sse4 vmm_aux0 as mask it's index must be 0
3342
+ Vmm vmm_aux0 = Vmm (rhs_arg_static_params_.rhs_prelu_helper_vmm_idx );
3343
+
3344
+ if (one_of (vmm_aux0, dst, lhs, rhs)) {
3345
+ // let's find a vacant XMM register
3346
+ int occupied_idices[] = {dst.getIdx (), lhs.getIdx (), rhs.getIdx ()};
3347
+
3348
+ int fixup_reg_indx = 14 ;
3349
+ while (std::any_of (std::begin (occupied_idices), std::end (occupied_idices),
3350
+ [&](const int x) {return x == fixup_reg_indx;}) && --fixup_reg_indx > 0 ) {}
3351
+ if (fixup_reg_indx < 0 ) assert (!" couldn't find a vacant XMM reg" );
3352
+
3353
+ vmm_aux0 = Vmm (fixup_reg_indx);
3354
+ }
3344
3355
3345
- push_vmm (host_, vmm_aux0);
3346
- host_->uni_vmulps (rhs, rhs, lhs);
3347
- host_->vpxor (vmm_aux0, vmm_aux0, vmm_aux0);
3348
- host_->vcmpltps (vmm_aux0, lhs, vmm_aux0);
3349
- host_->uni_vblendvps (dst, lhs, rhs, vmm_aux0);
3350
- pop_vmm (host_, vmm_aux0);
3356
+ push_vmm (host_, vmm_aux0);
3357
+
3358
+ auto swap_aux0 = [&](const Vmm ®) {
3359
+ Vmm vmm (reg.getIdx ());
3360
+ host_->vmovups (vmm_aux0, vmm);
3361
+ std::swap (vmm_aux0, vmm);
3362
+ return vmm;
3363
+ };
3364
+
3365
+ const auto aux_orig_indx = vmm_aux0.getIdx ();
3366
+ // if XMM0 is occupied, we swap XMM0 with vmm_aux0 to use XMM0 as the mask register
3367
+ const auto & dst_ = 0 == dst.getIdx () ? swap_aux0 (dst) : dst;
3368
+ const auto & lhs_ = 0 == lhs.getIdx () ? swap_aux0 (lhs) : lhs;
3369
+ const auto & rhs_ = 0 == rhs.getIdx () ? swap_aux0 (rhs) : rhs;
3370
+
3371
+ host_->uni_vmulps (rhs_, rhs_, lhs_);
3372
+ host_->vpxor (vmm_aux0, vmm_aux0, vmm_aux0);
3373
+ host_->vcmpltps (vmm_aux0, lhs_, vmm_aux0);
3374
+ host_->uni_vblendvps (dst_, lhs_, rhs_, vmm_aux0);
3375
+
3376
+ if (aux_orig_indx != 0 ) {
3377
+ auto vmm_aux_orig = Vmm (aux_orig_indx);
3378
+ host_->vmovups (vmm_aux0, vmm_aux_orig); // restore original Xmm0 value
3379
+ std::swap (vmm_aux0, vmm_aux_orig);
3380
+ }
3381
+
3382
+ pop_vmm (host_, vmm_aux0);
3383
+ }
3351
3384
}
3352
3385
3353
3386
template <cpu_isa_t isa, typename Vmm>
0 commit comments