Fix scalar_logistic_function overflow for complex inputs.
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 3c7dfb7..a3fc44c 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -1091,12 +1091,9 @@
};
};
-/** \internal
- * \brief Template functor to compute the logistic function of a scalar
- * \sa class CwiseUnaryOp, ArrayBase::logistic()
- */
-template <typename T>
-struct scalar_logistic_op {
+// Real-valued implementation.
+template <typename T, typename EnableIf = void>
+struct scalar_logistic_op_impl {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); }
template <typename Packet>
@@ -1109,6 +1106,22 @@
}
};
+// Complex-valud implementation.
+template <typename T>
+struct scalar_logistic_op_impl<T, std::enable_if_t<NumTraits<T>::IsComplex>> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
+ const T e = numext::exp(x);
+ return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1));
+ }
+};
+
+/** \internal
+ * \brief Template functor to compute the logistic function of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::logistic()
+ */
+template <typename T>
+struct scalar_logistic_op : scalar_logistic_op_impl<T> {};
+
// TODO(rmlarsen): Enable the following on host when integer_packet is defined
// for the relevant packet types.
#ifdef EIGEN_GPU_CC
@@ -1206,7 +1219,7 @@
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
(internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
: NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost),
- PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
+ PacketAccess = !NumTraits<T>::IsComplex && packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
(internal::is_same<T, float>::value
? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index bfea96a..9b62969 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -976,7 +976,14 @@
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
- VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1))));
+ VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
+ if (m1.size() > 0) {
+ // Complex exponential overflow edge-case.
+ Scalar old_m1_val = m1(0, 0);
+ m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0);
+ VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
+ m1(0, 0) = old_m1_val; // Restore value for future tests.
+ }
for (Index i = 0; i < m.rows(); ++i)
for (Index j = 0; j < m.cols(); ++j)