Skip to content

Commit 60a8893

Browse files
a-sidorovaazhai219
authored andcommitted
[FORK][FEATURE] Added support of hsigmoid, round_half_to_even, round_half_away_from_zero elementwise algorithms
Note: hsigmoid implementation probably can be removed since there is hardsigmoid implementation
1 parent 5409694 commit 60a8893

11 files changed

+125
-9
lines changed

include/oneapi/dnnl/dnnl.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,12 @@ enum class algorithm {
406406
eltwise_hardswish = dnnl_eltwise_hardswish,
407407
/// Elementwise: hardsigmoid
408408
eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid,
409+
/// Elementwise: hsigmoid
410+
eltwise_hsigmoid = dnnl_eltwise_hsigmoid,
411+
/// Elementwise: round_half_to_even
412+
eltwise_round_half_to_even = dnnl_eltwise_round_half_to_even,
413+
/// Elementwise: round_half_away_from_zero
414+
eltwise_round_half_away_from_zero = dnnl_eltwise_round_half_away_from_zero,
409415
/// Elementwise: rectified linar unit (ReLU) (dst for backward)
410416
eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
411417
/// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)

include/oneapi/dnnl/dnnl_types.h

+6
Original file line numberDiff line numberDiff line change
@@ -2088,6 +2088,12 @@ typedef enum {
20882088
dnnl_eltwise_mish,
20892089
/// Eltwise: hardswish
20902090
dnnl_eltwise_hardswish,
2091+
/// Eltwise: hsigmoid
2092+
dnnl_eltwise_hsigmoid,
2093+
/// Eltwise: round_half_to_even
2094+
dnnl_eltwise_round_half_to_even,
2095+
/// Eltwise: round_half_away_from_zero
2096+
dnnl_eltwise_round_half_away_from_zero,
20912097
/// Eltwise: ReLU (dst for backward)
20922098
dnnl_eltwise_relu_use_dst_for_bwd = 0x100,
20932099
/// Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward)

src/common/c_types_map.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ const alg_kind_t eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh;
9090
const alg_kind_t eltwise_gelu_erf = dnnl_eltwise_gelu_erf;
9191
const alg_kind_t eltwise_hardswish = dnnl_eltwise_hardswish;
9292
const alg_kind_t eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid;
93+
const alg_kind_t eltwise_hsigmoid = dnnl_eltwise_hsigmoid;
94+
const alg_kind_t eltwise_round_half_to_even = dnnl_eltwise_round_half_to_even;
95+
const alg_kind_t eltwise_round_half_away_from_zero = dnnl_eltwise_round_half_away_from_zero;
9396
const alg_kind_t eltwise_relu_use_dst_for_bwd
9497
= dnnl_eltwise_relu_use_dst_for_bwd;
9598
const alg_kind_t eltwise_tanh_use_dst_for_bwd

src/common/dnnl_debug_autogenerated.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1789,6 +1789,9 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) {
17891789
if (v == dnnl_eltwise_round) return "eltwise_round";
17901790
if (v == dnnl_eltwise_mish) return "eltwise_mish";
17911791
if (v == dnnl_eltwise_hardswish) return "eltwise_hardswish";
1792+
if (v == dnnl_eltwise_hsigmoid) return "eltwise_hsigmoid";
1793+
if (v == dnnl_eltwise_round_half_to_even) return "eltwise_round_half_to_even";
1794+
if (v == dnnl_eltwise_round_half_away_from_zero) return "eltwise_round_half_away_from_zero";
17921795
if (v == dnnl_eltwise_relu_use_dst_for_bwd) return "eltwise_relu_use_dst_for_bwd";
17931796
if (v == dnnl_eltwise_tanh_use_dst_for_bwd) return "eltwise_tanh_use_dst_for_bwd";
17941797
if (v == dnnl_eltwise_elu_use_dst_for_bwd) return "eltwise_elu_use_dst_for_bwd";

src/common/eltwise.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
5757
VCHECK_ELTWISE(
5858
IMPLICATION(!is_fwd, !any_null(diff_src_desc, diff_dst_desc)),
5959
VERBOSE_NULL_ARG);
60-
VCHECK_ELTWISE(IMPLICATION(alg_kind == eltwise_round, is_fwd),
60+
VCHECK_ELTWISE(IMPLICATION(one_of(alg_kind, eltwise_round, eltwise_hsigmoid,
61+
eltwise_round_half_away_from_zero, eltwise_round_half_to_even), is_fwd),
6162
VERBOSE_BAD_PROPKIND);
6263
VCHECK_ELTWISE(
6364
IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any()),

src/common/eltwise_pd.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ struct eltwise_fwd_pd_t : public eltwise_pd_t {
158158
return one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
159159
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_swish,
160160
eltwise_gelu_tanh, eltwise_gelu_erf, eltwise_round,
161-
eltwise_hardswish)
161+
eltwise_hardswish, eltwise_round_half_away_from_zero, eltwise_round_half_to_even)
162162
|| one_of(alg, eltwise_relu_use_dst_for_bwd,
163163
eltwise_tanh_use_dst_for_bwd,
164164
eltwise_elu_use_dst_for_bwd,

src/common/math_utils.hpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,31 @@ inline U hardswish_bwd(T dd, T s, A alpha, A beta) {
414414
return v <= 0.f ? 0.f : v >= 1.f ? dd : dd * w;
415415
}
416416

417+
template <typename T,
418+
typename U = typename utils::remove_reference<T>::type>
419+
inline U hsigmoid_fwd(T s) {
420+
float v = s + 3.0f;
421+
v = v > 0.0f ? v : 0.0f;
422+
v = v < 6.0f ? v : 6.0f;
423+
return (U)(v / 6.0f);
424+
}
425+
426+
template <typename T,
427+
typename U = typename utils::remove_reference<T>::type>
428+
inline U round_half_to_even_fwd(T s) {
429+
float r = ::roundf((float)s);
430+
float d = (float)s - r;
431+
float remainder = ::fmodf(r, 2.0f);
432+
return ((d != 0.5f) && (d != -0.5f)) || (remainder == 0.0f) ? (U)r :
433+
(U)((float)s + d);
434+
}
435+
436+
template <typename T,
437+
typename U = typename utils::remove_reference<T>::type>
438+
inline U round_half_away_from_zero_fwd(T s) {
439+
return (U)(::roundf((float)s));
440+
}
441+
417442
inline bool is_eltwise_ok(
418443
data_type_t src_dt, alg_kind_t alg, float alpha, float beta) {
419444
using namespace alg_kind;
@@ -426,7 +451,8 @@ inline bool is_eltwise_ok(
426451
eltwise_exp, eltwise_gelu_tanh, eltwise_hardsigmoid,
427452
eltwise_hardswish, eltwise_swish, eltwise_log,
428453
eltwise_clip, eltwise_clip_v2, eltwise_pow,
429-
eltwise_gelu_erf, eltwise_round)
454+
eltwise_gelu_erf, eltwise_round,
455+
eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even)
430456
&& IMPLICATION(
431457
one_of(alg, eltwise_clip, eltwise_clip_v2), beta >= alpha)
432458
&& IMPLICATION(alg == eltwise_round, src_dt == dnnl_f32)

src/cpu/primitive_attr_postops.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ float compute_eltwise_scalar_fwd(
6969
case eltwise_mish: d = mish_fwd(s); break;
7070
case eltwise_hardsigmoid: d = hardsigmoid_fwd(s, alpha, beta); break;
7171
case eltwise_hardswish: d = hardswish_fwd(s, alpha, beta); break;
72+
case eltwise_hsigmoid: d = hsigmoid_fwd(s); break;
73+
case eltwise_round_half_away_from_zero: d = round_half_away_from_zero_fwd(s); break;
74+
case eltwise_round_half_to_even: d = round_half_to_even_fwd(s); break;
7275
case eltwise_relu_use_dst_for_bwd: d = relu_fwd(s, alpha); break;
7376
case eltwise_tanh_use_dst_for_bwd: d = tanh_fwd(s); break;
7477
case eltwise_elu_use_dst_for_bwd: d = elu_fwd(s, alpha); break;
@@ -155,11 +158,12 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
155158
eltwise_soft_relu, eltwise_mish, eltwise_logistic, eltwise_exp,
156159
eltwise_gelu_tanh, eltwise_swish, eltwise_log, eltwise_clip,
157160
eltwise_clip_v2, eltwise_pow, eltwise_gelu_erf, eltwise_round,
158-
eltwise_hardsigmoid, eltwise_hardswish,
159-
eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd,
160-
eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd,
161-
eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd,
162-
eltwise_clip_v2_use_dst_for_bwd));
161+
eltwise_hardswish, eltwise_hardsigmoid,
162+
eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even,
163+
eltwise_relu_use_dst_for_bwd,
164+
eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd,
165+
eltwise_sqrt_use_dst_for_bwd, eltwise_logistic_use_dst_for_bwd,
166+
eltwise_exp_use_dst_for_bwd, eltwise_clip_v2_use_dst_for_bwd));
163167
}
164168

