Simplify and speed up pow() by 5-6%
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index e21d3ef..652892e 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -2016,130 +2016,6 @@
}
};
-// This specialization uses a faster algorithm to compute exp2(x) for floats
-// in [-0.5;0.5] with a relative accuracy of 1 ulp.
-// The minimax polynomial used was calculated using the Sollya tool.
-// See sollya.org.
-template <>
-struct fast_accurate_exp2<float> {
- template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) {
- // This function approximates exp2(x) by a degree 6 polynomial of the form
- // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
- // single precision, and the remaining steps are evaluated with extra precision using
- // double word arithmetic. C is an extra precise constant stored as a double word.
- //
- // The polynomial coefficients were calculated using Sollya commands:
- // > n = 6;
- // > f = 2^x;
- // > interval = [-0.5;0.5];
- // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
-
- const Packet p4 = pset1<Packet>(1.539513905e-4f);
- const Packet p3 = pset1<Packet>(1.340007293e-3f);
- const Packet p2 = pset1<Packet>(9.618283249e-3f);
- const Packet p1 = pset1<Packet>(5.550328270e-2f);
- const Packet p0 = pset1<Packet>(0.2402264923f);
-
- const Packet C_hi = pset1<Packet>(0.6931471825f);
- const Packet C_lo = pset1<Packet>(2.36836577e-08f);
- const Packet one = pset1<Packet>(1.0f);
-
- // Evaluate P(x) in working precision.
- // We evaluate even and odd parts of the polynomial separately
- // to gain some instruction level parallelism.
- Packet x2 = pmul(x, x);
- Packet p_even = pmadd(p4, x2, p2);
- Packet p_odd = pmadd(p3, x2, p1);
- p_even = pmadd(p_even, x2, p0);
- Packet p = pmadd(p_odd, x, p_even);
-
- // Evaluate the remaining terms of Q(x) with extra precision using
- // double word arithmetic.
- Packet p_hi, p_lo;
- // x * p(x)
- twoprod(p, x, p_hi, p_lo);
- // C + x * p(x)
- Packet q1_hi, q1_lo;
- twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
- // x * (C + x * p(x))
- Packet q2_hi, q2_lo;
- twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
- // 1 + x * (C + x * p(x))
- Packet q3_hi, q3_lo;
- // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
- // for adding it to unity here.
- fast_twosum(one, q2_hi, q3_hi, q3_lo);
- return padd(q3_hi, padd(q2_lo, q3_lo));
- }
-};
-
-// in [-0.5;0.5] with a relative accuracy of 1 ulp.
-// The minimax polynomial used was calculated using the Sollya tool.
-// See sollya.org.
-template <>
-struct fast_accurate_exp2<double> {
- template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) {
- // This function approximates exp2(x) by a degree 10 polynomial of the form
- // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in
- // single precision, and the remaining steps are evaluated with extra precision using
- // double word arithmetic. C is an extra precise constant stored as a double word.
- //
- // The polynomial coefficients were calculated using Sollya commands:
- // > n = 11;
- // > f = 2^x;
- // > interval = [-0.5;0.5];
- // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating);
-
- const Packet p9 = pset1<Packet>(4.431642109085495276e-10);
- const Packet p8 = pset1<Packet>(7.073829923303358410e-9);
- const Packet p7 = pset1<Packet>(1.017822306737031311e-7);
- const Packet p6 = pset1<Packet>(1.321543498017646657e-6);
- const Packet p5 = pset1<Packet>(1.525273342728892877e-5);
- const Packet p4 = pset1<Packet>(1.540353045780084423e-4);
- const Packet p3 = pset1<Packet>(1.333355814685869807e-3);
- const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
- const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
- const Packet p0 = pset1<Packet>(0.240226506959101332);
- const Packet C_hi = pset1<Packet>(0.693147180559945286);
- const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
- const Packet one = pset1<Packet>(1.0);
-
- // Evaluate P(x) in working precision.
- // We evaluate even and odd parts of the polynomial separately
- // to gain some instruction level parallelism.
- Packet x2 = pmul(x, x);
- Packet p_even = pmadd(p8, x2, p6);
- Packet p_odd = pmadd(p9, x2, p7);
- p_even = pmadd(p_even, x2, p4);
- p_odd = pmadd(p_odd, x2, p5);
- p_even = pmadd(p_even, x2, p2);
- p_odd = pmadd(p_odd, x2, p3);
- p_even = pmadd(p_even, x2, p0);
- p_odd = pmadd(p_odd, x2, p1);
- Packet p = pmadd(p_odd, x, p_even);
-
- // Evaluate the remaining terms of Q(x) with extra precision using
- // double word arithmetic.
- Packet p_hi, p_lo;
- // x * p(x)
- twoprod(p, x, p_hi, p_lo);
- // C + x * p(x)
- Packet q1_hi, q1_lo;
- twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
- // x * (C + x * p(x))
- Packet q2_hi, q2_lo;
- twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
- // 1 + x * (C + x * p(x))
- Packet q3_hi, q3_lo;
- // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
- // for adding it to unity here.
- fast_twosum(one, q2_hi, q3_hi, q3_lo);
- return padd(q3_hi, padd(q2_lo, q3_lo));
- }
-};
-
// This function implements the non-trivial case of pow(x,y) where x is
// positive and y is (possibly) non-integer.
// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
@@ -2186,11 +2062,18 @@
// We now have an accurate split of f = n_z + r_z and can compute
// x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
- // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
- // using a specialized algorithm. Multiplication by the second factor can
- // be done exactly using pldexp(), since it is an integer power of 2.
- const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
- return pldexp(e_r, n_z);
+ // Multiplication by the second factor can be done exactly using pldexp(), since
+ // it is an integer power of 2.
+ const Packet e_r = generic_exp2(r_z);
+
+ // Since we know that e_r is in [1/sqrt(2); sqrt(2)], we can use the fast version
+ // of pldexp to multiply by 2**{n_z} when |n_z| is sufficiently small.
+ constexpr Scalar kPldExpThresh = std::numeric_limits<Scalar>::max_exponent - 2;
+ const Packet pldexp_fast_unsafe = pcmp_lt(pset1<Packet>(kPldExpThresh), pabs(n_z));
+ if (predux_any(pldexp_fast_unsafe)) {
+ return pldexp(e_r, n_z);
+ }
+ return pldexp_fast(e_r, n_z);
}
// Generic implementation of pow(x,y).