Fix scalar_logistic_function overflow for complex inputs.
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 3c7dfb7..a3fc44c 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -1091,12 +1091,9 @@ }; }; -/** \internal - * \brief Template functor to compute the logistic function of a scalar - * \sa class CwiseUnaryOp, ArrayBase::logistic() - */ -template <typename T> -struct scalar_logistic_op { +// Real-valued implementation. +template <typename T, typename EnableIf = void> +struct scalar_logistic_op_impl { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); } template <typename Packet> @@ -1109,6 +1106,22 @@ } }; +// Complex-valud implementation. +template <typename T> +struct scalar_logistic_op_impl<T, std::enable_if_t<NumTraits<T>::IsComplex>> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { + const T e = numext::exp(x); + return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1)); + } +}; + +/** \internal + * \brief Template functor to compute the logistic function of a scalar + * \sa class CwiseUnaryOp, ArrayBase::logistic() + */ +template <typename T> +struct scalar_logistic_op : scalar_logistic_op_impl<T> {}; + // TODO(rmlarsen): Enable the following on host when integer_packet is defined // for the relevant packet types. #ifdef EIGEN_GPU_CC @@ -1206,7 +1219,7 @@ Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value + (internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11 : NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost), - PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv && + PacketAccess = !NumTraits<T>::IsComplex && packet_traits<T>::HasAdd && packet_traits<T>::HasDiv && (internal::is_same<T, float>::value ? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin : packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index bfea96a..9b62969 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp
@@ -976,7 +976,14 @@ VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1))); VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1))); VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1)))); - VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1)))); + VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1)))); + if (m1.size() > 0) { + // Complex exponential overflow edge-case. + Scalar old_m1_val = m1(0, 0); + m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0); + VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1)))); + m1(0, 0) = old_m1_val; // Restore value for future tests. + } for (Index i = 0; i < m.rows(); ++i) for (Index j = 0; j < m.cols(); ++j)