165169
ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(

src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp

+61
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ bool is_alg_supported(alg_kind_t alg) {
4141
eltwise_gelu_tanh, eltwise_hardsigmoid, eltwise_hardswish,
4242
eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2,
4343
eltwise_pow, eltwise_gelu_erf, eltwise_round,
44+
eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even,
4445
eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd,
4546
eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd,
4647
eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd,
@@ -1833,6 +1834,49 @@ size_t jit_uni_eltwise_injector<isa, Wmm>::op_vecs_count(
18331834
return ret;
18341835
}
18351836

1837+
template <cpu_isa_t isa, typename Wmm>
1838+
void jit_uni_eltwise_injector<isa, Wmm>::hsigmoid_compute_vector_fwd(
1839+
const Vmm &vmm_src) {
1840+
// x + 3
1841+
h->uni_vaddps(vmm_src, vmm_src, table_val(hsigmoid, 0));
1842+
// relu6(x + 3)
1843+
h->uni_vmaxps(vmm_src, vmm_src, table_val(zero));
1844+
h->uni_vminps(vmm_src, vmm_src, table_val(hsigmoid, 1));
1845+
// relu6(x + 3) / 6
1846+
h->uni_vmulps(vmm_src, vmm_src, table_val(hsigmoid, 2));
1847+
}
1848+
1849+
template <cpu_isa_t isa, typename Wmm>
1850+
void jit_uni_eltwise_injector<isa, Wmm>::round_half_to_even_compute_vector_fwd(
1851+
const Vmm &vmm_src) {
1852+
h->uni_vroundps(vmm_src, vmm_src, _op_near);
1853+
}
1854+
1855+
template <cpu_isa_t isa, typename Wmm>
1856+
void jit_uni_eltwise_injector<isa, Wmm>::round_half_away_from_zero_compute_vector_fwd(
1857+
const Vmm &vmm_src) {
1858+
// create a mask of negative numbers for later returning sign
1859+
compute_cmp_mask(vmm_src, table_val(zero), _cmp_lt_os);
1860+
1861+
// round half away from zero for positive numbers
1862+
h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask));
1863+
h->uni_vaddps(vmm_src, vmm_src, table_val(half));
1864+
h->uni_vroundps(vmm_src, vmm_src, _op_floor);
1865+
1866+
// return a sign for negative numbers using the mask
1867+
if (isa == sse41) {
1868+
h->movups(vmm_aux(1), vmm_src);
1869+
h->mulps(vmm_aux(1), table_val(minus_one));
1870+
h->blendvps(vmm_src, vmm_aux(1));
1871+
} else if (isa == avx2) {
1872+
h->vmulps(vmm_aux(1), vmm_src, table_val(minus_one));
1873+
h->vblendvps(vmm_src, vmm_src, vmm_aux(1), vmm_mask_);
1874+
} else if (isa == avx512_core) {
1875+
h->vmulps(vmm_aux(1), vmm_src, table_val(minus_one));
1876+
h->vblendmps(vmm_src | k_mask_, vmm_src, vmm_aux(1));
1877+
}
1878+
}
1879+
18361880
template <cpu_isa_t isa, typename Wmm>
18371881
size_t jit_uni_eltwise_injector<isa, Wmm>::aux_vecs_count(
18381882
alg_kind_t alg, bool is_fwd, float alpha) {
@@ -1873,6 +1917,9 @@ size_t jit_uni_eltwise_injector<isa, Wmm>::aux_vecs_count(
18731917
case eltwise_round: n_vmms = 0; break;
18741918
case eltwise_hardswish: n_vmms = 1; break;
18751919
case eltwise_hardsigmoid: n_vmms = 0; break;
1920+
case eltwise_hsigmoid: n_vmms = 0; break;
1921+
case eltwise_round_half_to_even: n_vmms = 0; break;
1922+
case eltwise_round_half_away_from_zero: n_vmms = 2; break;
18761923
default: assert(!"unsupported eltwise algorithm");
18771924
}
18781925
} else {
@@ -2042,6 +2089,9 @@ void jit_uni_eltwise_injector<isa, Wmm>::compute_body(
20422089
case eltwise_hardsigmoid:
20432090
hardsigmoid_compute_vector_fwd(Vmm(idx));
20442091
break;
2092+
case eltwise_hsigmoid: hsigmoid_compute_vector_fwd(Vmm(idx)); break;
2093+
case eltwise_round_half_to_even: round_half_to_even_compute_vector_fwd(Vmm(idx)); break;
2094+
case eltwise_round_half_away_from_zero: round_half_away_from_zero_compute_vector_fwd(Vmm(idx)); break;
20452095
default: assert(!"unsupported eltwise algorithm");
20462096
}
20472097
} else {
@@ -2826,6 +2876,13 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
28262876
{0xc2b00f34, true}}, // 63: -88.029693603515625
28272877
};
28282878

