Block-tile and bypass setZero for SelfAdjointView × Diagonal (perf for !2486)

libeigen/eigen!2501

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 08afb1f..2363c1f 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -858,6 +858,12 @@
                                                         const Alpha& alpha) {
     dst += alpha * diagonal.segment(begin, dst.size()).cwiseProduct(coeffs);
   }
+  template <typename DstSegment, typename Coeffs, typename DiagonalType>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void runOverwrite(DstSegment& dst, const Coeffs& coeffs,
+                                                                 const DiagonalType& diagonal, Index begin,
+                                                                 Index /*col*/) {
+    dst = diagonal.segment(begin, dst.size()).cwiseProduct(coeffs);
+  }
 };
 
 template <>
@@ -868,6 +874,12 @@
                                                         const Alpha& alpha) {
     dst += alpha * (coeffs * diagonal.coeff(col));
   }
+  template <typename DstSegment, typename Coeffs, typename DiagonalType>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void runOverwrite(DstSegment& dst, const Coeffs& coeffs,
+                                                                 const DiagonalType& diagonal, Index /*begin*/,
+                                                                 Index col) {
+    dst = coeffs * diagonal.coeff(col);
+  }
 };
 
 template <int Mode, int ProductOrder, typename MatrixType, typename DiagonalType>
@@ -977,43 +989,131 @@
 
 template <int Mode, int ProductOrder, typename MatrixType, typename DiagonalType>
 struct selfadjoint_diagonal_product_impl {
+  // Accumulating: dst += alpha * (matrix.selfadjointView<Mode>() * diagonal[asDiagonal])
+  // (with the diagonal on the right or left as ProductOrder dictates).
   template <typename Dest, typename Alpha>
   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const MatrixType& matrix,
                                                         const DiagonalType& diagonal, const Alpha& alpha) {
+    runImpl<true>(dst, matrix, diagonal, alpha);
+  }
+
+  // Overwriting: dst = matrix.selfadjointView<Mode>() * diagonal[asDiagonal].
+  // Each output entry is written exactly once, so the caller can skip the
+  // dst.setZero() pass that generic_product_impl_base::evalTo would do.
+  template <typename Dest>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void runOverwrite(Dest& dst, const MatrixType& matrix,
+                                                                 const DiagonalType& diagonal) {
+    using Scalar = typename traits<MatrixType>::Scalar;
+    runImpl<false>(dst, matrix, diagonal, Scalar(1));
+  }
+
+ private:
+  // Tile size for the blocked mirror pass. Tuned so that one BlockSize x
+  // BlockSize source tile fits comfortably in L1 across common scalar / SIMD
+  // combinations: 32x32 of double = 8 KB, of complex<double> = 16 KB, both
+  // well under typical 32 KB L1. Smaller tiles leave SIMD work on the table
+  // for AVX/AVX-512; larger tiles spill out of L1 on machines with smaller
+  // caches.
+  static constexpr Index BlockSize = 32;
+
+  // The mirror half writes the strict-other-triangle of dst. The naive
+  // per-column form reads matrix.row(col).segment(...) which has stride =
+  // leading dimension, so on an N x N source it streams cold cache lines.
+  // We walk the mirror in BlockSize x BlockSize tiles instead: off-diagonal
+  // tiles use a blocked conjugate-transpose, and the small diagonal tile
+  // falls back to the per-column row-strided loop where the working set is
+  // L1-hot.
+  template <bool Accumulate, typename Dest, typename Alpha>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void runImpl(Dest& dst, const MatrixType& matrix,
+                                                            const DiagonalType& diagonal, const Alpha& alpha) {
     eigen_assert(matrix.rows() == matrix.cols() && "SelfAdjointView is only for squared matrices");
     eigen_assert(diagonal.size() == matrix.rows() && "invalid matrix product");
 
     const Index size = matrix.rows();
+
+    // Stored half: one column-strided segment per output column.
     for (Index col = 0; col < size; ++col) {
       if ((Mode & Upper) == Upper) {
-        addStoredSegment(dst, matrix, diagonal, 0, col + 1, col, alpha);
-        addConjugateSegment(dst, matrix, diagonal, col + 1, size - col - 1, col, alpha);
+        storedSegment<Accumulate>(dst, matrix, diagonal, 0, col + 1, col, alpha);
       } else {
-        addConjugateSegment(dst, matrix, diagonal, 0, col, col, alpha);
-        addStoredSegment(dst, matrix, diagonal, col, size - col, col, alpha);
+        storedSegment<Accumulate>(dst, matrix, diagonal, col, size - col, col, alpha);
+      }
+    }
+
+    // Mirror half.
+    for (Index ib = 0; ib < size; ib += BlockSize) {
+      const Index ib_end = numext::mini(size, ib + BlockSize);
+      const Index br = ib_end - ib;
+      if ((Mode & Upper) == Upper) {
+        // Off-diagonal: write strict-lower of dst from strict-upper of source.
+        for (Index jb = 0; jb < ib; jb += BlockSize) {
+          const Index bc = numext::mini(jb + BlockSize, ib) - jb;
+          mirrorBlock<Accumulate>(dst, matrix, diagonal, ib, jb, br, bc, alpha);
+        }
+        // Diagonal tile: in-tile strict-lower mirror.
+        for (Index col = ib; col < ib_end; ++col)
+          conjugateSegment<Accumulate>(dst, matrix, diagonal, col + 1, ib_end - col - 1, col, alpha);
+      } else {
+        // Off-diagonal: write strict-upper of dst from strict-lower of source.
+        for (Index jb = ib_end; jb < size; jb += BlockSize) {
+          const Index bc = numext::mini(size, jb + BlockSize) - jb;
+          mirrorBlock<Accumulate>(dst, matrix, diagonal, ib, jb, br, bc, alpha);
+        }
+        // Diagonal tile: in-tile strict-upper mirror.
+        for (Index col = ib; col < ib_end; ++col)
+          conjugateSegment<Accumulate>(dst, matrix, diagonal, ib, col - ib, col, alpha);
       }
     }
   }
 
