Fix left scalar multiplication for TriangularView and SparseSelfAdjointView libeigen/eigen!2400 Closes #1398 and #1536
diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 62d0729..ff450eb 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h
@@ -114,6 +114,12 @@ return Product<OtherDerived, SelfAdjointView>(lhs.derived(), rhs); } + EIGEN_DEVICE_FUNC const + SelfAdjointView<const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(MatrixType, Scalar, product), UpLo> + operator*(const Scalar& s) const { + return (nestedExpression() * s).template selfadjointView<UpLo>(); + } + friend EIGEN_DEVICE_FUNC const SelfAdjointView<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar, MatrixType, product), UpLo> operator*(const Scalar& s, const SelfAdjointView& mat) {
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index e219e51..2b0f56b 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h
@@ -413,6 +413,21 @@ return Product<OtherDerived, TriangularViewType>(lhs.derived(), rhs.derived()); } + // Scaling a unit triangular view would break its implicit unit diagonal, so only non-unit modes participate. + template <unsigned int M = Mode, std::enable_if_t<(M & UnitDiag) == 0, int> = 0> + EIGEN_DEVICE_FUNC const + TriangularView<const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(MatrixType, Scalar, product), Mode> + operator*(const Scalar& s) const { + return (derived().nestedExpression() * s).template triangularView<Mode>(); + } + + template <unsigned int M = Mode, std::enable_if_t<(M & UnitDiag) == 0, int> = 0> + friend EIGEN_DEVICE_FUNC const + TriangularView<const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar, MatrixType, product), Mode> + operator*(const Scalar& s, const TriangularViewImpl& mat) { + return (s * mat.derived().nestedExpression()).template triangularView<Mode>(); + } + /** \returns the product of the inverse of \c *this with \a other, \a *this being triangular. * * This function computes the inverse-matrix matrix product inverse(\c *this) * \a other if
diff --git a/Eigen/src/SparseCore/SparseSelfAdjointView.h b/Eigen/src/SparseCore/SparseSelfAdjointView.h index 9b290dd..32fbcd4 100644 --- a/Eigen/src/SparseCore/SparseSelfAdjointView.h +++ b/Eigen/src/SparseCore/SparseSelfAdjointView.h
@@ -64,6 +64,8 @@ typedef Matrix<StorageIndex, Dynamic, 1> VectorI; typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested; typedef internal::remove_all_t<MatrixTypeNested> MatrixTypeNested_; + typedef SparseMatrix<Scalar, (MatrixTypeNested_::Flags & RowMajorBit) ? RowMajor : ColMajor, StorageIndex> + PlainObject; explicit inline SparseSelfAdjointView(MatrixType& matrix) : m_matrix(matrix) { eigen_assert(rows() == cols() && "SelfAdjointView is only for squared matrices"); @@ -114,6 +116,16 @@ return Product<OtherDerived, SparseSelfAdjointView>(lhs.derived(), rhs); } + // Scalar multiplication intentionally materializes the full matrix, unlike dense SelfAdjointView's lazy wrapper, + // matching the existing SparseSelfAdjointView products. + PlainObject operator*(const Scalar& s) const { return s * *this; } + + friend PlainObject operator*(const Scalar& s, const SparseSelfAdjointView& mat) { + PlainObject res(mat); + res *= s; + return res; + } + /** Perform a symmetric rank K update of the selfadjoint matrix \c *this: * \f$ this = this + \alpha ( u u^* ) \f$ where \a u is a vector or matrix. *
diff --git a/test/selfadjoint.cpp b/test/selfadjoint.cpp index 01e3806..65d4e79 100644 --- a/test/selfadjoint.cpp +++ b/test/selfadjoint.cpp
@@ -42,6 +42,18 @@ m4 = m2; m4 -= m1.template selfadjointView<Lower>(); VERIFY_IS_APPROX(m4, m2 - m3); + + Scalar s = internal::random<Scalar>(); + + m4 = s * m1.template selfadjointView<Upper>(); + VERIFY_IS_APPROX(m4, MatrixType((s * m1).template selfadjointView<Upper>())); + m4 = m1.template selfadjointView<Upper>() * s; + VERIFY_IS_APPROX(m4, MatrixType((m1 * s).template selfadjointView<Upper>())); + + m4 = s * m1.template selfadjointView<Lower>(); + VERIFY_IS_APPROX(m4, MatrixType((s * m1).template selfadjointView<Lower>())); + m4 = m1.template selfadjointView<Lower>() * s; + VERIFY_IS_APPROX(m4, MatrixType((m1 * s).template selfadjointView<Lower>())); } void bug_159() {
diff --git a/test/sparse_basic.cpp b/test/sparse_basic.cpp index 4f1d447..d583d9d 100644 --- a/test/sparse_basic.cpp +++ b/test/sparse_basic.cpp
@@ -796,10 +796,28 @@ m3 -= m2.template selfadjointView<Lower>(); VERIFY_IS_APPROX(m3, refMat3); + Scalar s2 = internal::random<Scalar>(); + refMat3 = DenseMatrix(refMat2.template selfadjointView<Upper>()); + refMat3 *= s2; + SparseMatrixType m4 = s2 * m2.template selfadjointView<Upper>(); + VERIFY_IS_APPROX(m4, refMat3); + refMat3 = DenseMatrix(refMat2.template selfadjointView<Upper>()); + refMat3 *= s2; + m4 = m2.template selfadjointView<Upper>() * s2; + VERIFY_IS_APPROX(m4, refMat3); + refMat3 = DenseMatrix(refMat2.template selfadjointView<Lower>()); + refMat3 *= s2; + m4 = s2 * m2.template selfadjointView<Lower>(); + VERIFY_IS_APPROX(m4, refMat3); + refMat3 = DenseMatrix(refMat2.template selfadjointView<Lower>()); + refMat3 *= s2; + m4 = m2.template selfadjointView<Lower>() * s2; + VERIFY_IS_APPROX(m4, refMat3); + // selfadjointView only works for square matrices: - SparseMatrixType m4(rows, rows + 1); - VERIFY_RAISES_ASSERT(m4.template selfadjointView<Lower>()); - VERIFY_RAISES_ASSERT(m4.template selfadjointView<Upper>()); + SparseMatrixType m5(rows, rows + 1); + VERIFY_RAISES_ASSERT(m5.template selfadjointView<Lower>()); + VERIFY_RAISES_ASSERT(m5.template selfadjointView<Upper>()); } // test sparseView
diff --git a/test/triangular.cpp b/test/triangular.cpp index a539715..ea4420b 100644 --- a/test/triangular.cpp +++ b/test/triangular.cpp
@@ -13,6 +13,55 @@ #include "main.h" +template <typename ViewType, typename = void> +struct has_left_scalar_multiply : std::false_type {}; + +template <typename ViewType> +struct has_left_scalar_multiply< + ViewType, internal::void_t<decltype(std::declval<typename ViewType::Scalar>() * std::declval<const ViewType&>())>> + : std::true_type {}; + +template <typename ViewType, typename = void> +struct has_right_scalar_multiply : std::false_type {}; + +template <typename ViewType> +struct has_right_scalar_multiply< + ViewType, internal::void_t<decltype(std::declval<const ViewType&>() * std::declval<typename ViewType::Scalar>())>> + : std::true_type {}; + +template <unsigned int Mode, typename MatrixType> +void triangular_scalar_multiply(const MatrixType& m) { + typedef typename MatrixType::Scalar Scalar; + + const Index rows = m.rows(); + const Index cols = m.cols(); + + const Scalar s = internal::random<Scalar>(); + const MatrixType triangular = MatrixType::Random(rows, cols); + + VERIFY_IS_APPROX((s * triangular.template triangularView<Mode>()).toDenseMatrix(), + (s * triangular).template triangularView<Mode>().toDenseMatrix()); + VERIFY_IS_APPROX((triangular.template triangularView<Mode>() * s).toDenseMatrix(), + (triangular * s).template triangularView<Mode>().toDenseMatrix()); +} + +template <typename MatrixType> +void triangular_scalar_multiply_sfinae() { + typedef decltype(std::declval<MatrixType&>().template triangularView<Lower>()) LowerView; + typedef decltype(std::declval<MatrixType&>().template triangularView<StrictlyLower>()) StrictlyLowerView; + typedef decltype(std::declval<MatrixType&>().template triangularView<StrictlyUpper>()) StrictlyUpperView; + typedef decltype(std::declval<MatrixType&>().template triangularView<UnitLower>()) UnitLowerView; + + STATIC_CHECK((has_left_scalar_multiply<LowerView>::value)); + STATIC_CHECK((has_right_scalar_multiply<LowerView>::value)); + STATIC_CHECK((has_left_scalar_multiply<StrictlyLowerView>::value)); + STATIC_CHECK((has_right_scalar_multiply<StrictlyLowerView>::value)); + STATIC_CHECK((has_left_scalar_multiply<StrictlyUpperView>::value)); + STATIC_CHECK((has_right_scalar_multiply<StrictlyUpperView>::value)); + STATIC_CHECK((!has_left_scalar_multiply<UnitLowerView>::value)); + STATIC_CHECK((!has_right_scalar_multiply<UnitLowerView>::value)); +} + template <typename MatrixType> void triangular_deprecated(const MatrixType& m) { Index rows = m.rows(); @@ -42,6 +91,8 @@ typedef typename NumTraits<Scalar>::Real RealScalar; typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType; + triangular_scalar_multiply_sfinae<MatrixType>(); + RealScalar largerEps = 10 * test_precision<RealScalar>(); Index rows = m.rows(); @@ -151,6 +202,10 @@ m6.setRandom(); VERIFY_IS_APPROX(m1.template triangularView<Upper>() * m5, m3 * m5); VERIFY_IS_APPROX(m6 * m1.template triangularView<Upper>(), m6 * m3); + triangular_scalar_multiply<Upper>(m1); + triangular_scalar_multiply<Lower>(m1); + triangular_scalar_multiply<StrictlyUpper>(m1); + triangular_scalar_multiply<StrictlyLower>(m1); m1up = m1.template triangularView<Upper>(); VERIFY_IS_APPROX(m1.template selfadjointView<Upper>().template triangularView<Upper>().toDenseMatrix(), m1up); @@ -228,6 +283,10 @@ m1.setZero(); m1.template triangularView<StrictlyLower>() = 3 * m2; VERIFY_IS_APPROX(m3.template triangularView<StrictlyLower>().toDenseMatrix(), m1); + triangular_scalar_multiply<Upper>(m1); + triangular_scalar_multiply<Lower>(m1); + triangular_scalar_multiply<StrictlyUpper>(m1); + triangular_scalar_multiply<StrictlyLower>(m1); m1.setRandom(); m2 = m1.template triangularView<Upper>(); VERIFY(m2.isUpperTriangular());