Add support for sparse * dense and dense * sparse matrix/vector products
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index eecd24c..d342a89 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h
@@ -250,10 +250,6 @@ Derived& lazyAssign(const Flagged<OtherDerived, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) { return lazyAssign(other._expression()); } - /** Overloaded for sparse product evaluation */ - /*template<typename Derived1, typename Derived2> - Derived& lazyAssign(const Product<Derived1,Derived2,SparseProduct>& product);*/ - CommaInitializer<Derived> operator<< (const Scalar& s); template<typename OtherDerived> @@ -615,6 +611,15 @@ PlainMatrixType unitOrthogonal(void) const; Matrix<Scalar,3,1> eulerAngles(int a0, int a1, int a2) const; +/////////// Sparse module /////////// + + // dense = spasre * dense + template<typename Derived1, typename Derived2> + Derived& lazyAssign(const SparseProduct<Derived1,Derived2,SparseTimeDenseProduct>& product); + // dense = dense * spasre + template<typename Derived1, typename Derived2> + Derived& lazyAssign(const SparseProduct<Derived1,Derived2,DenseTimeSparseProduct>& product); + #ifdef EIGEN_MATRIXBASE_PLUGIN #include EIGEN_MATRIXBASE_PLUGIN #endif
diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index f2c76cc..05df011 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h
@@ -201,7 +201,7 @@ enum { ConditionalJumpCost = 5 }; enum CornerType { TopLeft, TopRight, BottomLeft, BottomRight }; enum DirectionType { Vertical, Horizontal }; -enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct }; +enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct, SparseTimeSparseProduct, SparseTimeDenseProduct, DenseTimeSparseProduct }; enum { /** \internal Equivalent to a slice vectorization for fixed-size matrices having good alignment
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index c194882..a45210e 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -122,4 +122,7 @@ template<typename Scalar,int Dim> class Translation; template<typename Scalar,int Dim> class Scaling; +// Sparse module: +template<typename Lhs, typename Rhs, int ProductMode> class SparseProduct; + #endif // EIGEN_FORWARDDECLARATIONS_H
diff --git a/Eigen/src/Sparse/SparseMatrix.h b/Eigen/src/Sparse/SparseMatrix.h index a732bdc..07fc0be 100644 --- a/Eigen/src/Sparse/SparseMatrix.h +++ b/Eigen/src/Sparse/SparseMatrix.h
@@ -314,9 +314,10 @@ // 1 - compute the number of coeffs per dest inner vector // 2 - do the actual copy/eval // Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed - typedef typename ei_nested<OtherDerived,2>::type OtherCopy; - OtherCopy otherCopy(other.derived()); + //typedef typename ei_nested<OtherDerived,2>::type OtherCopy; + typedef typename ei_eval<OtherDerived>::type OtherCopy; typedef typename ei_cleantype<OtherCopy>::type _OtherCopy; + OtherCopy otherCopy(other.derived()); resize(other.rows(), other.cols()); Eigen::Map<VectorXi>(m_outerIndex,outerSize()).setZero();
diff --git a/Eigen/src/Sparse/SparseMatrixBase.h b/Eigen/src/Sparse/SparseMatrixBase.h index d01fa1e..14ac4e1 100644 --- a/Eigen/src/Sparse/SparseMatrixBase.h +++ b/Eigen/src/Sparse/SparseMatrixBase.h
@@ -213,7 +213,7 @@ } template<typename Lhs, typename Rhs> - inline Derived& operator=(const SparseProduct<Lhs,Rhs>& product); + inline Derived& operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product); friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) { @@ -291,6 +291,16 @@ template<typename OtherDerived> const typename SparseProductReturnType<Derived,OtherDerived>::Type operator*(const SparseMatrixBase<OtherDerived> &other) const; + + // dense * sparse (return a dense object) + template<typename OtherDerived> friend + const typename SparseProductReturnType<OtherDerived,Derived>::Type + operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs) + { return typename SparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); } + + template<typename OtherDerived> + const typename SparseProductReturnType<Derived,OtherDerived>::Type + operator*(const MatrixBase<OtherDerived> &other) const; template<typename OtherDerived> Derived& operator*=(const SparseMatrixBase<OtherDerived>& other);
diff --git a/Eigen/src/Sparse/SparseProduct.h b/Eigen/src/Sparse/SparseProduct.h index b4ba2ee..29f5208 100644 --- a/Eigen/src/Sparse/SparseProduct.h +++ b/Eigen/src/Sparse/SparseProduct.h
@@ -25,9 +25,29 @@ #ifndef EIGEN_SPARSEPRODUCT_H #define EIGEN_SPARSEPRODUCT_H +template<typename Lhs, typename Rhs> struct ei_sparse_product_mode +{ + enum { + + value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeSparseProduct + : (Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeDenseProduct + : DenseTimeSparseProduct }; +}; + +template<typename Lhs, typename Rhs, int ProductMode> +struct SparseProductReturnType +{ + typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested; + typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested; + + typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type; +}; + // sparse product return type specialization template<typename Lhs, typename Rhs> -struct SparseProductReturnType +struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct> { typedef typename ei_traits<Lhs>::Scalar Scalar; enum { @@ -47,11 +67,11 @@ SparseMatrix<Scalar,0>, const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested; - typedef SparseProduct<LhsNested, RhsNested> Type; + typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type; }; -template<typename LhsNested, typename RhsNested> -struct ei_traits<SparseProduct<LhsNested, RhsNested> > +template<typename LhsNested, typename RhsNested, int ProductMode> +struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> > { // clean the nested types: typedef typename ei_cleantype<LhsNested>::type _LhsNested; @@ -71,12 +91,13 @@ MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, - LhsRowMajor = LhsFlags & RowMajorBit, - RhsRowMajor = RhsFlags & RowMajorBit, +// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit, +// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit, EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), + ResultIsSparse = ProductMode==SparseTimeSparseProduct, - RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), + RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ), Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) | EvalBeforeAssigningBit @@ -84,11 +105,14 @@ CoeffReadCost = Dynamic }; + + typedef typename ei_meta_if<ResultIsSparse, + SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >, + MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base; }; -template<typename LhsNested, typename RhsNested> -class SparseProduct : ei_no_assignment_operator, - public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> > +template<typename LhsNested, typename RhsNested, int ProductMode> +class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base { public: @@ -102,17 +126,33 @@ public: template<typename Lhs, typename Rhs> - inline SparseProduct(const Lhs& lhs, const Rhs& rhs) + EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { ei_assert(lhs.cols() == rhs.rows()); + + enum { + ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic + || _RhsNested::RowsAtCompileTime==Dynamic + || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime), + AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime, + SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested) + }; + // note to the lost user: + // * for a dot product use: v1.dot(v2) + // * for a coeff-wise product use: v1.cwise()*v2 + EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), + INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) + EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), + INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) + EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) } - inline int rows() const { return m_lhs.rows(); } - inline int cols() const { return m_rhs.cols(); } + EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); } + EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); } - const _LhsNested& lhs() const { return m_lhs; } - const _LhsNested& rhs() const { return m_rhs; } + EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } + EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } protected: LhsNested m_lhs; @@ -240,9 +280,10 @@ // return derived(); // } +// sparse = sparse * sparse template<typename Derived> template<typename Lhs, typename Rhs> -inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product) +inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product) { // std::cout << "sparse product to sparse\n"; ei_sparse_product_selector< @@ -252,26 +293,51 @@ return derived(); } +// dense = sparse * dense +template<typename Derived> +template<typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product) +{ + typedef typename ei_cleantype<Lhs>::type _Lhs; + typedef typename _Lhs::InnerIterator LhsInnerIterator; + enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j<product.lhs().outerSize(); ++j) + for (LhsInnerIterator i(product.lhs(),j); i; ++i) + derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j); + return derived(); +} + +// dense = dense * sparse +template<typename Derived> +template<typename Lhs, typename Rhs> +Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product) +{ + typedef typename ei_cleantype<Rhs>::type _Rhs; + typedef typename _Rhs::InnerIterator RhsInnerIterator; + enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j<product.rhs().outerSize(); ++j) + for (RhsInnerIterator i(product.rhs(),j); i; ++i) + derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index()); + return derived(); +} + +// sparse * sparse template<typename Derived> template<typename OtherDerived> -inline const typename SparseProductReturnType<Derived,OtherDerived>::Type +EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const { - enum { - ProductIsValid = Derived::ColsAtCompileTime==Dynamic - || OtherDerived::RowsAtCompileTime==Dynamic - || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime), - AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime, - SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived) - }; - // note to the lost user: - // * for a dot product use: v1.dot(v2) - // * for a coeff-wise product use: v1.cwise()*v2 - EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), - INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) - EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), - INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) - EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) + return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); +} + +// sparse * dense +template<typename Derived> +template<typename OtherDerived> +EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type +SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const +{ return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); }
diff --git a/Eigen/src/Sparse/SparseUtil.h b/Eigen/src/Sparse/SparseUtil.h index 724fb9e..046523d 100644 --- a/Eigen/src/Sparse/SparseUtil.h +++ b/Eigen/src/Sparse/SparseUtil.h
@@ -109,10 +109,10 @@ template<typename Derived> class SparseCwise; template<typename UnaryOp, typename MatrixType> class SparseCwiseUnaryOp; template<typename BinaryOp, typename Lhs, typename Rhs> class SparseCwiseBinaryOp; -template<typename Lhs, typename Rhs> class SparseProduct; template<typename ExpressionType, unsigned int Added, unsigned int Removed> class SparseFlagged; -template<typename Lhs, typename Rhs> struct SparseProductReturnType; +template<typename Lhs, typename Rhs> struct ei_sparse_product_mode; +template<typename Lhs, typename Rhs, int ProductMode = ei_sparse_product_mode<Lhs,Rhs>::value> struct SparseProductReturnType; const int AccessPatternNotSupported = 0x0; const int AccessPatternSupported = 0x1;
diff --git a/test/sparse_basic.cpp b/test/sparse_basic.cpp index 54272d8..07a38dd 100644 --- a/test/sparse_basic.cpp +++ b/test/sparse_basic.cpp
@@ -216,6 +216,7 @@ DenseMatrix refMat2 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat3 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat4 = DenseMatrix::Zero(rows, rows); + DenseMatrix dm4 = DenseMatrix::Zero(rows, rows); SparseMatrix<Scalar> m2(rows, rows); SparseMatrix<Scalar> m3(rows, rows); SparseMatrix<Scalar> m4(rows, rows); @@ -226,6 +227,18 @@ VERIFY_IS_APPROX(m4=m2.transpose()*m3, refMat4=refMat2.transpose()*refMat3); VERIFY_IS_APPROX(m4=m2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); VERIFY_IS_APPROX(m4=m2*m3.transpose(), refMat4=refMat2*refMat3.transpose()); + + // sparse * dense + VERIFY_IS_APPROX(dm4=m2*refMat3, refMat4=refMat2*refMat3); + VERIFY_IS_APPROX(dm4=m2*refMat3.transpose(), refMat4=refMat2*refMat3.transpose()); + VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3, refMat4=refMat2.transpose()*refMat3); + VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); + + // dense * sparse + VERIFY_IS_APPROX(dm4=refMat2*m3, refMat4=refMat2*refMat3); + VERIFY_IS_APPROX(dm4=refMat2*m3.transpose(), refMat4=refMat2*refMat3.transpose()); + VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3, refMat4=refMat2.transpose()*refMat3); + VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); } }