- private:
-  template <typename Dest, typename Alpha>
-  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addStoredSegment(Dest& dst, const MatrixType& matrix,
+  template <bool Accumulate, typename Dest, typename Alpha>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void storedSegment(Dest& dst, const MatrixType& matrix,
+                                                                  const DiagonalType& diagonal, Index begin, Index size,
+                                                                  Index col, const Alpha& alpha) {
+    if (size <= 0) return;
+    auto dstSegment = dst.col(col).segment(begin, size);
+    auto srcSegment = matrix.col(col).segment(begin, size);
+    if (Accumulate)
+      diagonal_product_segment_impl<ProductOrder>::run(dstSegment, srcSegment, diagonal, begin, col, alpha);
+    else
+      diagonal_product_segment_impl<ProductOrder>::runOverwrite(dstSegment, srcSegment, diagonal, begin, col);
+  }
+
+  template <bool Accumulate, typename Dest, typename Alpha>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void conjugateSegment(Dest& dst, const MatrixType& matrix,
                                                                      const DiagonalType& diagonal, Index begin,
                                                                      Index size, Index col, const Alpha& alpha) {
     if (size <= 0) return;
     auto dstSegment = dst.col(col).segment(begin, size);
-    diagonal_product_segment_impl<ProductOrder>::run(dstSegment, matrix.col(col).segment(begin, size), diagonal, begin,
-                                                     col, alpha);
+    auto srcSegment = matrix.row(col).segment(begin, size).conjugate().transpose();
+    if (Accumulate)
+      diagonal_product_segment_impl<ProductOrder>::run(dstSegment, srcSegment, diagonal, begin, col, alpha);
+    else
+      diagonal_product_segment_impl<ProductOrder>::runOverwrite(dstSegment, srcSegment, diagonal, begin, col);
   }
 
-  template <typename Dest, typename Alpha>
-  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addConjugateSegment(Dest& dst, const MatrixType& matrix,
-                                                                        const DiagonalType& diagonal, Index begin,
-                                                                        Index size, Index col, const Alpha& alpha) {
-    if (size <= 0) return;
-    auto dstSegment = dst.col(col).segment(begin, size);
-    diagonal_product_segment_impl<ProductOrder>::run(
-        dstSegment, matrix.row(col).segment(begin, size).conjugate().transpose(), diagonal, begin, col, alpha);
+  // dst.block(ib, jb, br, bc) [+= alpha *] matrix.block(jb, ib, bc, br).adjoint() * <diag>,
+  // where <diag> scales each output column (OnTheRight) or row (OnTheLeft).
+  // Loop bounds in runImpl guarantee br > 0 && bc > 0.
+  template <bool Accumulate, typename Dest, typename Alpha>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void mirrorBlock(Dest& dst, const MatrixType& matrix,
+                                                                const DiagonalType& diagonal, Index ib, Index jb,
+                                                                Index br, Index bc, const Alpha& alpha) {
+    auto dstBlock = dst.block(ib, jb, br, bc);
+    auto srcAdjoint = matrix.block(jb, ib, bc, br).adjoint();
+    if (ProductOrder == OnTheRight) {
+      auto scaled = srcAdjoint * diagonal.segment(jb, bc).asDiagonal();
+      if (Accumulate)
+        dstBlock.noalias() += alpha * scaled;
+      else
+        dstBlock.noalias() = scaled;
+    } else {
+      auto scaled = diagonal.segment(ib, br).asDiagonal() * srcAdjoint;
+      if (Accumulate)
+        dstBlock.noalias() += alpha * scaled;
+      else
+        dstBlock.noalias() = scaled;
+    }
   }
 };
 
@@ -1045,13 +1145,42 @@
 struct generic_product_impl<Lhs, Rhs, SelfAdjointShape, DiagonalShape, ProductTag>
     : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, SelfAdjointShape, DiagonalShape, ProductTag>> {
   typedef typename Product<Lhs, Rhs>::Scalar Scalar;
+  // The "Dense ?= scalar * Product" rewriting rule folds an outer alpha into the
+  // SelfAdjointView via SelfAdjointView::operator*(scalar), whose nested
+  // expression becomes (matrix * alpha). For complex alpha this is no longer
+  // Hermitian — the mirror half of our kernel would produce conj(alpha) on the
+  // off-triangle. Strip the scalar factor with blas_traits and re-fold it into
+  // the kernel's alpha so the same scalar multiplies every output entry.
+  using LhsBlasTraits = blas_traits<typename Lhs::MatrixType>;
+  using ActualLhsMatrix = decltype(LhsBlasTraits::extract(std::declval<const typename Lhs::MatrixType&>())
+                                       .template conjugateIf<bool(LhsBlasTraits::NeedToConjugate)>());
+  using ActualLhsMatrixType = remove_all_t<ActualLhsMatrix>;
+  typedef selfadjoint_diagonal_product_impl<Lhs::Mode, OnTheRight, ActualLhsMatrixType,
+                                            typename Rhs::DiagonalVectorType>
+      Kernel;
+
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ActualLhsMatrix actualLhsMatrix(const typename Lhs::MatrixType& matrix) {
+    return LhsBlasTraits::extract(matrix).template conjugateIf<bool(LhsBlasTraits::NeedToConjugate)>();
+  }
+
+  template <typename Dest>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
+    if (LhsBlasTraits::HasScalarFactor) {
+      // Folded scalar factor present: zero dst then accumulate at the extracted alpha.
+      Scalar factor = LhsBlasTraits::extractScalarFactor(lhs.nestedExpression());
+      dst.setZero();
+      Kernel::run(dst, actualLhsMatrix(lhs.nestedExpression()), rhs.diagonal(), factor);
+    } else {
+      // No scalar factor: kernel writes every entry exactly once, skip setZero.
+      Kernel::runOverwrite(dst, actualLhsMatrix(lhs.nestedExpression()), rhs.diagonal());
+    }
+  }
 
   template <typename Dest>
   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs,
                                                                   const Scalar& alpha) {
-    selfadjoint_diagonal_product_impl<Lhs::Mode, OnTheRight, typename Lhs::MatrixType,
-                                      typename Rhs::DiagonalVectorType>::run(dst, lhs.nestedExpression(),
-                                                                             rhs.diagonal(), alpha);
+    Scalar combinedAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs.nestedExpression());
+    Kernel::run(dst, actualLhsMatrix(lhs.nestedExpression()), rhs.diagonal(), combinedAlpha);
   }
 };
 
