Skip to content

Commit 64e26dd

Browse files
maxnickazhai219
authored andcommitted
[FORK][FIX][x64] Refactor avx2 binary PReLU and fix reg conflicts
1 parent 8ec6ca6 commit 64e26dd

File tree

1 file changed

+48
-15
lines changed

1 file changed

+48
-15
lines changed

src/cpu/x64/injectors/jit_uni_binary_injector.cpp

+48-15
Original file line numberDiff line numberDiff line change
@@ -3332,22 +3332,55 @@ template <typename T>
33323332
typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
33333333
|| std::is_same<T, Xbyak::Address>::value)>::type
33343334
jit_uni_binary_injector_t<isa, Vmm>::execute_prelu_binary(const Vmm &dst,
3335-
const Vmm &lhs, const T &rhs) const {
3336-
// in sse4 vmm_aux0 as mask it's index must be 0
3337-
Vmm vmm_aux0 = Vmm(rhs_arg_static_params_.rhs_prelu_helper_vmm_idx);
3338-
3339-
if (dst == vmm_aux0) {
3340-
vmm_aux0 = Vmm(14);
3341-
if (isa == sse41)
3342-
assert(!"conflict mask register");
3343-
}
3335+
const Vmm &lhs, const T &rhs) const {
3336+
if (is_superset(isa, avx)) {
3337+
host_->uni_vmulps(rhs, rhs, lhs);
3338+
host_->uni_vblendvps(dst, lhs, rhs, lhs);
3339+
} else {
3340+
using dnnl::impl::utils::one_of;
3341+
// in sse4 vmm_aux0 as mask it's index must be 0
3342+
Vmm vmm_aux0 = Vmm(rhs_arg_static_params_.rhs_prelu_helper_vmm_idx);
3343+
3344+
if (one_of(vmm_aux0, dst, lhs, rhs)) {
3345+
//let's find a vacant XMM register
3346+
int occupied_idices[] = {dst.getIdx(), lhs.getIdx(), rhs.getIdx()};
3347+
3348+
int fixup_reg_indx = 14;
3349+
while (std::any_of(std::begin(occupied_idices), std::end(occupied_idices),
3350+
[&](const int x) {return x == fixup_reg_indx;}) && --fixup_reg_indx > 0) {}
3351+
if (fixup_reg_indx < 0) assert(!"couldn't find a vacant XMM reg");
3352+
3353+
vmm_aux0 = Vmm(fixup_reg_indx);
3354+
}
33443355

3345-
push_vmm(host_, vmm_aux0);
3346-
host_->uni_vmulps(rhs, rhs, lhs);
3347-
host_->vpxor(vmm_aux0, vmm_aux0, vmm_aux0);
3348-
host_->vcmpltps(vmm_aux0, lhs, vmm_aux0);
3349-
host_->uni_vblendvps(dst, lhs, rhs, vmm_aux0);
3350-
pop_vmm(host_, vmm_aux0);
3356+
push_vmm(host_, vmm_aux0);
3357+
3358+
auto swap_aux0 = [&](const Vmm &reg) {
3359+
Vmm vmm(reg.getIdx());
3360+
host_->vmovups(vmm_aux0, vmm);
3361+
std::swap(vmm_aux0, vmm);
3362+
return vmm;
3363+
};
3364+
3365+
const auto aux_orig_indx = vmm_aux0.getIdx();
3366+
// if XMM0 is occupied, we swap XMM0 with vmm_aux0 to use XMM0 as the mask register
3367+
const auto& dst_ = 0 == dst.getIdx() ? swap_aux0(dst) : dst;
3368+
const auto& lhs_ = 0 == lhs.getIdx() ? swap_aux0(lhs) : lhs;
3369+
const auto& rhs_ = 0 == rhs.getIdx() ? swap_aux0(rhs) : rhs;
3370+
3371+
host_->uni_vmulps(rhs_, rhs_, lhs_);
3372+
host_->vpxor(vmm_aux0, vmm_aux0, vmm_aux0);
3373+
host_->vcmpltps(vmm_aux0, lhs_, vmm_aux0);
3374+
host_->uni_vblendvps(dst_, lhs_, rhs_, vmm_aux0);
3375+
3376+
if (aux_orig_indx != 0) {
3377+
auto vmm_aux_orig = Vmm(aux_orig_indx);
3378+
host_->vmovups(vmm_aux0, vmm_aux_orig); // restore original Xmm0 value
3379+
std::swap(vmm_aux0, vmm_aux_orig);
3380+
}
3381+
3382+
pop_vmm(host_, vmm_aux0);
3383+
}
33513384
}
33523385

33533386
template <cpu_isa_t isa, typename Vmm>

0 commit comments

Comments
 (0)