Skip to content

Commit 3d1e89a

Browse files
vishwascmvishwascm
and
vishwascm
authored
cpu: aarch64: injectors: Improve performance of tanh for block size 16 (#2094)
Co-authored-by: vishwascm <vishwas.bc@futjitsu.com>
1 parent 19f0862 commit 3d1e89a

File tree

2 files changed

+326
-2
lines changed

2 files changed

+326
-2
lines changed

src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp

+320-2
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,87 @@ void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector_fwd(
475475
h->mov(vmm_src, p_mask / T_m, vmm_aux3);
476476
}
477477

478+
template <cpu_isa_t isa>
479+
void jit_uni_eltwise_injector_f32<
480+
isa>::tanh_polynomial_approx_compute_vector_fwd(const TRegS &vmm_src) {
481+
482+
if (!utils::one_of(isa, sve_512)) return;
483+
484+
using namespace Xbyak_aarch64::util;
485+
486+
const int tanh_n_polynomials = 32;
487+
488+
// Register mapping
489+
TRegS vmm_dst = vmm_aux1, vmm_src_shift = vmm_aux1, vmm_coeff = vmm_aux1,
490+
vmm_pol = vmm_aux2, vmm_indices = vmm_aux3, vmm_tmp = vmm_aux3,
491+
vmm_src_pos = vmm_aux4, vmm_sign = vmm_aux4;
492+
493+
const auto &mask = PReg(6); // avoid pred regs used in *conv_kernel*
494+
495+
// Helper function to gather polynomial coefficients
496+
auto gather_coefficient = [&](TRegS vmm_coeff, int coeff_idx,
497+
TRegS vmm_pol_idx) {
498+
h->add_imm(h->X_TMP_1, x_table,
499+
table_off(tanh_pol_table, coeff_idx * tanh_n_polynomials),
500+
h->X_TMP_0);
501+
h->ld1w(ZRegS(IDX(vmm_coeff)), p_all,
502+
ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW));
503+
};
504+
505+
// because tanh(x) = -tanh(-x), we extract sign to make x postive
506+
// and reapply sign at the end
507+
h->fabs(vmm_src_pos, p_all / T_z, vmm_src);
508+
509+
// Compute indices for the table lookup
510+
h->sub(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_src_pos)),
511+
ZRegS(IDX(table_val(tanh_idx_bias, z_tmp))));
512+
h->and_(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)),
513+
ZRegD(IDX(table_val(tanh_idx_mask, z_tmp))));
514+
h->lsr(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)), 20);
515+
516+
// Argument reduction
517+
h->and_(ZRegD(IDX(vmm_src_shift)), ZRegD(IDX(vmm_src_pos)),
518+
ZRegD(IDX(table_val(tanh_idx_mask, z_tmp))));
519+
h->fsub(vmm_src_pos, vmm_src_pos, ZRegS(IDX(vmm_src_shift)));
520+
521+
gather_coefficient(vmm_pol, 6, vmm_indices);
522+
for (int deg = 5; deg >= 0; --deg) {
523+
gather_coefficient(vmm_coeff, deg, vmm_indices);
524+
h->fmad(vmm_pol, p_all / T_m, vmm_src_pos, vmm_coeff);
525+
}
526+
527+
// Restore src_pos
528+
h->fabs(vmm_src_pos, p_all / T_z, vmm_src);
529+
530+
// Now Blend the results
531+
// [saturation_ubound; +inf] : return +/- 1
532+
table_val(one, vmm_dst);
533+
534+
// [linear_ubound; saturation_lbound] : return +/- P(x)
535+
table_val(tanh_saturation_lbound, vmm_tmp);
536+
h->fcmgt(PRegS(IDX(mask)), p_all / T_z, vmm_tmp, vmm_src_pos);
537+
h->sel(vmm_dst, mask / T_m, vmm_pol, vmm_dst);
538+
539+
// [0; linear_ubound] : return x
540+
table_val(tanh_linear_ubound, vmm_tmp);
541+
h->fcmgt(PRegS(IDX(mask)), p_all / T_z, vmm_tmp, vmm_src_pos);
542+
h->sel(vmm_dst, mask / T_m, vmm_src_pos, vmm_dst);
543+
544+
// Reapply sign and return
545+
h->and_(ZRegD(IDX(vmm_sign)), ZRegD(IDX(vmm_src)),
546+
ZRegD(IDX(table_val(sign_mask, z_tmp))));
547+
h->eor(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_dst)), ZRegD(IDX(vmm_sign)));
548+
}
549+
478550
template <cpu_isa_t isa>
479551
void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector_fwd(
480552
const TRegS &vmm_src) {
553+
554+
if (utils::one_of(isa, sve_512)) {
555+
tanh_polynomial_approx_compute_vector_fwd(vmm_src);
556+
return;
557+
}
558+
481559
// tanh(x) = x(1 + (-1/3)x^2) for |x| < tanh_range
482560
// tanh(x) = 1 - 2/(1 + exp(2 x)) for otherwise
483561

@@ -1734,9 +1812,248 @@ void jit_uni_eltwise_injector_f32<isa>::register_table_entries() {
17341812
{bwd_mish_max_x_for_equation_f, {0x41b17217, true}}};
17351813

17361814
// tanh(x) constants for four interval approximation
1737-
static const table_t tanh_consts {
1738-
{tanh_range, {0x3d4ccccd, true}},
1815+
// and for polynomial approximation
1816+
static const table_t tanh_consts {{tanh_range, {0x3d4ccccd, true}},
17391817
{tanh_m1d3, {0xbeaaaaab, true}},
1818+
{tanh_idx_bias, {0x39800000, true}},
1819+
{tanh_idx_mask, {0xffc00000, true}},
1820+
{tanh_linear_ubound, {0x39ddb3d7, true}},
1821+
{tanh_saturation_lbound, {0x41102cb3, true}}};
1822+
1823+
// tanh(x) polynomial approximation
1824+
// For each coefficient, there is 32 entries
1825+
static const table_t tanh_polynomial_table {
1826+
// coefficients of degree 0
1827+
{tanh_pol_table, {0x00000000, false}},
1828+
{tanh_pol_table, {0x39bfffff, false}},
1829+
{tanh_pol_table, {0x39ffffff, false}},
1830+
{tanh_pol_table, {0x3a3ffffe, false}},
1831+
{tanh_pol_table, {0x3a7ffffb, false}},
1832+
{tanh_pol_table, {0x3abffff7, false}},
1833+
{tanh_pol_table, {0x3affffeb, false}},
1834+
{tanh_pol_table, {0x3b3fffdc, false}},
1835+
{tanh_pol_table, {0x3b7fffab, false}},
1836+
{tanh_pol_table, {0x3bbfff70, false}},
1837+
{tanh_pol_table, {0x3bfffeab, false}},
1838+
{tanh_pol_table, {0x3c3ffdc0, false}},
1839+
{tanh_pol_table, {0x3c7ffaab, false}},
1840+
{tanh_pol_table, {0x3cbff701, false}},
1841+
{tanh_pol_table, {0x3cffeaad, false}},
1842+
{tanh_pol_table, {0x3d3fdc08, false}},
1843+
{tanh_pol_table, {0x3d7faacd, false}},
1844+
{tanh_pol_table, {0x3dbf7081, false}},
1845+
{tanh_pol_table, {0x3dfeacc9, false}},
1846+
{tanh_pol_table, {0x3e3dc7fd, false}},
1847+
{tanh_pol_table, {0x3e7acbf5, false}},
1848+
{tanh_pol_table, {0x3eb77a9f, false}},
1849+
{tanh_pol_table, {0x3eec9a9f, false}},
1850+
{tanh_pol_table, {0x3f22991f, false}},
1851+
{tanh_pol_table, {0x3f42f7d6, false}},
1852+
{tanh_pol_table, {0x3f67b7cc, false}},
1853+
{tanh_pol_table, {0x3f76ca83, false}},
1854+
{tanh_pol_table, {0x3f7ebbe9, false}},
1855+
{tanh_pol_table, {0x3f7fd40c, false}},
1856+
{tanh_pol_table, {0x3f7fff32, false}},
1857+
{tanh_pol_table, {0x3f7ffffc, false}},
1858+
{tanh_pol_table, {0x3f800000, false}},
1859+
// coefficients of degree 1
1860+
{tanh_pol_table, {0x3f800000, false}},
1861+
{tanh_pol_table, {0x3f800018, false}},
1862+
{tanh_pol_table, {0x3f7fffe8, false}},
1863+
{tanh_pol_table, {0x3f7fffda, false}},
1864+
{tanh_pol_table, {0x3f7fffdc, false}},
1865+
{tanh_pol_table, {0x3f7fffdc, false}},
1866+
{tanh_pol_table, {0x3f7fffac, false}},
1867+
{tanh_pol_table, {0x3f7fff70, false}},
1868+
{tanh_pol_table, {0x3f7ffeec, false}},
1869+
{tanh_pol_table, {0x3f7ffdc0, false}},
1870+
{tanh_pol_table, {0x3f7ffbed, false}},
1871+
{tanh_pol_table, {0x3f7ff704, false}},
1872+
{tanh_pol_table, {0x3f7feff5, false}},
1873+
{tanh_pol_table, {0x3f7fdbca, false}},
1874+
{tanh_pol_table, {0x3f7fbfff, false}},
1875+
{tanh_pol_table, {0x3f7f7041, false}},
1876+
{tanh_pol_table, {0x3f7f009b, false}},
1877+
{tanh_pol_table, {0x3f7dc36c, false}},
1878+
{tanh_pol_table, {0x3f7c0aa8, false}},
1879+
{tanh_pol_table, {0x3f7734b8, false}},
1880+
{tanh_pol_table, {0x3f70a4de, false}},
1881+
{tanh_pol_table, {0x3f5f1fd8, false}},
1882+
{tanh_pol_table, {0x3f495493, false}},
1883+
{tanh_pol_table, {0x3f18b9ec, false}},
1884+
{tanh_pol_table, {0x3ed706cb, false}},
1885+
{tanh_pol_table, {0x3e390b06, false}},
1886+
{tanh_pol_table, {0x3d90b11f, false}},
1887+
{tanh_pol_table, {0x3c21a053, false}},
1888+
{tanh_pol_table, {0x3aaf7fdb, false}},
1889+
{tanh_pol_table, {0x37ccc1a3, false}},
1890+
{tanh_pol_table, {0x355c6733, false}},
1891+
{tanh_pol_table, {0x00000000, false}},
1892+
// coefficients of degree 2
1893+
{tanh_pol_table, {0x00000000, false}},
1894+
{tanh_pol_table, {0xbe4e0ff1, false}},
1895+
{tanh_pol_table, {0x3d25b1b1, false}},
1896+
{tanh_pol_table, {0x3d6b6dab, false}},
1897+
{tanh_pol_table, {0x3c9fb1d5, false}},
1898+
{tanh_pol_table, {0xbabff06f, false}},
1899+
{tanh_pol_table, {0x3c07b3f6, false}},
1900+
{tanh_pol_table, {0xbb3fc1bc, false}},
1901+
{tanh_pol_table, {0x3a9f5921, false}},
1902+
{tanh_pol_table, {0xbbbf06f2, false}},
1903+
{tanh_pol_table, {0xbbb0f402, false}},
1904+
{tanh_pol_table, {0xbc47db9e, false}},
1905+
{tanh_pol_table, {0xbc73d5e7, false}},
1906+
{tanh_pol_table, {0xbca25bda, false}},
1907+
{tanh_pol_table, {0xbcfca780, false}},
1908+
{tanh_pol_table, {0xbd40e07c, false}},
1909+
{tanh_pol_table, {0xbd7dab03, false}},
1910+
{tanh_pol_table, {0xbdbe4a0f, false}},
1911+
{tanh_pol_table, {0xbdfb14a5, false}},
1912+
{tanh_pol_table, {0xbe36cc8d, false}},
1913+
{tanh_pol_table, {0xbe6bd102, false}},
1914+
{tanh_pol_table, {0xbe9fe7c5, false}},
1915+
{tanh_pol_table, {0xbeba0f10, false}},
1916+
{tanh_pol_table, {0xbec206a8, false}},
1917+
{tanh_pol_table, {0xbea3c388, false}},
1918+
{tanh_pol_table, {0xbe277d62, false}},
1919+
{tanh_pol_table, {0xbd8b7960, false}},
1920+
{tanh_pol_table, {0xbc209f49, false}},
1921+
{tanh_pol_table, {0xbaad44ca, false}},
1922+
{tanh_pol_table, {0xb7c6eeac, false}},
1923+
{tanh_pol_table, {0xb663aa41, false}},
1924+
{tanh_pol_table, {0x00000000, false}},
1925+
// coefficients of degree 3
1926+
{tanh_pol_table, {0x00000000, false}},
1927+
{tanh_pol_table, {0x45b3ae96, false}},
1928+
{tanh_pol_table, {0xc414eb20, false}},
1929+
{tanh_pol_table, {0xc450e02e, false}},
1930+
{tanh_pol_table, {0xc3152b4e, false}},
1931+
{tanh_pol_table, {0xbead2f56, false}},
1932+
{tanh_pol_table, {0xc2162e02, false}},
1933+
{tanh_pol_table, {0xbeb4bd5a, false}},
1934+
{tanh_pol_table, {0xc11a59a4, false}},
1935+
{tanh_pol_table, {0xbed2f507, false}},
1936+
{tanh_pol_table, {0xc020d32c, false}},
1937+
{tanh_pol_table, {0x3dd0f506, false}},
1938+
{tanh_pol_table, {0xbf2a75e2, false}},
1939+
{tanh_pol_table, {0xbff950e3, false}},
1940+
{tanh_pol_table, {0xbed47334, false}},
1941+
{tanh_pol_table, {0xbe809b8c, false}},
1942+
{tanh_pol_table, {0xbeb64532, false}},
1943+
{tanh_pol_table, {0xbe961a5b, false}},
1944+
{tanh_pol_table, {0xbe9b63ac, false}},
1945+
{tanh_pol_table, {0xbea0d4b2, false}},
1946+
{tanh_pol_table, {0xbe828a77, false}},
1947+
{tanh_pol_table, {0xbe378612, false}},
1948+
{tanh_pol_table, {0xbdc20908, false}},
1949+
{tanh_pol_table, {0x3d2d3957, false}},
1950+
{tanh_pol_table, {0x3dd46e89, false}},
1951+
{tanh_pol_table, {0x3db3f629, false}},
1952+
{tanh_pol_table, {0x3d2c5e7b, false}},
1953+
{tanh_pol_table, {0x3bd20403, false}},
1954+
{tanh_pol_table, {0x3a59dfae, false}},
1955+
{tanh_pol_table, {0x3770af45, false}},
1956+
{tanh_pol_table, {0x372cc014, false}},
1957+
{tanh_pol_table, {0x00000000, false}},
1958+
// coefficients of degree 4
1959+
{tanh_pol_table, {0x00000000, false}},
1960+
{tanh_pol_table, {0xcc981a1b, false}},
1961+
{tanh_pol_table, {0x4a7edd3d, false}},
1962+
{tanh_pol_table, {0x4ab1007c, false}},
1963+
{tanh_pol_table, {0x48fedd9c, false}},
1964+
{tanh_pol_table, {0x41a557b5, false}},
1965+
{tanh_pol_table, {0x477ee32a, false}},
1966+
{tanh_pol_table, {0x422557f5, false}},
1967+
{tanh_pol_table, {0x45ff3ce4, false}},
1968+
{tanh_pol_table, {0x42a55641, false}},
1969+
{tanh_pol_table, {0x446e0867, false}},
1970+
{tanh_pol_table, {0xc33dc19a, false}},
1971+
{tanh_pol_table, {0x42915214, false}},
1972+
{tanh_pol_table, {0x43af4fad, false}},
1973+
{tanh_pol_table, {0x4110fe88, false}},
1974+
{tanh_pol_table, {0xc1099b75, false}},
1975+
{tanh_pol_table, {0x3fc8a8dc, false}},
1976+
{tanh_pol_table, {0xbfbeaef5, false}},
1977+
{tanh_pol_table, {0xbe365aad, false}},
1978+
{tanh_pol_table, {0x3f4d9652, false}},
1979+
{tanh_pol_table, {0x3ddfa08f, false}},
1980+
{tanh_pol_table, {0x3e34e9b8, false}},
1981+
{tanh_pol_table, {0x3e2d07a6, false}},
1982+
{tanh_pol_table, {0x3dc63567, false}},
1983+
{tanh_pol_table, {0x3cdaeb78, false}},
1984+
{tanh_pol_table, {0xbcd17537, false}},
1985+
{tanh_pol_table, {0xbc92829c, false}},
1986+
{tanh_pol_table, {0xbb43ab99, false}},
1987+
{tanh_pol_table, {0xb9b471dd, false}},
1988+
{tanh_pol_table, {0xb6baad5a, false}},
1989+
{tanh_pol_table, {0xb78bafc7, false}},
1990+
{tanh_pol_table, {0x00000000, false}},
1991+
// coefficients of degree 5
1992+
{tanh_pol_table, {0x00000000, false}},
1993+
{tanh_pol_table, {0x52f688d5, false}},
1994+
{tanh_pol_table, {0xd0505c72, false}},
1995+
{tanh_pol_table, {0xd08f98e3, false}},
1996+
{tanh_pol_table, {0xce505cc9, false}},
1997+
{tanh_pol_table, {0xc7162b8a, false}},
1998+
{tanh_pol_table, {0xcc5061d6, false}},
1999+
{tanh_pol_table, {0xc7162bdf, false}},
2000+
{tanh_pol_table, {0xca50b37f, false}},
2001+
{tanh_pol_table, {0xc7162a3a, false}},
2002+
{tanh_pol_table, {0xc8422086, false}},
2003+
{tanh_pol_table, {0x471a714e, false}},
2004+
{tanh_pol_table, {0xc5ece1f1, false}},
2005+
{tanh_pol_table, {0xc70e3d90, false}},
2006+
{tanh_pol_table, {0xc3eba94a, false}},
2007+
{tanh_pol_table, {0x43e0c424, false}},
2008+
{tanh_pol_table, {0xc21f4552, false}},
2009+
{tanh_pol_table, {0x42217cc8, false}},
2010+
{tanh_pol_table, {0x405e7dc4, false}},
2011+
{tanh_pol_table, {0xc10dd401, false}},
2012+
{tanh_pol_table, {0x3e96b602, false}},
2013+
{tanh_pol_table, {0xbd1a6d2f, false}},
2014+
{tanh_pol_table, {0xbd393883, false}},
2015+
{tanh_pol_table, {0xbd674682, false}},
2016+
{tanh_pol_table, {0xbd310016, false}},
2017+
{tanh_pol_table, {0xb961e269, false}},
2018+
{tanh_pol_table, {0x3ba32495, false}},
2019+
{tanh_pol_table, {0x3a7680d5, false}},
2020+
{tanh_pol_table, {0x38b3173c, false}},
2021+
{tanh_pol_table, {0x35a9deea, false}},
2022+
{tanh_pol_table, {0x375c3f2a, false}},
2023+
{tanh_pol_table, {0x00000000, false}},
2024+
// coefficients of degree 6
2025+
{tanh_pol_table, {0x00000000, false}},
2026+
{tanh_pol_table, {0xd8995ed1, false}},
2027+
{tanh_pol_table, {0x558285ea, false}},
2028+
{tanh_pol_table, {0x55b2cd69, false}},
2029+
{tanh_pol_table, {0x53028625, false}},
2030+
{tanh_pol_table, {0x4bc9991f, false}},
2031+
{tanh_pol_table, {0x5082898a, false}},
2032+
{tanh_pol_table, {0x4b4999b3, false}},
2033+
{tanh_pol_table, {0x4e02c07c, false}},
2034+
{tanh_pol_table, {0x4ac99764, false}},
2035+
{tanh_pol_table, {0x4b72c822, false}},
2036+
{tanh_pol_table, {0xca40c0e1, false}},
2037+
{tanh_pol_table, {0x489413e4, false}},
2038+
{tanh_pol_table, {0x49b12224, false}},
2039+
{tanh_pol_table, {0x46134c4e, false}},
2040+
{tanh_pol_table, {0xc60c2d57, false}},
2041+
{tanh_pol_table, {0x43c83910, false}},
2042+
{tanh_pol_table, {0xc3c872d1, false}},
2043+
{tanh_pol_table, {0xc186bc9e, false}},
2044+
{tanh_pol_table, {0x42325bc3, false}},
2045+
{tanh_pol_table, {0xbf2ffa4a, false}},
2046+
{tanh_pol_table, {0x3d9a203c, false}},
2047+
{tanh_pol_table, {0xbc545a43, false}},
2048+
{tanh_pol_table, {0xbae08fee, false}},
2049+
{tanh_pol_table, {0x3c80225d, false}},
2050+
{tanh_pol_table, {0x3b1fd1df, false}},
2051+
{tanh_pol_table, {0xba36b9d1, false}},
2052+
{tanh_pol_table, {0xb91de544, false}},
2053+
{tanh_pol_table, {0xb71f100f, false}},
2054+
{tanh_pol_table, {0xb408e2ed, false}},
2055+
{tanh_pol_table, {0xb685fec8, false}},
2056+
{tanh_pol_table, {0x00000000, false}},
17402057
};
17412058

17422059
// soft_relu(x) constants
@@ -2061,6 +2378,7 @@ void jit_uni_eltwise_injector_f32<isa>::register_table_entries() {
20612378
if (need.exp()) push_entries_of(exp_consts2);
20622379
if (need.mish()) push_entries_of(mish_consts);
20632380
if (need.tanh()) push_entries_of(tanh_consts);
2381+
if (need.tanh()) push_entries_of(tanh_polynomial_table);
20642382
if (need.soft_relu()) push_entries_of(soft_relu_consts);
20652383
if (need.soft_relu()) push_entries_of(soft_relu_polynomial);
20662384
if (need.gelu_tanh()) push_entries_of(gelu_tanh_consts);

src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ struct jit_uni_eltwise_injector_f32 {
215215
void relu_zero_ns_compute_vector_fwd(const TRegS &vmm_src);
216216
void elu_compute_vector_fwd(const TRegS &vmm_src);
217217
void tanh_compute_vector_fwd(const TRegS &vmm_src);
218+
void tanh_polynomial_approx_compute_vector_fwd(const TRegS &vmm_src);
218219
void square_compute_vector_fwd(const TRegS &vmm_src);
219220
void abs_compute_vector_fwd(const TRegS &vmm_src);
220221
void sqrt_compute_vector_fwd(const TRegS &vmm_src);
@@ -277,6 +278,11 @@ struct jit_uni_eltwise_injector_f32 {
277278
bwd_mish_max_x_for_equation_f,
278279
tanh_range, // tanh(x) = x - x^3/3 for |x| < tanh_range
279280
tanh_m1d3, // -1/3
281+
tanh_idx_bias, // bias applied during index computation
282+
tanh_idx_mask, // mask applied to extract index
283+
tanh_linear_ubound, // arg below which tanh(x) = x
284+
tanh_saturation_lbound, // arg after which tanh(x) = 1.f
285+
tanh_pol_table, // table of polynomial coefficients
280286
soft_relu_one_twenty_six, // 126.f
281287
soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign
282288
soft_relu_pol, // see correspondent table for float values

0 commit comments

Comments
 (0)