Introduce `numext::copysign` libeigen/eigen!2436
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 100650c..4b0ca97 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h
@@ -901,6 +901,37 @@ typedef Scalar type; }; +template <typename Scalar, bool IsComplex = (NumTraits<Scalar>::IsComplex != 0), + bool IsInteger = (NumTraits<Scalar>::IsInteger != 0)> +struct copysign_impl { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& a, const Scalar& b) { + EIGEN_USING_STD(copysign); + return Scalar(copysign(a, b)); + } +}; + +template <typename Scalar, bool IsInteger> +struct copysign_impl<Scalar, true, IsInteger> { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& a, const Scalar& b) { + EIGEN_USING_STD(copysign); + return Scalar(copysign(numext::real(a), numext::real(b)), copysign(numext::imag(a), numext::imag(b))); + } +}; + +template <typename Scalar> +struct copysign_impl<Scalar, false, true> { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& a, const Scalar& b) { + EIGEN_IF_CONSTEXPR(!NumTraits<Scalar>::IsSigned) return a; + const Scalar abs_a = a < Scalar(0) ? -a : a; + return b < Scalar(0) ? -abs_a : abs_a; + } +}; + +template <typename Scalar> +struct copysign_retval { + typedef Scalar type; +}; + // suppress "unary minus operator applied to unsigned type, result still unsigned" warnings on MSVC // note: `0 - a` is distinct from `-a` when Scalar is a floating point type and `a` is zero @@ -1181,6 +1212,11 @@ } template <typename Scalar> +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(copysign, Scalar) copysign(const Scalar& x, const Scalar& y) { + return EIGEN_MATHFUNC_IMPL(copysign, Scalar)::run(x, y); +} + +template <typename Scalar> EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(negate, Scalar) negate(const Scalar& x) { return EIGEN_MATHFUNC_IMPL(negate, Scalar)::run(x); }
diff --git a/test/numext.cpp b/test/numext.cpp index 32e9dca..a52987a 100644 --- a/test/numext.cpp +++ b/test/numext.cpp
@@ -34,6 +34,101 @@ #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b)) template <typename T> +struct check_copysign_impl { + static void run() { + const T pos_zero = T(0); + const T pos_one = T(1); + + // Tests valid for all types. + VERIFY_IS_EQUAL(numext::copysign(pos_one, pos_one), pos_one); + VERIFY_IS_EQUAL(numext::copysign(pos_zero, pos_one), pos_zero); + + // Tests valid for all signed types (integer and floating-point). + if (NumTraits<T>::IsSigned) { + const T neg_one = numext::negate(pos_one); + VERIFY_IS_EQUAL(numext::copysign(pos_one, neg_one), neg_one); + VERIFY_IS_EQUAL(numext::copysign(neg_one, pos_one), pos_one); + VERIFY_IS_EQUAL(numext::copysign(neg_one, neg_one), neg_one); + } + + // Tests specific to floating-point types (negative zero, infinity, NaN). + if (!NumTraits<T>::IsInteger) { + const T neg_zero = numext::negate(pos_zero); + const T neg_one = numext::negate(pos_one); + const T pos_inf = std::numeric_limits<T>::infinity(); + const T neg_inf = numext::negate(pos_inf); + const T pos_nan = std::numeric_limits<T>::quiet_NaN(); + const T neg_nan = numext::negate(pos_nan); + // Sign transferred from zero. + VERIFY_IS_EQUAL(numext::copysign(pos_one, pos_zero), pos_one); + VERIFY_IS_EQUAL(numext::copysign(pos_one, neg_zero), neg_one); + // Sign transferred from infinity. + VERIFY_IS_EQUAL(numext::copysign(pos_one, pos_inf), pos_one); + VERIFY_IS_EQUAL(numext::copysign(pos_one, neg_inf), neg_one); + // Sign transferred from NaN. + VERIFY_IS_EQUAL(numext::copysign(pos_one, pos_nan), pos_one); + VERIFY_IS_EQUAL(numext::copysign(pos_one, neg_nan), neg_one); + } + + for (int k = 0; k < 100; ++k) { + // For signed integers avoid lowest() so that abs(a) does not overflow. + const T a = (NumTraits<T>::IsSigned && NumTraits<T>::IsInteger) + ? internal::random<T>(numext::negate(NumTraits<T>::highest()), NumTraits<T>::highest()) + : internal::random<T>(); + const T b = internal::random<T>(); + const T result = numext::copysign(a, b); + // Magnitude is preserved. + VERIFY_IS_EQUAL(numext::abs(result), numext::abs(a)); + // Sign matches sign source. Integers have no negative zero, so the sign + // of the result is only meaningful when a != 0. + if (!NumTraits<T>::IsInteger || a != T(0)) { + VERIFY_IS_EQUAL(numext::copysign(pos_one, result), numext::copysign(pos_one, b)); + } + } + } +}; + +template <typename T> +struct check_copysign_impl<std::complex<T>> { + static void run() { + typedef std::complex<T> ComplexT; + const T pos_one = T(1); + const T neg_one = numext::negate(pos_one); + + // Complex copysign is applied component-wise. + VERIFY_IS_EQUAL(numext::copysign(ComplexT(pos_one, pos_one), ComplexT(pos_one, neg_one)), + ComplexT(pos_one, neg_one)); + VERIFY_IS_EQUAL(numext::copysign(ComplexT(neg_one, pos_one), ComplexT(pos_one, neg_one)), + ComplexT(pos_one, neg_one)); + VERIFY_IS_EQUAL(numext::copysign(ComplexT(pos_one, neg_one), ComplexT(neg_one, pos_one)), + ComplexT(neg_one, pos_one)); + + for (int k = 0; k < 100; ++k) { + const ComplexT a = internal::random<ComplexT>(); + const ComplexT b = internal::random<ComplexT>(); + const ComplexT result = numext::copysign(a, b); + // Each component is independently copysigned. + VERIFY_IS_EQUAL(numext::real(result), numext::copysign(numext::real(a), numext::real(b))); + VERIFY_IS_EQUAL(numext::imag(result), numext::copysign(numext::imag(a), numext::imag(b))); + } + } +}; + +template <typename T> +void check_copysign() { + check_copysign_impl<T>::run(); +} + +template <> +void check_copysign<bool>() { + for (bool a : {false, true}) { + for (bool b : {false, true}) { + VERIFY_IS_EQUAL(numext::copysign(a, b), a); + } + } +} + +template <typename T> void check_negate() { Index size = 1000; for (Index i = 0; i < size; i++) { @@ -334,6 +429,24 @@ EIGEN_DECLARE_TEST(numext) { for (int k = 0; k < g_repeat; ++k) { + CALL_SUBTEST(check_copysign<half>()); + CALL_SUBTEST(check_copysign<bfloat16>()); + CALL_SUBTEST(check_copysign<float>()); + CALL_SUBTEST(check_copysign<double>()); + CALL_SUBTEST(check_copysign<long double>()); + CALL_SUBTEST(check_copysign<std::complex<float>>()); + CALL_SUBTEST(check_copysign<std::complex<double>>()); + + CALL_SUBTEST(check_copysign<bool>()); + CALL_SUBTEST(check_copysign<int8_t>()); + CALL_SUBTEST(check_copysign<int16_t>()); + CALL_SUBTEST(check_copysign<int32_t>()); + CALL_SUBTEST(check_copysign<int64_t>()); + CALL_SUBTEST(check_copysign<uint8_t>()); + CALL_SUBTEST(check_copysign<uint16_t>()); + CALL_SUBTEST(check_copysign<uint32_t>()); + CALL_SUBTEST(check_copysign<uint64_t>()); + CALL_SUBTEST(check_negate<signed char>()); CALL_SUBTEST(check_negate<unsigned char>()); CALL_SUBTEST(check_negate<short>());