@@ -1059,13 +1188,35 @@
 struct generic_product_impl<Lhs, Rhs, DiagonalShape, SelfAdjointShape, ProductTag>
     : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DiagonalShape, SelfAdjointShape, ProductTag>> {
   typedef typename Product<Lhs, Rhs>::Scalar Scalar;
+  // See note on the SelfAdjointShape, DiagonalShape specialization above for why
+  // we extract the scalar factor with blas_traits.
+  using RhsBlasTraits = blas_traits<typename Rhs::MatrixType>;
+  using ActualRhsMatrix = decltype(RhsBlasTraits::extract(std::declval<const typename Rhs::MatrixType&>())
+                                       .template conjugateIf<bool(RhsBlasTraits::NeedToConjugate)>());
+  using ActualRhsMatrixType = remove_all_t<ActualRhsMatrix>;
+  typedef selfadjoint_diagonal_product_impl<Rhs::Mode, OnTheLeft, ActualRhsMatrixType, typename Lhs::DiagonalVectorType>
+      Kernel;
+
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ActualRhsMatrix actualRhsMatrix(const typename Rhs::MatrixType& matrix) {
+    return RhsBlasTraits::extract(matrix).template conjugateIf<bool(RhsBlasTraits::NeedToConjugate)>();
+  }
+
+  template <typename Dest>
+  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
+    if (RhsBlasTraits::HasScalarFactor) {
+      Scalar factor = RhsBlasTraits::extractScalarFactor(rhs.nestedExpression());
+      dst.setZero();
+      Kernel::run(dst, actualRhsMatrix(rhs.nestedExpression()), lhs.diagonal(), factor);
+    } else {
+      Kernel::runOverwrite(dst, actualRhsMatrix(rhs.nestedExpression()), lhs.diagonal());
+    }
+  }
 
   template <typename Dest>
   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs,
                                                                   const Scalar& alpha) {
-    selfadjoint_diagonal_product_impl<Rhs::Mode, OnTheLeft, typename Rhs::MatrixType,
-                                      typename Lhs::DiagonalVectorType>::run(dst, rhs.nestedExpression(),
-                                                                             lhs.diagonal(), alpha);
+    Scalar combinedAlpha = alpha * RhsBlasTraits::extractScalarFactor(rhs.nestedExpression());
+    Kernel::run(dst, actualRhsMatrix(rhs.nestedExpression()), lhs.diagonal(), combinedAlpha);
   }
 };
 
