Core: patanh fast-exit and plog10_float hi+lo accuracy fix libeigen/eigen!2497 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index e91bd11..8e54d4c 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -596,15 +596,27 @@ log10(x) ~= log(x) * hi + log(x) * lo, computed via fma. */ template <typename Packet> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog10_float(const Packet& x) { - const Packet cst_log10e = pset1<Packet>(0.4342944819032518f); - return pmul(plog(x), cst_log10e); + typedef typename unpacket_traits<Packet>::type Scalar; + // log10(e) in higher precision, split into hi+lo so log_x*(hi+lo) is reconstructed via FMA. + // hi = round-to-nearest-float of log10(e); lo = float(log10(e) - hi). + const Packet cst_log10e_hi = pset1<Packet>(0.4342944920063018f); + const Packet cst_log10e_lo = pset1<Packet>(-1.0103049952192578e-08f); + const Packet cst_inf = pset1<Packet>(NumTraits<Scalar>::infinity()); + const Packet cst_zero = pzero(x); + + const Packet log_x = plog(x); + const Packet finite_mask = pcmp_lt(pabs(log_x), cst_inf); + const Packet finite_log_x = pselect(finite_mask, log_x, cst_zero); + const Packet split_log10_x = pmadd(finite_log_x, cst_log10e_hi, pmul(finite_log_x, cst_log10e_lo)); + return pselect(finite_mask, split_log10_x, log_x); } /** \internal \returns log10(x) for double precision float. - Computed as log(x) * log10(e). */ + Computed as log(x) * log10(e). For double, single-constant rounding error + is ~1e-17 (sub-ULP), so the simple form is precise enough. */ template <typename Packet> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog10_double(const Packet& x) { - const Packet cst_log10e = pset1<Packet>(0.4342944819032518); + const Packet cst_log10e = pset1<Packet>(0.43429448190325182); return pmul(plog(x), cst_log10e); }
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h b/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h index 5514ffd..3350de2 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathTrig.h
@@ -775,13 +775,18 @@ Packet p = ppolevl<Packet, 4>::run(x2, alpha); p = pmadd(x3, p, x); - // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); const Packet half = pset1<Packet>(0.5f); const Packet one = pset1<Packet>(1.0f); + const Packet x_gt_half = pcmp_le(half, pabs(x)); + // Fast exit: if all |x| <= 0.5, skip the expensive plog/pdiv branch. + if (!predux_any(x_gt_half)) { + return p; + } + + // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); Packet r = pdiv(padd(one, x), psub(one, x)); r = pmul(half, plog(r)); - const Packet x_gt_half = pcmp_le(half, pabs(x)); const Packet x_eq_one = pcmp_eq(one, pabs(x)); const Packet x_gt_one = pcmp_lt(one, pabs(x)); const Packet sign_mask = pset1<Packet>(-0.0f); @@ -808,13 +813,18 @@ Packet q = ppolevl<Packet, 5>::run(x2, beta); Packet y_small = pmadd(x3, pdiv(p, q), x); - // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); const Packet half = pset1<Packet>(0.5); const Packet one = pset1<Packet>(1.0); + const Packet x_gt_half = pcmp_le(half, pabs(x)); + // Fast exit: if all |x| <= 0.5, skip the expensive plog/pdiv branch. + if (!predux_any(x_gt_half)) { + return y_small; + } + + // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); Packet y_large = pdiv(padd(one, x), psub(one, x)); y_large = pmul(half, plog(y_large)); - const Packet x_gt_half = pcmp_le(half, pabs(x)); const Packet x_eq_one = pcmp_eq(one, pabs(x)); const Packet x_gt_one = pcmp_lt(one, pabs(x)); const Packet sign_mask = pset1<Packet>(-0.0);
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 2e74e2e..d431046 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp
@@ -191,6 +191,7 @@ unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(exp)); unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(exp2)); unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(log)); + unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(log10)); unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(sin)); unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(cos)); unary_op_test<Scalar>(UNARY_FUNCTOR_TEST_ARGS(tan));
diff --git a/test/packetmath.cpp b/test/packetmath.cpp index f36e6d5..69b607f 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp
@@ -886,6 +886,7 @@ CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog); CHECK_CWISE1_IF(PacketTraits::HasLog, log2, internal::plog2); + CHECK_CWISE1_IF(PacketTraits::HasLog10, std::log10, internal::plog10); CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); for (int i = 0; i < size; ++i) { @@ -1145,6 +1146,14 @@ h.store(data2, internal::plog(h.load(data1))); VERIFY((numext::isinf)(data2[0])); } + if (PacketTraits::HasLog10) { + test::packet_helper<PacketTraits::HasLog10, Packet> h; + data1[0] = Scalar(0); + data1[1] = NumTraits<Scalar>::infinity(); + h.store(data2, internal::plog10(h.load(data1))); + VERIFY_IS_EQUAL(std::log10(Scalar(0)), data2[0]); + VERIFY_IS_EQUAL(std::log10(NumTraits<Scalar>::infinity()), data2[1]); + } if (PacketTraits::HasLog1p) { test::packet_helper<PacketTraits::HasLog1p, Packet> h; data1[0] = Scalar(-2);