2879+
// hsigmoid(x) polynomial approximation
2880+
static const table_t hsigmoid_values {
2881+
{hsigmoid, {0x40400000, true}}, // 3
2882+
{hsigmoid, {0x40C00000, true}}, // 6
2883+
{hsigmoid, {0x3e2aaaaa, true}}, // 1 / 6
2884+
};
2885+
28292886
// This object takes care about which constants and polynomials to include.
28302887
struct need_t {
28312888
need_t(alg_kind_t alg) {
@@ -2845,6 +2902,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
28452902
case eltwise_mish: mish_ = true; break;
28462903
case eltwise_tanh_use_dst_for_bwd:
28472904
case eltwise_tanh: tanh_ = true; break;
2905+
case eltwise_hsigmoid: hsigmoid_ = true; break;
28482906
default: break;
28492907
}
28502908
}
@@ -2856,6 +2914,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
28562914
bool gelu_tanh_ = false;
28572915
bool gelu_erf_ = false;
28582916
bool log_ = false;
2917+
bool hsigmoid_ = false;
28592918

28602919
bool exp() const { return exp_ || soft_relu_ || gelu_erf_ || mish_; }
28612920
bool mish() const { return mish_; }
@@ -2864,6 +2923,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
28642923
bool gelu_tanh() const { return gelu_tanh_; }
28652924
bool gelu_erf() const { return gelu_erf_; }
28662925
bool log() const { return log_; }
2926+
bool hsigmoid() const { return hsigmoid_; }
28672927
};
28682928

