@@ -199,12 +199,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t(
199
199
bool preserve_vmm_helper, std::size_t abi_param_offset,
200
200
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
201
201
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
202
- bool use_exact_tail_scalar_bcast)
202
+ bool use_exact_tail_scalar_bcast, std:: size_t rhs_prelu_helper_vmm_idx )
203
203
: rhs_arg_static_params_t (rhs_dt_helper_vmm_idx, rhs_addr_reg,
204
204
rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
205
205
preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
206
206
tail_size, tail_opmask, use_exact_tail_scalar_bcast, rhs_helper_reg,
207
- true /* is_opmask_set*/ ) {}
207
+ true /* is_opmask_set*/ ) {
208
+ this ->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx;
209
+ }
208
210
209
211
rhs_arg_static_params_t::rhs_arg_static_params_t (
210
212
std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
@@ -213,12 +215,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t(
213
215
bool preserve_vmm_helper, std::size_t abi_param_offset,
214
216
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
215
217
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
216
- const Xbyak::Reg64 ®_tail_size, bool use_exact_tail_scalar_bcast)
218
+ const Xbyak::Reg64 ®_tail_size, bool use_exact_tail_scalar_bcast, std:: size_t rhs_prelu_helper_vmm_idx )
217
219
: rhs_arg_static_params_t (rhs_dt_helper_vmm_idx, rhs_addr_reg,
218
220
rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
219
221
preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
220
222
tail_size, tail_opmask, use_exact_tail_scalar_bcast, reg_tail_size,
221
- true /* is_opmask_set*/ ) {}
223
+ true /* is_opmask_set*/ ) {
224
+ this ->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx;
225
+ }
222
226
223
227
rhs_arg_static_params_t::rhs_arg_static_params_t (
224
228
std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
@@ -2295,7 +2299,7 @@ void jit_uni_binary_injector_t<isa, Vmm>::inject_binary(
2295
2299
= rhs_arg_data_type != data_type::f32 || (scalar_f32 && !is_avx512_)
2296
2300
|| with_tail_not_fusable_to_binary_op
2297
2301
|| !binary_op_with_unaligned_mem_operand_allowed_
2298
- || (cmp_op && !is_avx512_);
2302
+ || (( cmp_op || alg == alg_kind::binary_prelu) && !is_avx512_);
2299
2303
2300
2304
if (process_rhs_arg_using_tmp_vmm) {
2301
2305
@@ -3192,6 +3196,23 @@ jit_uni_binary_injector_t<isa, Vmm>::execute_cmp_binary(const Vmm &dst,
3192
3196
pop_opmask (host_, cmp_mask);
3193
3197
}
3194
3198
3199
+ template <cpu_isa_t isa, typename Vmm>
3200
+ template <typename T>
3201
+ typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value
3202
+ || std::is_same<T, Xbyak::Address>::value>::type
3203
+ jit_uni_binary_injector_t <isa, Vmm>::execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const {
3204
+ const auto &cmp_mask = rhs_arg_static_params_.tail_opmask ;
3205
+ const Xbyak::Zmm zmm_aux0
3206
+ = Xbyak::Zmm (rhs_arg_static_params_.rhs_prelu_helper_vmm_idx );
3207
+
3208
+ push_opmask (host_, cmp_mask);
3209
+ host_->uni_vpxor (zmm_aux0, zmm_aux0, zmm_aux0);
3210
+ host_->vcmpps (cmp_mask, lhs, zmm_aux0, jit_generator::_cmp_lt_os);
3211
+ host_->uni_vmulps (dst | cmp_mask, lhs, rhs);
3212
+ pop_opmask (host_, cmp_mask);
3213
+ }
3214
+
3215
+
3195
3216
// SSE4.1., AVX and AVX2 implementation
3196
3217
template <cpu_isa_t isa, typename Vmm>
3197
3218
template <typename T>
@@ -3211,6 +3232,23 @@ jit_uni_binary_injector_t<isa, Vmm>::execute_cmp_binary(const Vmm &dst,
3211
3232
host_->uni_vminps (dst, dst, vreg_one);
3212
3233
}
3213
3234
3235
+ // todo: [antonvor] check sse41 path
3236
+ template <cpu_isa_t isa, typename Vmm>
3237
+ template <typename T>
3238
+ typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
3239
+ || std::is_same<T, Xbyak::Address>::value)>::type
3240
+ jit_uni_binary_injector_t <isa, Vmm>::execute_prelu_binary(const Vmm &dst,
3241
+ const Vmm &lhs, const T &rhs) const {
3242
+ const Vmm vmm_aux0 = Vmm (rhs_arg_static_params_.rhs_prelu_helper_vmm_idx );
3243
+
3244
+ push_vmm (host_, vmm_aux0);
3245
+ host_->uni_vmulps (rhs, rhs, lhs);
3246
+ host_->vpxor (vmm_aux0, vmm_aux0, vmm_aux0);
3247
+ host_->vcmpltps (vmm_aux0, lhs, vmm_aux0);
3248
+ host_->uni_vblendvps (dst, lhs, rhs, vmm_aux0);
3249
+ pop_vmm (host_, vmm_aux0);
3250
+ }
3251
+
3214
3252
template <cpu_isa_t isa, typename Vmm>
3215
3253
template <typename T>
3216
3254
void jit_uni_binary_injector_t <isa, Vmm>::execute_binary(alg_kind_t binary_alg,
@@ -3240,6 +3278,9 @@ void jit_uni_binary_injector_t<isa, Vmm>::execute_binary(alg_kind_t binary_alg,
3240
3278
case alg_kind::binary_ne:
3241
3279
execute_cmp_binary (dst, lhs, rhs, jit_generator::_cmp_neq_uq);
3242
3280
break ;
3281
+ case alg_kind::binary_prelu:
3282
+ execute_prelu_binary (dst, lhs, rhs);
3283
+ break ;
3243
3284
default : assert (!" unsupported algorithm" );
3244
3285
}
3245
3286
}
0 commit comments