Core: skip GEMM/GEMV kernels when alpha == 0

libeigen/eigen!2471

Closes #2173

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 2b45dd2..915f130 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -52,6 +52,9 @@
   static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs_, Index lhsStride, const RhsScalar* rhs_,
                   Index rhsStride, ResScalar* res_, Index resIncr, Index resStride, ResScalar alpha,
                   level3_blocking<LhsScalar, RhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) {
+    // BLAS contract: if alpha == 0, the result is unchanged (and lhs/rhs need not be read).
+    if (numext::is_exactly_zero(alpha)) return;
+
     typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
     typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
     typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h
index b0db9f2..326ee1d 100644
--- a/Eigen/src/Core/products/GeneralMatrixVector.h
+++ b/Eigen/src/Core/products/GeneralMatrixVector.h
@@ -190,6 +190,9 @@
   EIGEN_UNUSED_VARIABLE(resIncr);
   eigen_internal_assert(resIncr == 1);
 
+  // BLAS contract: if alpha == 0, the result is unchanged (and lhs/rhs need not be read).
+  if (numext::is_exactly_zero(alpha)) return;
+
   // The following copy tells the compiler that lhs's attributes are not modified outside this function
   // This helps GCC to generate proper code.
   LhsMapper lhs(alhs);
@@ -337,6 +340,9 @@
 general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
                               Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
                                             ResScalar* res, Index resIncr, ResScalar alpha) {
+  // BLAS contract: if alpha == 0, the result is unchanged (and lhs/rhs need not be read).
+  if (numext::is_exactly_zero(alpha)) return;
+
   // When cols < full packet size, the main vectorized loops are empty.
   // Dispatch to a separate noinline function to avoid polluting the icache.
   // Only dispatch when cols is large enough that half or quarter packets can be used;
diff --git a/test/product_extra.cpp b/test/product_extra.cpp
index 1e3c665..4751b0d 100644
--- a/test/product_extra.cpp
+++ b/test/product_extra.cpp
@@ -651,6 +651,74 @@
   (void)PS_d;
 }
 
+// Locks the BLAS contract that GEMM/GEMV leave the destination unchanged when
+// alpha == 0, including under non-finite inputs in A/x/B that would otherwise
+// taint the result via 0 * Inf = NaN.
+template <typename Scalar>
+void alpha_zero_skips_kernel() {
+  typedef typename NumTraits<Scalar>::Real RealScalar;
+  typedef Matrix<Scalar, Dynamic, Dynamic, ColMajor> ColMat;
+  typedef Matrix<Scalar, Dynamic, Dynamic, RowMajor> RowMat;
+  typedef Matrix<Scalar, Dynamic, 1> Vec;
+
+  const Index m = 17, k = 13, n = 11;
+  const Scalar inf = Scalar(NumTraits<RealScalar>::infinity());
+  const Scalar nan = Scalar(NumTraits<RealScalar>::quiet_NaN());
+  const Scalar pos_zero = Scalar(0);
+  const Scalar neg_zero = Scalar(-RealScalar(0));
+
+  // GEMM (col-major).
+  {
+    ColMat A = ColMat::Random(m, k);
+    ColMat B = ColMat::Random(k, n);
+    A(0, 0) = inf;
+    B(1, 1) = nan;
+
+    ColMat C = ColMat::Random(m, n);
+    const ColMat C_ref = C;
+
+    C.noalias() += pos_zero * A * B;
+    VERIFY_IS_CWISE_EQUAL(C, C_ref);
+
+    C.noalias() += neg_zero * A * B;
+    VERIFY_IS_CWISE_EQUAL(C, C_ref);
+  }
+
+  // GEMV col-major.
+  {
+    ColMat A = ColMat::Random(m, k);
+    Vec x = Vec::Random(k);
+    A(0, 0) = inf;
+    x(1) = nan;
+
+    Vec y = Vec::Random(m);
+    const Vec y_ref = y;
+
+    y.noalias() += pos_zero * (A * x);
+    VERIFY_IS_CWISE_EQUAL(y, y_ref);
+
+    y.noalias() += neg_zero * (A * x);
+    VERIFY_IS_CWISE_EQUAL(y, y_ref);
+  }
+
+  // GEMV row-major.
+  {
+    RowMat A = RowMat::Random(m, k);
+    Vec x = Vec::Random(k);
+    A(0, 0) = inf;
+    x(1) = nan;
+
+    Vec y = Vec::Random(m);
+    const Vec y_ref = y;
+
+    y.noalias() += pos_zero * (A * x);
+    VERIFY_IS_CWISE_EQUAL(y, y_ref);
+
+    y.noalias() += neg_zero * (A * x);
+    VERIFY_IS_CWISE_EQUAL(y, y_ref);
+  }
+}
+
 EIGEN_DECLARE_TEST(product_extra) {
   for (int i = 0; i < g_repeat; i++) {
     CALL_SUBTEST_1(product_extra(
@@ -678,4 +746,10 @@
 
   // Complex GEMV conjugation at varied sizes (deterministic, outside g_repeat).
   CALL_SUBTEST_11(gemv_complex_conjugate<0>());
+
+  // alpha==0 fast path: GEMM/GEMV must leave the destination unchanged.
+  CALL_SUBTEST_12(alpha_zero_skips_kernel<float>());
+  CALL_SUBTEST_12(alpha_zero_skips_kernel<double>());
+  CALL_SUBTEST_12(alpha_zero_skips_kernel<std::complex<float> >());
+  CALL_SUBTEST_12(alpha_zero_skips_kernel<std::complex<double> >());
 }