Fix self-adjoint products when multiplying by a compile-time vector.
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index ce8d954..a230044 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -846,7 +846,7 @@
template <typename Dest>
static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
- selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::IsVectorAtCompileTime>::run(
+ selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::ColsAtCompileTime == 1>::run(
dst, lhs.nestedExpression(), rhs, alpha);
}
};
@@ -858,7 +858,7 @@
template <typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
- selfadjoint_product_impl<Lhs, 0, Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, Rhs::Mode, false>::run(
+ selfadjoint_product_impl<Lhs, 0, Lhs::RowsAtCompileTime == 1, typename Rhs::MatrixType, Rhs::Mode, false>::run(
dst, lhs, rhs.nestedExpression(), alpha);
}
};
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h
index f738760..580f6a8 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixVector.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h
@@ -164,6 +164,11 @@
enum { LhsUpLo = LhsMode & (Upper | Lower) };
+ // Verify that the Rhs is a vector in the correct orientation.
+ // Otherwise, we break the assumption that we are multiplying
+ // MxN * Nx1.
+ static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector.");
+
template <typename Dest>
static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
typedef typename Dest::Scalar ResScalar;
@@ -173,11 +178,6 @@
eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols());
- if (a_lhs.rows() == 1) {
- dest = (alpha * a_lhs.coeff(0, 0)) * a_rhs;
- return;
- }
-
add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);