TriangularView: alias-aware fallback for structured-diagonal product fast path

libeigen/eigen!2504

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h
index fec3a2d..9995bcd 100644
--- a/Eigen/src/Core/TriangularMatrix.h
+++ b/Eigen/src/Core/TriangularMatrix.h
@@ -987,6 +987,32 @@
   }
 };
 
+// Underlying-storage data pointer for the diagonal operand of a structured x diagonal
+// product, or nullptr for non-diagonal operands. The structured (triangular/selfadjoint)
+// operand can safely share storage with dst because the kernel reads each (row, col) cell
+// before writing it; only diagonal/dst overlap can corrupt later reads via the diagonal
+// entries that have already been written.
+template <typename Op>
+EIGEN_DEVICE_FUNC inline const void* diagonal_operand_data(const Op& op, DiagonalShape) {
+  return extract_data(op.diagonal());
+}
+template <typename Op, typename Shape>
+EIGEN_DEVICE_FUNC inline const void* diagonal_operand_data(const Op& /*op*/, Shape) {
+  return nullptr;
+}
+
+template <typename DstXprType, typename SrcXprType>
+EIGEN_DEVICE_FUNC inline bool structured_diagonal_product_aliases(const DstXprType& dst, const SrcXprType& src) {
+  const void* dst_data = dst.nestedExpression().data();
+  if (dst_data == nullptr) return false;
+  const void* lhs_diag_data =
+      diagonal_operand_data(src.lhs(), typename evaluator_traits<typename SrcXprType::Lhs>::Shape{});
+  const void* rhs_diag_data =
+      diagonal_operand_data(src.rhs(), typename evaluator_traits<typename SrcXprType::Rhs>::Shape{});
+  return (lhs_diag_data != nullptr && lhs_diag_data == dst_data) ||
+         (rhs_diag_data != nullptr && rhs_diag_data == dst_data);
+}
+
 template <>
 struct triangular_product_assignment_dispatcher<true> {
   template <typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
@@ -996,7 +1022,20 @@
     EIGEN_UNUSED_VARIABLE(beta);
     EIGEN_STATIC_ASSERT((int(DstXprType::Mode) & int(UnitDiag)) == 0,
                         WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
-    call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func);
+    // The triangular assignment loop reads src.coeff(row, col) lazily while writing
+    // dst.coeffRef(row, col). When the diagonal operand of the product shares storage with
+    // dst (e.g.
+    //   A.triangularView<Upper>() = A.diagonal().asDiagonal() * A.triangularView<Upper>())
+    // the diagonal entries already written in earlier columns would feed back as modified
+    // values, corrupting later reads. Materialize the source into a temporary first when
+    // overlap is detected at run time. The structured (triangular/selfadjoint) operand may
+    // safely alias dst because the kernel reads each cell before writing it.
+    if (structured_diagonal_product_aliases(dst, src)) {
+      typename SrcXprType::PlainObject tmp(src);
+      call_triangular_assignment_loop<DstXprType::Mode, false>(dst, tmp, func);
+    } else {
+      call_triangular_assignment_loop<DstXprType::Mode, false>(dst, src, func);
+    }
   }
 };
 
diff --git a/test/diagonalmatrices.cpp b/test/diagonalmatrices.cpp
index 2af0da6..b2cc914 100644
--- a/test/diagonalmatrices.cpp
+++ b/test/diagonalmatrices.cpp
@@ -444,6 +444,70 @@
   selfadjoint_diagonal_products_at<std::complex<double>>(65);
 }
 
+// In-place patterns where the diagonal operand shares storage with the destination view.
+// The fast path in triangular_product_assignment_dispatcher detects the run-time overlap
+// and materializes a temporary, so the result must match a reference computed against a
+// materialized source. The structured (triangular/selfadjoint) operand can safely share
+// storage with dst because the kernel reads each cell before writing it.
+template <unsigned int Mode, typename Mat>
+void verify_triangular_in_place_with_aliased_diagonal(const Mat& m) {
+  // diagonal * tri_view
+  {
+    Mat actual = m, expected = m;
+    Mat ref_diag = actual.diagonal().asDiagonal();
+    Mat ref_tri = actual.template triangularView<Mode>();
+    expected.template triangularView<Mode>() = (ref_diag * ref_tri).eval();
+    actual.template triangularView<Mode>() = actual.diagonal().asDiagonal() * actual.template triangularView<Mode>();
+    VERIFY_IS_APPROX(actual, expected);
+  }
+  // tri_view * diagonal
+  {
+    Mat actual = m, expected = m;
+    Mat ref_diag = actual.diagonal().asDiagonal();
+    Mat ref_tri = actual.template triangularView<Mode>();
+    expected.template triangularView<Mode>() = (ref_tri * ref_diag).eval();
+    actual.template triangularView<Mode>() = actual.template triangularView<Mode>() * actual.diagonal().asDiagonal();
+    VERIFY_IS_APPROX(actual, expected);
+  }
+}
+
+template <unsigned int Mode, typename Mat>
+void verify_selfadjoint_in_place_with_aliased_diagonal(const Mat& m) {
+  // diagonal * sa_view
+  {
+    Mat actual = m, expected = m;
+    Mat ref_diag = actual.diagonal().asDiagonal();
+    Mat ref_sa = actual.template selfadjointView<Mode>();
+    expected.template triangularView<Mode>() = (ref_diag * ref_sa).eval();
+    actual.template selfadjointView<Mode>() = actual.diagonal().asDiagonal() * actual.template selfadjointView<Mode>();
+    VERIFY_IS_APPROX(actual.template triangularView<Mode>().toDenseMatrix(),
+                     expected.template triangularView<Mode>().toDenseMatrix());
+  }
+  // sa_view * diagonal
+  {
+    Mat actual = m, expected = m;
+    Mat ref_diag = actual.diagonal().asDiagonal();
+    Mat ref_sa = actual.template selfadjointView<Mode>();
+    expected.template triangularView<Mode>() = (ref_sa * ref_diag).eval();
+    actual.template selfadjointView<Mode>() = actual.template selfadjointView<Mode>() * actual.diagonal().asDiagonal();
+    VERIFY_IS_APPROX(actual.template triangularView<Mode>().toDenseMatrix(),
+                     expected.template triangularView<Mode>().toDenseMatrix());
+  }
+}
+
+template <int>
+void structured_diagonal_aliasing() {
+  for (int n : {3, 5, 8, 17, 32, 33, 64, 65}) {
+    MatrixXcd m = MatrixXcd::Random(n, n);
+    m.diagonal() = m.diagonal().real();  // Hermitian-friendly diagonal
+
+    verify_triangular_in_place_with_aliased_diagonal<Upper>(m);
+    verify_triangular_in_place_with_aliased_diagonal<Lower>(m);
+    verify_selfadjoint_in_place_with_aliased_diagonal<Upper>(m);
+    verify_selfadjoint_in_place_with_aliased_diagonal<Lower>(m);
+  }
+}
+
 EIGEN_DECLARE_TEST(diagonalmatrices) {
   for (int i = 0; i < g_repeat; i++) {
     CALL_SUBTEST_1(diagonalmatrices(Matrix<float, 1, 1>()));
@@ -469,4 +533,5 @@
   CALL_SUBTEST_10(bug2013<0>());
   CALL_SUBTEST_10(selfadjoint_diagonal_products<0>());
   CALL_SUBTEST_10(selfadjoint_diagonal_products_block_path<0>());
+  CALL_SUBTEST_10(structured_diagonal_aliasing<0>());
 }