Fix accuracy of logistic sigmoid
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index f792c8e..72d4858 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -1026,83 +1026,93 @@
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet packetOp(const Packet& x) const {
const Packet one = pset1<Packet>(T(1));
- return pdiv(one, padd(one, pexp(pnegate(x))));
+ const Packet e = pexp(x);
+ return pdiv(e, padd(one, e));
}
};
#ifndef EIGEN_GPU_COMPILE_PHASE
+
/** \internal
* \brief Template specialization of the logistic function for float.
- *
- * Uses just a 9/10-degree rational interpolant which
- * interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range
- * [-9, 18]. Below -9 we use the more accurate approximation
- * 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 within
- * one ulp. The shifted logistic is interpolated because it was easier to
- * make the fit converge.
- *
+ * Computes S(x) = exp(x) / (1 + exp(x)), where exp(x) is implemented
+ * using an algorithm partly adopted from the implementation of
+ * pexp_float. See the individual steps described in the code below.
+ * Note that compared to pexp, we use an additional outer multiplicative
+ * range reduction step using the identity exp(x) = exp(x/2)^2.
+ * This prevert us from having to call ldexp on values that could produce
+ * a denormal result, which allows us to call the faster implementation in
+ * pldexp_fast_impl<Packet>::run(p, m).
+ * The final squaring, however, doubles the error bound on the final
+ * approximation. Exhaustive testing shows that we have a worst case error
+ * of 4.5 ulps (compared to computing S(x) in double precision), which is
+ * acceptable.
*/
template <>
struct scalar_logistic_op<float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
- return packetOp(x);
+ const float e = numext::exp(x);
+ return e / (1.0f + e);
}
- template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- Packet packetOp(const Packet& _x) const {
- const Packet cutoff_lower = pset1<Packet>(-9.f);
- const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower);
- const bool any_small = predux_any(lt_mask);
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
+ packetOp(const Packet& _x) const {
+ const Packet cst_zero = pset1<Packet>(0.0f);
+ const Packet cst_one = pset1<Packet>(1.0f);
+ const Packet cst_half = pset1<Packet>(0.5f);
+ const Packet cst_exp_hi = pset1<Packet>(16.f);
+ const Packet cst_exp_lo = pset1<Packet>(-104.f);
- // The upper cut-off is the smallest x for which the rational approximation evaluates to 1.
- // Choosing this value saves us a few instructions clamping the results at the end.
-#ifdef EIGEN_VECTORIZE_FMA
- const Packet cutoff_upper = pset1<Packet>(15.7243833541870117f);
-#else
- const Packet cutoff_upper = pset1<Packet>(15.6437711715698242f);
-#endif
- const Packet x = pmin(_x, cutoff_upper);
+ // Clamp x to the non-trivial range where S(x). Outside this
+ // interval the correctly rounded value of S(x) is either zero
+ // or one.
+ Packet zero_mask = pcmp_lt(_x, cst_exp_lo);
+ Packet x = pmin(_x, cst_exp_hi);
- // The monomial coefficients of the numerator polynomial (odd).
- const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f);
- const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f);
- const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f);
- const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f);
- const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
+ // 1. Multiplicative range reduction:
+ // Reduce the range of x by a factor of 2. This avoids having
+ // to compute exp(x) accurately where the result is a denormalized
+ // value.
+ x = pmul(x, cst_half);
- // The monomial coefficients of the denominator polynomial (even).
- const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f);
- const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f);
- const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f);
- const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f);
- const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f);
- const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f);
+ // 2. Subtractive range reduction:
+ // Express exp(x) as exp(m*ln(2) + r) = 2^m*exp(r), start by extracting
+ // m = floor(x/ln(2) + 0.5), such that x = m*ln(2) + r.
+ const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
+ Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
+ // Get r = x - m*ln(2). We use a trick from Cephes where the term
+ // m*ln(2) is subtracted out in two parts, m*C1+m*C2 = m*ln(2),
+ // to avoid accumulating truncation errors.
+ const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
+ const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
+ Packet r = pmadd(m, cst_cephes_exp_C1, x);
+ r = pmadd(m, cst_cephes_exp_C2, r);
- // Since the polynomials are odd/even, we need x^2.
- const Packet x2 = pmul(x, x);
+ // 3. Compute an approximation to exp(r) using a degree 5 minimax polynomial.
+ // We compute even and odd terms separately to increase instruction level
+ // parallelism.
+ Packet r2 = pmul(r, r);
+ const Packet cst_p2 = pset1<Packet>(0.49999141693115234375f);
+ const Packet cst_p3 = pset1<Packet>(0.16666877269744873046875f);
+ const Packet cst_p4 = pset1<Packet>(4.1898667812347412109375e-2f);
+ const Packet cst_p5 = pset1<Packet>(8.33471305668354034423828125e-3f);
- // Evaluate the numerator polynomial p.
- Packet p = pmadd(x2, alpha_9, alpha_7);
- p = pmadd(x2, p, alpha_5);
- p = pmadd(x2, p, alpha_3);
- p = pmadd(x2, p, alpha_1);
- p = pmul(x, p);
+ const Packet p_even = pmadd(r2, cst_p4, cst_p2);
+ const Packet p_odd = pmadd(r2, cst_p5, cst_p3);
+ const Packet p_low = padd(r, cst_one);
+ Packet p = pmadd(r, p_odd, p_even);
+ p = pmadd(r2, p, p_low);
- // Evaluate the denominator polynomial q.
- Packet q = pmadd(x2, beta_10, beta_8);
- q = pmadd(x2, q, beta_6);
- q = pmadd(x2, q, beta_4);
- q = pmadd(x2, q, beta_2);
- q = pmadd(x2, q, beta_0);
- // Divide the numerator by the denominator and shift it up.
- const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f));
- if (EIGEN_PREDICT_FALSE(any_small)) {
- const Packet exponential = pexp(_x);
- return pselect(lt_mask, exponential, logistic);
- } else {
- return logistic;
- }
+ // 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r).
+ Packet e = pldexp_fast_impl<Packet>::run(p, m);
+
+ // 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2.
+ e = pmul(e, e);
+
+ // Return exp(x) / (1 + exp(x))
+ return pselect(zero_mask, cst_zero, pdiv(e, padd(cst_one, e)));
}
};
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE