Vectorize erfc() for float
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index 86a49b6..5169f1c 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
@@ -345,6 +345,49 @@ /*************************************************************************** * Implementation of erfc, requires C++11/C99 * ****************************************************************************/ +template <typename T> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc_float(const T& x) { + const T x_abs = pmin(pabs(x), pset1<T>(10.0f)); + const T one = pset1<T>(1.0f); + const T x_abs_gt_one_mask = pcmp_lt(one, x_abs); + + // erfc(x) = 1 + x * S(x^2), |x| <= 1. + // + // Coefficients for S and T generated with Rminimax command: + // ./ratapprox --function="erfc(x)-1" --dom='[-1,1]' --type=[11,0] --num="odd" + // --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec" + constexpr float alpha[] = {5.61802298761904239654541015625e-04, -4.91381669417023658752441406250e-03, + 2.67075151205062866210937500000e-02, -1.12800106406211853027343750000e-01, + 3.76122951507568359375000000000e-01, -1.12837910652160644531250000000e+00}; + const T x2 = pmul(x, x); + const T erfc_small = pmadd(x, ppolevl<T, 5>::run(x2, alpha), one); + + // Return early if we don't need the more expensive approximation for any + // entry in a. + if (!predux_any(x_abs_gt_one_mask)) return erfc_small; + + // erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 9. + // + // Coefficients for P and Q generated with Rminimax command: + // ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)" + // --dom='[0.01,1]' --type=[3,4] --numF="[SG]" --denF="[SG]" --log + // --dispCoeff="dec" + constexpr float gamma[] = {1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f, + 3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f}; + constexpr float delta[] = {1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f, + 1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f, + 9.5662862062454223632812500e-02f}; + const T z = pexp(pnegate(x2)); + const T q2 = preciprocal(x2); + const T num = ppolevl<T, 3>::run(q2, gamma); + const T denom = pmul(x_abs, ppolevl<T, 4>::run(q2, delta)); + const T r = pdiv(num, denom); + // If x < -1 then use erfc(x) = 2 - erfc(|x|). + const T x_negative = pcmp_lt(x, pset1<T>(0.0f)); + const T erfc_large = pselect(x_negative, pnmadd(z, r, pset1<T>(2.0f)), pmul(z, r)); + + return pselect(x_abs_gt_one_mask, erfc_large, erfc_small); +} template <typename Scalar> struct erfc_impl { @@ -365,7 +408,7 @@ #if defined(SYCL_DEVICE_ONLY) return cl::sycl::erfc(x); #else - return ::erfcf(x); + return generic_fast_erfc_float(x); #endif } }; @@ -462,17 +505,17 @@ template <typename T, typename ScalarType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) { const ScalarType p0[] = {ScalarType(-5.99633501014107895267e1), ScalarType(9.80010754185999661536e1), - ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1), - ScalarType(-1.23916583867381258016e0)}; + ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1), + ScalarType(-1.23916583867381258016e0)}; const ScalarType q0[] = {ScalarType(1.0), - ScalarType(1.95448858338141759834e0), - ScalarType(4.67627912898881538453e0), - ScalarType(8.63602421390890590575e1), - ScalarType(-2.25462687854119370527e2), - ScalarType(2.00260212380060660359e2), - ScalarType(-8.20372256168333339912e1), - ScalarType(1.59056225126211695515e1), - ScalarType(-1.18331621121330003142e0)}; + ScalarType(1.95448858338141759834e0), + ScalarType(4.67627912898881538453e0), + ScalarType(8.63602421390890590575e1), + ScalarType(-2.25462687854119370527e2), + ScalarType(2.00260212380060660359e2), + ScalarType(-8.20372256168333339912e1), + ScalarType(1.59056225126211695515e1), + ScalarType(-1.18331621121330003142e0)}; const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0)); const T half = pset1<T>(ScalarType(0.5)); T c, c2, ndtri_gt_exp_neg_two;