Skip to content

Commit 35e1a93

Browse files
antonvorluweizhou2016
authored andcommitted
[FORK][FEATURE] Added prelu as binary post op
1 parent c63e16c commit 35e1a93

18 files changed

+210
-30
lines changed

include/oneapi/dnnl/dnnl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,8 @@ enum class algorithm {
459459
binary_eq = dnnl_binary_eq,
460460
/// Binary not equal
461461
binary_ne = dnnl_binary_ne,
462+
/// Binary prelu
463+
binary_prelu = dnnl_binary_prelu,
462464
/// Nearest Neighbor resampling method
463465
resampling_nearest = dnnl_resampling_nearest,
464466
/// Linear (Bilinear, Trilinear) resampling method

include/oneapi/dnnl/dnnl_types.h

+2
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,8 @@ typedef enum {
21662166
dnnl_binary_eq = 0x1fffa,
21672167
/// Binary not equal
21682168
dnnl_binary_ne = 0x1fffb,
2169+
/// Binary prelu
2170+
dnnl_binary_prelu = 0x1fffc,
21692171
/// Nearest Neighbor Resampling Method
21702172
dnnl_resampling_nearest = 0x2fff0,
21712173
/// Linear Resampling Method

src/common/binary.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ status_t dnnl_binary_primitive_desc_create(
9191
VCHECK_BINARY(
9292
one_of(alg_kind, binary_add, binary_mul, binary_max, binary_min,
9393
binary_div, binary_sub, binary_ge, binary_gt, binary_le,
94-
binary_lt, binary_eq, binary_ne),
94+
binary_lt, binary_eq, binary_ne, binary_prelu),
9595
VERBOSE_BAD_ALGORITHM);
9696
// TODO - Add support for mutual or bi-directional broadcasts
9797
VCHECK_BINARY(!memory_desc_wrapper(src0_md).format_any(),

src/common/c_types_map.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ const alg_kind_t binary_le = dnnl_binary_le;
129129
const alg_kind_t binary_lt = dnnl_binary_lt;
130130
const alg_kind_t binary_eq = dnnl_binary_eq;
131131
const alg_kind_t binary_ne = dnnl_binary_ne;
132+
const alg_kind_t binary_prelu = dnnl_binary_prelu;
132133
const alg_kind_t resampling_nearest = dnnl_resampling_nearest;
133134
const alg_kind_t resampling_linear = dnnl_resampling_linear;
134135
const alg_kind_t reduction_max = dnnl_reduction_max;

src/common/dnnl_debug_autogenerated.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,7 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) {
18251825
if (v == dnnl_binary_lt) return "binary_lt";
18261826
if (v == dnnl_binary_eq) return "binary_eq";
18271827
if (v == dnnl_binary_ne) return "binary_ne";
1828+
if (v == dnnl_binary_prelu) return "binary_prelu";
18281829
if (v == dnnl_resampling_nearest) return "resampling_nearest";
18291830
if (v == dnnl_resampling_linear) return "resampling_linear";
18301831
if (v == dnnl_reduction_max) return "reduction_max";

src/common/primitive_attr.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ status_t post_ops_t::validate_binary(
270270
using namespace alg_kind;
271271
bool alg_ok = one_of(alg, binary_add, binary_mul, binary_max, binary_min,
272272
binary_div, binary_sub, binary_ge, binary_gt, binary_le, binary_lt,
273-
binary_eq, binary_ne);
273+
binary_eq, binary_ne, binary_prelu);
274274
if (!alg_ok) return invalid_arguments;
275275
if (!memory_desc_sanity_check(*user_src1_desc)) return invalid_arguments;
276276

src/cpu/gemm_x8s8s32x_convolution.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t {
7777
| skip_mask_t::post_ops
7878
| skip_mask_t::sum_dt
7979
| primitive_attr_t::skip_mask_t::input_zero_points
80-
| primitive_attr_t::skip_mask_t::output_compensations,
80+
| primitive_attr_t::skip_mask_t::output_compensations
81+
| primitive_attr_t::skip_mask_t::sum_dt,
8182
dst_type),
8283
VERBOSE_UNSUPPORTED_ATTR);
8384

src/cpu/primitive_attr_postops.cpp

+40-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ float compute_binary_scalar(alg_kind_t alg, float x, float y) {
4040
case binary_lt: return x < y;
4141
case binary_eq: return x == y;
4242
case binary_ne: return x != y;
43+
case binary_prelu: return x >= 0 ? x : x * y;
4344
default: assert(!"not supported operation!"); return NAN;
4445
}
4546
}
@@ -139,7 +140,7 @@ ref_binary_scalar_t::ref_binary_scalar_t(alg_kind_t alg) : alg_(alg) {
139140
alg_kind::binary_min, alg_kind::binary_mul, alg_kind::binary_div,
140141
alg_kind::binary_sub, alg_kind::binary_ge, alg_kind::binary_gt,
141142
alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq,
142-
alg_kind::binary_ne));
143+
alg_kind::binary_ne, alg_kind::binary_prelu));
143144
}
144145

