@@ -41,6 +41,7 @@ bool is_alg_supported(alg_kind_t alg) {
41
41
eltwise_gelu_tanh, eltwise_hardsigmoid, eltwise_hardswish,
42
42
eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2,
43
43
eltwise_pow, eltwise_gelu_erf, eltwise_round,
44
+ eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even,
44
45
eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd,
45
46
eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd,
46
47
eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd,
@@ -1833,6 +1834,49 @@ size_t jit_uni_eltwise_injector<isa, Wmm>::op_vecs_count(
1833
1834
return ret;
1834
1835
}
1835
1836
1837
+ template <cpu_isa_t isa, typename Wmm>
1838
+ void jit_uni_eltwise_injector<isa, Wmm>::hsigmoid_compute_vector_fwd(
1839
+ const Vmm &vmm_src) {
1840
+ // x + 3
1841
+ h->uni_vaddps (vmm_src, vmm_src, table_val (hsigmoid, 0 ));
1842
+ // relu6(x + 3)
1843
+ h->uni_vmaxps (vmm_src, vmm_src, table_val (zero));
1844
+ h->uni_vminps (vmm_src, vmm_src, table_val (hsigmoid, 1 ));
1845
+ // relu6(x + 3) / 6
1846
+ h->uni_vmulps (vmm_src, vmm_src, table_val (hsigmoid, 2 ));
1847
+ }
1848
+
1849
+ template <cpu_isa_t isa, typename Wmm>
1850
+ void jit_uni_eltwise_injector<isa, Wmm>::round_half_to_even_compute_vector_fwd(
1851
+ const Vmm &vmm_src) {
1852
+ h->uni_vroundps (vmm_src, vmm_src, _op_near);
1853
+ }
1854
+
1855
+ template <cpu_isa_t isa, typename Wmm>
1856
+ void jit_uni_eltwise_injector<isa, Wmm>::round_half_away_from_zero_compute_vector_fwd(
1857
+ const Vmm &vmm_src) {
1858
+ // create a mask of negative numbers for later returning sign
1859
+ compute_cmp_mask (vmm_src, table_val (zero), _cmp_lt_os);
1860
+
1861
+ // round half away from zero for positive numbers
1862
+ h->uni_vandps (vmm_src, vmm_src, table_val (positive_mask));
1863
+ h->uni_vaddps (vmm_src, vmm_src, table_val (half));
1864
+ h->uni_vroundps (vmm_src, vmm_src, _op_floor);
1865
+
1866
+ // return a sign for negative numbers using the mask
1867
+ if (isa == sse41) {
1868
+ h->movups (vmm_aux (1 ), vmm_src);
1869
+ h->mulps (vmm_aux (1 ), table_val (minus_one));
1870
+ h->blendvps (vmm_src, vmm_aux (1 ));
1871
+ } else if (isa == avx2) {
1872
+ h->vmulps (vmm_aux (1 ), vmm_src, table_val (minus_one));
1873
+ h->vblendvps (vmm_src, vmm_src, vmm_aux (1 ), vmm_mask_);
1874
+ } else if (isa == avx512_core) {
1875
+ h->vmulps (vmm_aux (1 ), vmm_src, table_val (minus_one));
1876
+ h->vblendmps (vmm_src | k_mask_, vmm_src, vmm_aux (1 ));
1877
+ }
1878
+ }
1879
+
1836
1880
template <cpu_isa_t isa, typename Wmm>
1837
1881
size_t jit_uni_eltwise_injector<isa, Wmm>::aux_vecs_count(
1838
1882
alg_kind_t alg, bool is_fwd, float alpha) {
@@ -1873,6 +1917,9 @@ size_t jit_uni_eltwise_injector<isa, Wmm>::aux_vecs_count(
1873
1917
case eltwise_round: n_vmms = 0 ; break ;
1874
1918
case eltwise_hardswish: n_vmms = 1 ; break ;
1875
1919
case eltwise_hardsigmoid: n_vmms = 0 ; break ;
1920
+ case eltwise_hsigmoid: n_vmms = 0 ; break ;
1921
+ case eltwise_round_half_to_even: n_vmms = 0 ; break ;
1922
+ case eltwise_round_half_away_from_zero: n_vmms = 2 ; break ;
1876
1923
default : assert (!" unsupported eltwise algorithm" );
1877
1924
}
1878
1925
} else {
@@ -2042,6 +2089,9 @@ void jit_uni_eltwise_injector<isa, Wmm>::compute_body(
2042
2089
case eltwise_hardsigmoid:
2043
2090
hardsigmoid_compute_vector_fwd (Vmm (idx));
2044
2091
break ;
2092
+ case eltwise_hsigmoid: hsigmoid_compute_vector_fwd (Vmm (idx)); break ;
2093
+ case eltwise_round_half_to_even: round_half_to_even_compute_vector_fwd (Vmm (idx)); break ;
2094
+ case eltwise_round_half_away_from_zero: round_half_away_from_zero_compute_vector_fwd (Vmm (idx)); break ;
2045
2095
default : assert (!" unsupported eltwise algorithm" );
2046
2096
}
2047
2097
} else {
@@ -2826,6 +2876,13 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
2826
2876
{0xc2b00f34 , true }}, // 63: -88.029693603515625
2827
2877
};
2828
2878
2879
+ // hsigmoid(x) polynomial approximation
2880
+ static const table_t hsigmoid_values {
2881
+ {hsigmoid, {0x40400000 , true }}, // 3
2882
+ {hsigmoid, {0x40C00000 , true }}, // 6
2883
+ {hsigmoid, {0x3e2aaaaa , true }}, // 1 / 6
2884
+ };
2885
+
2829
2886
// This object takes care about which constants and polynomials to include.
2830
2887
struct need_t {
2831
2888
need_t (alg_kind_t alg) {
@@ -2845,6 +2902,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
2845
2902
case eltwise_mish: mish_ = true ; break ;
2846
2903
case eltwise_tanh_use_dst_for_bwd:
2847
2904
case eltwise_tanh: tanh_ = true ; break ;
2905
+ case eltwise_hsigmoid: hsigmoid_ = true ; break ;
2848
2906
default : break ;
2849
2907
}
2850
2908
}
@@ -2856,6 +2914,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
2856
2914
bool gelu_tanh_ = false ;
2857
2915
bool gelu_erf_ = false ;
2858
2916
bool log_ = false ;
2917
+ bool hsigmoid_ = false ;
2859
2918
2860
2919
bool exp () const { return exp_ || soft_relu_ || gelu_erf_ || mish_; }
2861
2920
bool mish () const { return mish_; }
@@ -2864,6 +2923,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
2864
2923
bool gelu_tanh () const { return gelu_tanh_; }
2865
2924
bool gelu_erf () const { return gelu_erf_; }
2866
2925
bool log () const { return log_; }
2926
+ bool hsigmoid () const { return hsigmoid_; }
2867
2927
};
2868
2928
2869
2929
need_t need (alg_);
@@ -2903,6 +2963,7 @@ void jit_uni_eltwise_injector<isa, Wmm>::register_table_entries() {
2903
2963
if (need.log ()) push_entries_of (log_consts);
2904
2964
if (need.log ()) push_entries_of (log_polynomial);
2905
2965
if (need.log ()) push_entries_of (log_predefined_values);
2966
+ if (need.hsigmoid ()) push_entries_of (hsigmoid_values);
2906
2967
2907
2968
// Now that we registered the entries, we set the offsets. No
2908
2969
// entries should be registered after this point. This allows to
0 commit comments