28692929
need_t need(alg_);
@@ -2903,6 +2963,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
29032963
if (need.log()) push_entries_of(log_consts);
29042964
if (need.log()) push_entries_of(log_polynomial);
29052965
if (need.log()) push_entries_of(log_predefined_values);
2966+
if (need.hsigmoid()) push_entries_of(hsigmoid_values);
29062967

29072968
// Now that we registered the entries, we set the offsets. No
29082969
// entries should be registered after this point. This allows to

src/cpu/x64/injectors/jit_uni_eltwise_injector.hpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ struct jit_uni_eltwise_injector {
174174
_cmp_ge_os = jit_generator::_cmp_nlt_us,
175175
_cmp_gt_os = jit_generator::_cmp_nle_us,
176176
_op_floor = jit_generator::_op_floor,
177-
_op_mxcsr = jit_generator::_op_mxcsr
177+
_op_mxcsr = jit_generator::_op_mxcsr,
178+
_op_near = jit_generator::_op_near
178179
};
179180

180181
const bool is_avx512_ = is_superset(isa, avx512_core);
@@ -245,6 +246,9 @@ struct jit_uni_eltwise_injector {
245246
void round_compute_vector_fwd(const Vmm &vmm_src);
246247
void hardswish_compute_vector_fwd(const Vmm &vmm_src);
247248
void hardsigmoid_compute_vector_fwd(const Vmm &vmm_src);
249+
void hsigmoid_compute_vector_fwd(const Vmm &vmm_src);
250+
void round_half_to_even_compute_vector_fwd(const Vmm &vmm_src);
251+
void round_half_away_from_zero_compute_vector_fwd(const Vmm &vmm_src);
248252

249253
void exp_compute_vector_bwd(const Vmm &vmm_src);
250254
void relu_compute_vector_bwd(const Vmm &vmm_src);
@@ -324,6 +328,7 @@ struct jit_uni_eltwise_injector {
324328
log_five_bit_offset, // 5 bits off (31 = 2^5 - 1)
325329
log_pol, // see correspondent table for float values
326330
log_predefined_vals, // see correspondent table for float values
331+
hsigmoid, // hsigmoid
327332
undef_key,
328333
};
329334

src/cpu/x64/jit_generator.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class jit_generator : public Xbyak::MmapAllocator,
186186
_cmp_nlt_us = 5u,
187187
_cmp_nle_us = 6u,
188188

189+
_op_near = 0u,
189190
_op_floor = 1u,
190191
_op_mxcsr = 4u,
191192
};

0 commit comments

Comments
 (0)