145146
ref_binary_scalar_t::ref_binary_scalar_t(
@@ -183,6 +184,8 @@ ref_post_ops_t::ref_post_ops_t(const post_ops_t &po, bool skip_sum)
183184
eltwise_po_.emplace_back(e.eltwise);
184185
} else if (po_.contain(primitive_kind::binary, idx)) {
185186
binary_po_.emplace_back(e.binary);
187+
} else if (po_.contain(primitive_kind::depthwise, idx)) {
188+
depthwise_po_.emplace_back(e.depthwise.alg);
186189
}
187190
}
188191
}
@@ -264,12 +267,13 @@ status_t ref_post_ops_t::init(const memory_desc_t *dst_md) {
264267
return status::success;
265268
}
266269

267-
void ref_post_ops_t::execute(float &res, const args_t &args) const {
270+
void ref_post_ops_t::execute(float &res, const args_t &args, const size_t oc) const {
268271
if (po_.len() == 0) return;
269272

270273
auto it_eltwise_po = eltwise_po_.begin();
271274
auto it_binary_po = binary_po_.begin();
272275
auto it_prelu_md = prelu_md_.begin();
276+
auto it_depthwise_po = depthwise_po_.begin();
273277
for (auto idx = 0; idx < po_.len(); ++idx) {
274278
const auto &e = po_.entry_[idx];
275279
switch (e.kind) {
@@ -330,6 +334,40 @@ void ref_post_ops_t::execute(float &res, const args_t &args) const {
330334
res = weights_value * res;
331335
++it_prelu_md;
332336
} break;
337+
case primitive_kind::depthwise: {
338+
auto depthwise_weights = e.depthwise.weights_data;
339+
auto depthwise_bias = e.depthwise.biases_data;
340+
res = it_depthwise_po->compute_scalar(res, depthwise_weights + oc, depthwise_bias + oc);
341+
++it_depthwise_po;
342+
} break;
343+
case primitive_kind::quantization: {
344+
bool do_dequantization = e.quantization.alg == alg_kind::quantization_quantize_dequantize;
345+
bool do_rounding = do_dequantization || args.dst_md->data_type == dnnl_f32 || idx != po_.len() - 1;
346+
347+
auto quant = e.quantization;
348+
auto pcl = quant.crop_low_data->shifts_;
349+
auto pch = quant.crop_high_data->shifts_;
350+
auto pisc = quant.input_scale_data->scales_;
351+
auto pish = quant.input_shift_data->shifts_;
352+
auto posc = quant.output_scale_data->scales_;
353+
auto posh = quant.output_shift_data->shifts_;
354+
355+
int cl_idx = quant.crop_low_data->count_ == 1 ? 0 : oc;
356+
int ch_idx = quant.crop_high_data->count_ == 1 ? 0 : oc;
357+
int isc_idx = quant.input_scale_data->count_ == 1 ? 0 : oc;
358+
int ish_idx = quant.input_shift_data->count_ == 1 ? 0 : oc;
359+
int osc_idx = quant.output_scale_data->count_ == 1 ? 0 : oc;
360+
int osh_idx = quant.output_shift_data->count_ == 1 ? 0 : oc;
361+
362+
res = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], res));
363+
res = res * pisc[isc_idx] + pish[ish_idx];
364+
365+
if (do_rounding)
366+
res = roundf(res);
367+
368+
if (do_dequantization)
369+
res = res * posc[osc_idx] + posh[osh_idx];
370+
} break;
333371
default: assert(!"unsupported post op primitive kind!");
334372
}
335373
}

src/cpu/primitive_attr_postops.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "common/primitive.hpp"
2323
#include "common/primitive_attr.hpp"
2424

