TriangularView: alias-aware fallback for structured-diagonal product fast path libeigen/eigen!2504 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index fec3a2d..9995bcd 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h
@@ -987,6 +987,32 @@ } }; +// Underlying-storage data pointer for the diagonal operand of a structured x diagonal +// product, or nullptr for non-diagonal operands. The structured (triangular/selfadjoint) +// operand can safely share storage with dst because the kernel reads each (row, col) cell +// before writing it; only diagonal/dst overlap can corrupt later reads via the diagonal +// entries that have already been written. +template <typename Op> +EIGEN_DEVICE_FUNC inline const void* diagonal_operand_data(const Op& op, DiagonalShape) { + return extract_data(op.diagonal()); +} +template <typename Op, typename Shape> +EIGEN_DEVICE_FUNC inline const void* diagonal_operand_data(const Op& /*op*/, Shape) { + return nullptr; +} + +template <typename DstXprType, typename SrcXprType> +EIGEN_DEVICE_FUNC inline bool structured_diagonal_product_aliases(const DstXprType& dst, const SrcXprType& src) { + const void* dst_data = dst.nestedExpression().data(); + if (dst_data == nullptr) return false; + const void* lhs_diag_data = + diagonal_operand_data(src.lhs(), typename evaluator_traits<typename SrcXprType::Lhs>::Shape{}); + const void* rhs_diag_data = + diagonal_operand_data(src.rhs(), typename evaluator_traits<typename SrcXprType::Rhs>::Shape{}); + return (lhs_diag_data != nullptr && lhs_diag_data == dst_data) || + (rhs_diag_data != nullptr && rhs_diag_data == dst_data); +} + template <> struct triangular_product_assignment_dispatcher<true> { template <typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> @@ -996,7 +1022,20 @@ EIGEN_UNUSED_VARIABLE(beta); EIGEN_STATIC_ASSERT((int(DstXprType::Mode) & int(UnitDiag)) == 0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED); - call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func); + // The triangular assignment loop reads src.coeff(row, col) lazily while writing + // dst.coeffRef(row, col). When the diagonal operand of the product shares storage with + // dst (e.g. + // A.triangularView<Upper>() = A.diagonal().asDiagonal() * A.triangularView<Upper>()) + // the diagonal entries already written in earlier columns would feed back as modified + // values, corrupting later reads. Materialize the source into a temporary first when + // overlap is detected at run time. The structured (triangular/selfadjoint) operand may + // safely alias dst because the kernel reads each cell before writing it. + if (structured_diagonal_product_aliases(dst, src)) { + typename SrcXprType::PlainObject tmp(src); + call_triangular_assignment_loop<DstXprType::Mode, false>(dst, tmp, func); + } else { + call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func); + } } };
diff --git a/test/diagonalmatrices.cpp b/test/diagonalmatrices.cpp index 2af0da6..b2cc914 100644 --- a/test/diagonalmatrices.cpp +++ b/test/diagonalmatrices.cpp
@@ -444,6 +444,70 @@ selfadjoint_diagonal_products_at<std::complex<double>>(65); } +// In-place patterns where the diagonal operand shares storage with the destination view. +// The fast path in triangular_product_assignment_dispatcher detects the run-time overlap +// and materializes a temporary, so the result must match a reference computed against a +// materialized source. The structured (triangular/selfadjoint) operand can safely share +// storage with dst because the kernel reads each cell before writing it. +template <unsigned int Mode, typename Mat> +void verify_triangular_in_place_with_aliased_diagonal(const Mat& m) { + // diagonal * tri_view + { + Mat actual = m, expected = m; + Mat ref_diag = actual.diagonal().asDiagonal(); + Mat ref_tri = actual.template triangularView<Mode>(); + expected.template triangularView<Mode>() = (ref_diag * ref_tri).eval(); + actual.template triangularView<Mode>() = actual.diagonal().asDiagonal() * actual.template triangularView<Mode>(); + VERIFY_IS_APPROX(actual, expected); + } + // tri_view * diagonal + { + Mat actual = m, expected = m; + Mat ref_diag = actual.diagonal().asDiagonal(); + Mat ref_tri = actual.template triangularView<Mode>(); + expected.template triangularView<Mode>() = (ref_tri * ref_diag).eval(); + actual.template triangularView<Mode>() = actual.template triangularView<Mode>() * actual.diagonal().asDiagonal(); + VERIFY_IS_APPROX(actual, expected); + } +} + +template <unsigned int Mode, typename Mat> +void verify_selfadjoint_in_place_with_aliased_diagonal(const Mat& m) { + // diagonal * sa_view + { + Mat actual = m, expected = m; + Mat ref_diag = actual.diagonal().asDiagonal(); + Mat ref_sa = actual.template selfadjointView<Mode>(); + expected.template triangularView<Mode>() = (ref_diag * ref_sa).eval(); + actual.template selfadjointView<Mode>() = actual.diagonal().asDiagonal() * actual.template selfadjointView<Mode>(); + VERIFY_IS_APPROX(actual.template triangularView<Mode>().toDenseMatrix(), + expected.template triangularView<Mode>().toDenseMatrix()); + } + // sa_view * diagonal + { + Mat actual = m, expected = m; + Mat ref_diag = actual.diagonal().asDiagonal(); + Mat ref_sa = actual.template selfadjointView<Mode>(); + expected.template triangularView<Mode>() = (ref_sa * ref_diag).eval(); + actual.template selfadjointView<Mode>() = actual.template selfadjointView<Mode>() * actual.diagonal().asDiagonal(); + VERIFY_IS_APPROX(actual.template triangularView<Mode>().toDenseMatrix(), + expected.template triangularView<Mode>().toDenseMatrix()); + } +} + +template <int> +void structured_diagonal_aliasing() { + for (int n : {3, 5, 8, 17, 32, 33, 64, 65}) { + MatrixXcd m = MatrixXcd::Random(n, n); + m.diagonal() = m.diagonal().real(); // Hermitian-friendly diagonal + + verify_triangular_in_place_with_aliased_diagonal<Upper>(m); + verify_triangular_in_place_with_aliased_diagonal<Lower>(m); + verify_selfadjoint_in_place_with_aliased_diagonal<Upper>(m); + verify_selfadjoint_in_place_with_aliased_diagonal<Lower>(m); + } +} + EIGEN_DECLARE_TEST(diagonalmatrices) { for (int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1(diagonalmatrices(Matrix<float, 1, 1>())); @@ -469,4 +533,5 @@ CALL_SUBTEST_10(bug2013<0>()); CALL_SUBTEST_10(selfadjoint_diagonal_products<0>()); CALL_SUBTEST_10(selfadjoint_diagonal_products_block_path<0>()); + CALL_SUBTEST_10(structured_diagonal_aliasing<0>()); }