vectorize squaredNorm() for complex types
diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h index 82eb9c7..dd4a2c4 100644 --- a/Eigen/src/Core/Dot.h +++ b/Eigen/src/Core/Dot.h
@@ -41,6 +41,20 @@ } }; +template <typename Derived, typename Scalar = typename traits<Derived>::Scalar> +struct squared_norm_impl { + using Real = typename NumTraits<Scalar>::Real; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { + Scalar result = a.unaryExpr(squared_norm_functor<Scalar>()).sum(); + return numext::real(result) + numext::imag(result); + } +}; + +template <typename Derived> +struct squared_norm_impl<Derived, bool> { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const Derived& a) { return a.any(); } +}; + } // end namespace internal /** \fn MatrixBase::dot @@ -85,7 +99,7 @@ template <typename Derived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename NumTraits<typename internal::traits<Derived>::Scalar>::Real MatrixBase<Derived>::squaredNorm() const { - return numext::real((*this).cwiseAbs2().sum()); + return internal::squared_norm_impl<Derived>::run(derived()); } /** \returns, for vectors, the \em l2 norm of \c *this, and for matrices the Frobenius norm.
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 5059a54..b3b7d79 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -103,6 +103,26 @@ enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasAbs2 }; }; +template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex> +struct squared_norm_functor { + typedef Scalar result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + return Scalar(numext::real(a) * numext::real(a), numext::imag(a) * numext::imag(a)); + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { + return Packet(pmul(a.v, a.v)); + } +}; +template <typename Scalar> +struct squared_norm_functor<Scalar, false> : scalar_abs2_op<Scalar> {}; + +template <typename Scalar> +struct functor_traits<squared_norm_functor<Scalar>> { + using Real = typename NumTraits<Scalar>::Real; + enum { Cost = NumTraits<Real>::MulCost, PacketAccess = packet_traits<Real>::HasMul }; +}; + /** \internal * \brief Template functor to compute the conjugate of a complex value *