25+
#include "ref_depthwise_injector.hpp"
26+
2527
namespace dnnl {
2628
namespace impl {
2729
namespace cpu {
@@ -71,7 +73,7 @@ struct ref_post_ops_t {
7173

7274
status_t init(const memory_desc_t *dst_md);
7375

74-
void execute(float &res, const args_t &args = args_t()) const;
76+
void execute(float &res, const args_t &args = args_t(), const size_t oc = 0) const;
7577

7678
static bool primitive_kind_ok(const post_ops_t &po) {
7779
using namespace primitive_kind;
@@ -86,6 +88,7 @@ struct ref_post_ops_t {
8688

8789
std::vector<ref_eltwise_scalar_fwd_t> eltwise_po_;
8890
std::vector<ref_binary_scalar_t> binary_po_;
91+
std::vector<ref_depthwise_scalar_fwd_t> depthwise_po_;
8992
std::vector<memory_desc_t> prelu_md_;
9093
};
9194

src/cpu/ref_depthwise_injector.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ ref_depthwise_scalar_fwd_t::ref_depthwise_scalar_fwd_t(const alg_kind_t alg_)
6464
assert(utils::one_of(alg, depthwise_scale_shift, depthwise_prelu));
6565
}
6666

67-
float ref_depthwise_scalar_fwd_t::compute_scalar(float s, const float* weights, const float* bias) {
67+
float ref_depthwise_scalar_fwd_t::compute_scalar(float s, const float* weights, const float* bias) const {
6868
switch (alg) {
6969
case depthwise_scale_shift: return scale_shift_fwd(s, *weights, *bias);
7070
case depthwise_prelu: return prelu_fwd(s, *weights);

src/cpu/ref_depthwise_injector.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace cpu {
2727
struct ref_depthwise_scalar_fwd_t {
2828
public:
2929
explicit ref_depthwise_scalar_fwd_t(alg_kind_t alg);
30-
float compute_scalar(float s, const float* weights, const float* bias);
30+
float compute_scalar(float s, const float* weights, const float* bias) const;
3131

3232
private:
3333
alg_kind_t alg;

src/cpu/x64/injectors/jit_uni_binary_injector.cpp

+46-5
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t(
199199
bool preserve_vmm_helper, std::size_t abi_param_offset,
200200
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
201201
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
202-
bool use_exact_tail_scalar_bcast)
202+
bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx)
203203
: rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg,
204204
rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
205205
preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
206206
tail_size, tail_opmask, use_exact_tail_scalar_bcast, rhs_helper_reg,
207-
true /*is_opmask_set*/) {}
207+
true /*is_opmask_set*/) {
208+
this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx;
209+
}
208210

209211
rhs_arg_static_params_t::rhs_arg_static_params_t(
210212
std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
@@ -213,12 +215,14 @@ rhs_arg_static_params_t::rhs_arg_static_params_t(
213215
bool preserve_vmm_helper, std::size_t abi_param_offset,
214216
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
215217
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
216-
const Xbyak::Reg64 &reg_tail_size, bool use_exact_tail_scalar_bcast)
218+
const Xbyak::Reg64 &reg_tail_size, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx)
217219
: rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg,
218220
rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
219221
preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
220222
tail_size, tail_opmask, use_exact_tail_scalar_bcast, reg_tail_size,
221-
true /*is_opmask_set*/) {}
223+
true /*is_opmask_set*/) {
224+
this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx;
225+
}
222226

223227
rhs_arg_static_params_t::rhs_arg_static_params_t(
224228
std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
@@ -2295,7 +2299,7 @@ void jit_uni_binary_injector_t<isa, Vmm>::inject_binary(
22952299
= rhs_arg_data_type != data_type::f32 || (scalar_f32 && !is_avx512_)
22962300
|| with_tail_not_fusable_to_binary_op
22972301
|| !binary_op_with_unaligned_mem_operand_allowed_
2298-
|| (cmp_op && !is_avx512_);
2302+
|| ((cmp_op || alg == alg_kind::binary_prelu) && !is_avx512_);
22992303

23002304
if (process_rhs_arg_using_tmp_vmm) {
23012305

@@ -3192,6 +3196,23 @@ jit_uni_binary_injector_t<isa, Vmm>::execute_cmp_binary(const Vmm &dst,
31923196
pop_opmask(host_, cmp_mask);
31933197
}
31943198

3199+
template <cpu_isa_t isa, typename Vmm>
3200+
template <typename T>
3201+
typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value
3202+
|| std::is_same<T, Xbyak::Address>::value>::type
3203+
jit_uni_binary_injector_t<isa, Vmm>::execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const {
3204+
const auto &cmp_mask = rhs_arg_static_params_.tail_opmask;
3205+
const Xbyak::Zmm zmm_aux0
3206+
= Xbyak::Zmm(rhs_arg_static_params_.rhs_prelu_helper_vmm_idx);
3207+
3208+
push_opmask(host_, cmp_mask);
3209+
host_->uni_vpxor(zmm_aux0, zmm_aux0, zmm_aux0);
3210+
host_->vcmpps(cmp_mask, lhs, zmm_aux0, jit_generator::_cmp_lt_os);
3211+
host_->uni_vmulps(dst | cmp_mask, lhs, rhs);
3212+
pop_opmask(host_, cmp_mask);
3213+
}
3214+
3215+
31953216
// SSE4.1., AVX and AVX2 implementation
31963217
template <cpu_isa_t isa, typename Vmm>
31973218
template <typename T>
@@ -3211,6 +3232,23 @@ jit_uni_binary_injector_t<isa, Vmm>::execute_cmp_binary(const Vmm &dst,
32113232
host_->uni_vminps(dst, dst, vreg_one);
32123233
}
32133234

3235+
// todo: [antonvor] check sse41 path
3236+
template <cpu_isa_t isa, typename Vmm>
3237+
template <typename T>
3238+
typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
3239+
|| std::is_same<T, Xbyak::Address>::value)>::type
3240+
jit_uni_binary_injector_t<isa, Vmm>::execute_prelu_binary(const Vmm &dst,
3241+
const Vmm &lhs, const T &rhs) const {
3242+
const Vmm vmm_aux0 = Vmm(rhs_arg_static_params_.rhs_prelu_helper_vmm_idx);
3243+
3244+
push_vmm(host_, vmm_aux0);
3245+
host_->uni_vmulps(rhs, rhs, lhs);
3246+
host_->vpxor(vmm_aux0, vmm_aux0, vmm_aux0);
3247+
host_->vcmpltps(vmm_aux0, lhs, vmm_aux0);
3248+
host_->uni_vblendvps(dst, lhs, rhs, vmm_aux0);
3249+
pop_vmm(host_, vmm_aux0);
3250+
}
3251+
32143252
template <cpu_isa_t isa, typename Vmm>
32153253
template <typename T>
32163254
void jit_uni_binary_injector_t<isa, Vmm>::execute_binary(alg_kind_t binary_alg,
@@ -3240,6 +3278,9 @@ void jit_uni_binary_injector_t<isa, Vmm>::execute_binary(alg_kind_t binary_alg,
32403278
case alg_kind::binary_ne:
32413279
execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_neq_uq);
32423280
break;
3281+
case alg_kind::binary_prelu:
3282+
execute_prelu_binary(dst, lhs, rhs);
3283+
break;
32433284
default: assert(!"unsupported algorithm");
32443285
}
32453286
}

src/cpu/x64/injectors/jit_uni_binary_injector.hpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct rhs_arg_static_params_t {
110110
bool preserve_vmm_helper, std::size_t abi_param_offset,
111111
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
112112
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
113-
bool use_exact_tail_scalar_bcast);
113+
bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0);
114114
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
115115
const Xbyak::Reg64 &rhs_addr_reg,
116116
const Xbyak::Reg64 &rhs_helper_reg,
@@ -119,7 +119,7 @@ struct rhs_arg_static_params_t {
119119
std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
120120
std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
121121
const Xbyak::Reg64 &reg_tail_size,
122-
bool use_exact_tail_scalar_bcast);
122+
bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0);
123123

124124
bool is_opmask_set() const noexcept { return is_opmask_set_; }
125125

@@ -138,6 +138,8 @@ struct rhs_arg_static_params_t {
138138
Xbyak::Reg64 reg_tail_size;
139139
bool is_tail;
140140

141+
mutable std::size_t rhs_prelu_helper_vmm_idx;
142+
141143
private:
142144
rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
143145
const Xbyak::Reg64 &rhs_addr_reg,
@@ -509,11 +511,19 @@ class jit_uni_binary_injector_t {
509511
execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs,
510512
const unsigned int cmp_predicate) const;
511513
template <typename T>
514+
typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value
515+
|| std::is_same<T, Xbyak::Address>::value>::type
516+
execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const;
517+
template <typename T>
512518
typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
513519
|| std::is_same<T, Xbyak::Address>::value)>::type
514520
execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs,
515521
const unsigned int cmp_predicate) const;
516522
template <typename T>
523+
typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
524+
|| std::is_same<T, Xbyak::Address>::value)>::type
525+
execute_prelu_binary(const Vmm &dst, const Vmm &lhs, const T &rhs) const;
526+
template <typename T>
517527
void execute_binary(alg_kind_t binary_alg, const Vmm &dst, const Vmm &lhs,
518528
const T &rhs) const;
519529
void execute_prelu(const Vmm &dst, const Xbyak::Operand &rhs) const;

0 commit comments

Comments
 (0)