diff --git a/test/diagonalmatrices.cpp b/test/diagonalmatrices.cpp
index 97046de..2af0da6 100644
--- a/test/diagonalmatrices.cpp
+++ b/test/diagonalmatrices.cpp
@@ -336,6 +336,114 @@
   VERIFY_IS_APPROX(no_malloc_result, dynamic_expected);
 }
 
+// Exercise the block-tile path of the dense selfadjoint x diagonal kernel
+// (BlockSize = 32 in ProductEvaluators.h). Picks a few sizes that hit
+// full blocks (size = 64), partial blocks (size = 33, 65), and a tiny size
+// that bypasses the block loop entirely (size = 8). Also verifies the
+// overwrite path leaves no stale data when dst is pre-filled.
+template <typename Scalar>
+void selfadjoint_diagonal_products_at(Index n) {
+  typedef Matrix<Scalar, Dynamic, Dynamic> MatType;
+  typedef Matrix<Scalar, Dynamic, 1> VecType;
+
+  MatType m = MatType::Random(n, n);
+  m.diagonal() = m.diagonal().real().template cast<Scalar>();  // Hermitian diagonal
+  VecType d = VecType::Random(n);
+
+  MatType ref_lower = m.template selfadjointView<Lower>();
+  MatType ref_upper = m.template selfadjointView<Upper>();
+
+  // Plain assignment goes through evalTo (overwrite kernel).
+  // Pre-fill dst with garbage to verify no stale entries remain.
+  MatType dst = MatType::Constant(n, n, Scalar(42));
+  dst.noalias() = m.template selfadjointView<Upper>() * d.asDiagonal();
+  VERIFY_IS_APPROX(dst, ref_upper * d.asDiagonal());
+
+  dst = MatType::Constant(n, n, Scalar(-7));
+  dst.noalias() = m.template selfadjointView<Lower>() * d.asDiagonal();
+  VERIFY_IS_APPROX(dst, ref_lower * d.asDiagonal());
+
+  dst = MatType::Constant(n, n, Scalar(13));
+  dst.noalias() = d.asDiagonal() * m.template selfadjointView<Upper>();
+  VERIFY_IS_APPROX(dst, d.asDiagonal() * ref_upper);
+
+  dst = MatType::Constant(n, n, Scalar(99));
+  dst.noalias() = d.asDiagonal() * m.template selfadjointView<Lower>();
+  VERIFY_IS_APPROX(dst, d.asDiagonal() * ref_lower);
+
+  // Accumulating paths (scaleAndAddTo).
+  MatType base = MatType::Random(n, n);
+  dst = base;
+  dst.noalias() += m.template selfadjointView<Upper>() * d.asDiagonal();
+  VERIFY_IS_APPROX(dst, base + ref_upper * d.asDiagonal());
+
+  dst = base;
+  dst.noalias() -= d.asDiagonal() * m.template selfadjointView<Lower>();
+  VERIFY_IS_APPROX(dst, base - d.asDiagonal() * ref_lower);
+
+  // Scalar-scaled products: the "Dense ?= scalar * Product" rewriting rule
+  // folds alpha into the SelfAdjointView. For a complex alpha that fold is
+  // not Hermitian, so the dispatch must restore alpha rather than apply it
+  // straight to the kernel — verify all four orientation x triangle combos.
+  Scalar alpha = internal::random<Scalar>();
+  dst = base;
+  dst.noalias() += alpha * (m.template selfadjointView<Lower>() * d.asDiagonal());
+  VERIFY_IS_APPROX(dst, base + alpha * (ref_lower * d.asDiagonal()));
+
+  dst = base;
+  dst.noalias() -= alpha * (m.template selfadjointView<Upper>() * d.asDiagonal());
+  VERIFY_IS_APPROX(dst, base - alpha * (ref_upper * d.asDiagonal()));
+
+  dst = base;
+  dst.noalias() += alpha * (d.asDiagonal() * m.template selfadjointView<Upper>());
+  VERIFY_IS_APPROX(dst, base + alpha * (d.asDiagonal() * ref_upper));
+
+  dst = base;
+  dst.noalias() -= alpha * (d.asDiagonal() * m.template selfadjointView<Lower>());
+  VERIFY_IS_APPROX(dst, base - alpha * (d.asDiagonal() * ref_lower));
+
+  // Overwrite-with-scalar path: hits evalTo's HasScalarFactor branch.
+  dst = MatType::Constant(n, n, Scalar(17));
+  dst.noalias() = alpha * (m.template selfadjointView<Lower>() * d.asDiagonal());
+  VERIFY_IS_APPROX(dst, alpha * (ref_lower * d.asDiagonal()));
+
+  dst = MatType::Constant(n, n, Scalar(-3));
+  dst.noalias() = alpha * (d.asDiagonal() * m.template selfadjointView<Upper>());
+  VERIFY_IS_APPROX(dst, alpha * (d.asDiagonal() * ref_upper));
+
+  // Conjugated nested expressions go through the same blas_traits extraction
+  // path. The extracted matrix must keep NeedToConjugate, otherwise the kernel
+  // computes with m instead of m.conjugate().
+  MatType conj_ref_lower = m.conjugate().template selfadjointView<Lower>();
+  MatType conj_ref_upper = m.conjugate().template selfadjointView<Upper>();
+
+  dst = MatType::Constant(n, n, Scalar(23));
+  dst.noalias() = m.conjugate().template selfadjointView<Upper>() * d.asDiagonal();
+  VERIFY_IS_APPROX(dst, conj_ref_upper * d.asDiagonal());
+
+  dst = MatType::Constant(n, n, Scalar(-29));
+  dst.noalias() = d.asDiagonal() * m.conjugate().template selfadjointView<Lower>();
+  VERIFY_IS_APPROX(dst, d.asDiagonal() * conj_ref_lower);
+
+  dst = base;
+  dst.noalias() += alpha * (m.conjugate().template selfadjointView<Lower>() * d.asDiagonal());
+  VERIFY_IS_APPROX(dst, base + alpha * (conj_ref_lower * d.asDiagonal()));
+
+  dst = base;
+  dst.noalias() -= alpha * (d.asDiagonal() * m.conjugate().template selfadjointView<Upper>());
+  VERIFY_IS_APPROX(dst, base - alpha * (d.asDiagonal() * conj_ref_upper));
+}
+
+template <int>
+void selfadjoint_diagonal_products_block_path() {
+  selfadjoint_diagonal_products_at<double>(8);
+  selfadjoint_diagonal_products_at<double>(33);  // partial off-diagonal block
+  selfadjoint_diagonal_products_at<double>(64);  // exact multiple of BlockSize
+  selfadjoint_diagonal_products_at<double>(65);  // off-by-one
+  selfadjoint_diagonal_products_at<std::complex<double>>(33);
+  selfadjoint_diagonal_products_at<std::complex<double>>(65);
+}
+
 EIGEN_DECLARE_TEST(diagonalmatrices) {
   for (int i = 0; i < g_repeat; i++) {
     CALL_SUBTEST_1(diagonalmatrices(Matrix<float, 1, 1>()));
@@ -360,4 +468,5 @@
   CALL_SUBTEST_10(bug987<0>());
   CALL_SUBTEST_10(bug2013<0>());
   CALL_SUBTEST_10(selfadjoint_diagonal_products<0>());
+  CALL_SUBTEST_10(selfadjoint_diagonal_products_block_path<0>());
 }