Vectorize erfc() for float
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
index 86a49b6..5169f1c 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
@@ -345,6 +345,49 @@
/***************************************************************************
* Implementation of erfc, requires C++11/C99 *
****************************************************************************/
+template <typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc_float(const T& x) {
+ const T x_abs = pmin(pabs(x), pset1<T>(10.0f));
+ const T one = pset1<T>(1.0f);
+ const T x_abs_gt_one_mask = pcmp_lt(one, x_abs);
+
+ // erfc(x) = 1 + x * S(x^2), |x| <= 1.
+ //
+ // Coefficients for S and T generated with Rminimax command:
+ // ./ratapprox --function="erfc(x)-1" --dom='[-1,1]' --type=[11,0] --num="odd"
+ // --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec"
+ constexpr float alpha[] = {5.61802298761904239654541015625e-04, -4.91381669417023658752441406250e-03,
+ 2.67075151205062866210937500000e-02, -1.12800106406211853027343750000e-01,
+ 3.76122951507568359375000000000e-01, -1.12837910652160644531250000000e+00};
+ const T x2 = pmul(x, x);
+ const T erfc_small = pmadd(x, ppolevl<T, 5>::run(x2, alpha), one);
+
+ // Return early if we don't need the more expensive approximation for any
+ // entry in a.
+ if (!predux_any(x_abs_gt_one_mask)) return erfc_small;
+
+ // erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 9.
+ //
+ // Coefficients for P and Q generated with Rminimax command:
+ // ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)"
+ // --dom='[0.01,1]' --type=[3,4] --numF="[SG]" --denF="[SG]" --log
+ // --dispCoeff="dec"
+ constexpr float gamma[] = {1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f,
+ 3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f};
+ constexpr float delta[] = {1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f,
+ 1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f,
+ 9.5662862062454223632812500e-02f};
+ const T z = pexp(pnegate(x2));
+ const T q2 = preciprocal(x2);
+ const T num = ppolevl<T, 3>::run(q2, gamma);
+ const T denom = pmul(x_abs, ppolevl<T, 4>::run(q2, delta));
+ const T r = pdiv(num, denom);
+ // If x < -1 then use erfc(x) = 2 - erfc(|x|).
+ const T x_negative = pcmp_lt(x, pset1<T>(0.0f));
+ const T erfc_large = pselect(x_negative, pnmadd(z, r, pset1<T>(2.0f)), pmul(z, r));
+
+ return pselect(x_abs_gt_one_mask, erfc_large, erfc_small);
+}
template <typename Scalar>
struct erfc_impl {
@@ -365,7 +408,7 @@
#if defined(SYCL_DEVICE_ONLY)
return cl::sycl::erfc(x);
#else
- return ::erfcf(x);
+ return generic_fast_erfc_float(x);
#endif
}
};
@@ -462,17 +505,17 @@
template <typename T, typename ScalarType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) {
const ScalarType p0[] = {ScalarType(-5.99633501014107895267e1), ScalarType(9.80010754185999661536e1),
- ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1),
- ScalarType(-1.23916583867381258016e0)};
+ ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1),
+ ScalarType(-1.23916583867381258016e0)};
const ScalarType q0[] = {ScalarType(1.0),
- ScalarType(1.95448858338141759834e0),
- ScalarType(4.67627912898881538453e0),
- ScalarType(8.63602421390890590575e1),
- ScalarType(-2.25462687854119370527e2),
- ScalarType(2.00260212380060660359e2),
- ScalarType(-8.20372256168333339912e1),
- ScalarType(1.59056225126211695515e1),
- ScalarType(-1.18331621121330003142e0)};
+ ScalarType(1.95448858338141759834e0),
+ ScalarType(4.67627912898881538453e0),
+ ScalarType(8.63602421390890590575e1),
+ ScalarType(-2.25462687854119370527e2),
+ ScalarType(2.00260212380060660359e2),
+ ScalarType(-8.20372256168333339912e1),
+ ScalarType(1.59056225126211695515e1),
+ ScalarType(-1.18331621121330003142e0)};
const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
const T half = pset1<T>(ScalarType(0.5));
T c, c2, ndtri_gt_exp_neg_two;