Core: skip GEMM/GEMV kernels when alpha == 0 libeigen/eigen!2471 Closes #2173 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 2b45dd2..915f130 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -52,6 +52,9 @@ static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs_, Index lhsStride, const RhsScalar* rhs_, Index rhsStride, ResScalar* res_, Index resIncr, Index resStride, ResScalar alpha, level3_blocking<LhsScalar, RhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) { + // BLAS contract: if alpha == 0, the result is unchanged (and lhs/rhs need not be read). + if (numext::is_exactly_zero(alpha)) return; + typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index b0db9f2..326ee1d 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h
@@ -190,6 +190,9 @@ EIGEN_UNUSED_VARIABLE(resIncr); eigen_internal_assert(resIncr == 1); + // BLAS contract: if alpha == 0, the result is unchanged (and lhs/rhs need not be read). + if (numext::is_exactly_zero(alpha)) return; + // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate proper code. LhsMapper lhs(alhs); @@ -337,6 +340,9 @@ general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha) { + // BLAS contract: if alpha == 0, the result is unchanged (and lhs/rhs need not be read). + if (numext::is_exactly_zero(alpha)) return; + // When cols < full packet size, the main vectorized loops are empty. // Dispatch to a separate noinline function to avoid polluting the icache. // Only dispatch when cols is large enough that half or quarter packets can be used;
diff --git a/test/product_extra.cpp b/test/product_extra.cpp index 1e3c665..4751b0d 100644 --- a/test/product_extra.cpp +++ b/test/product_extra.cpp
@@ -651,6 +651,74 @@ (void)PS_d; } +// Locks the BLAS contract that GEMM/GEMV leave the destination unchanged when +// alpha == 0, including under non-finite inputs in A/x/B that would otherwise +// taint the result via 0 * Inf = NaN. +template <typename Scalar> +void alpha_zero_skips_kernel() { + typedef typename NumTraits<Scalar>::Real RealScalar; + typedef Matrix<Scalar, Dynamic, Dynamic, ColMajor> ColMat; + typedef Matrix<Scalar, Dynamic, Dynamic, RowMajor> RowMat; + typedef Matrix<Scalar, Dynamic, 1> Vec; + + const Index m = 17, k = 13, n = 11; + const Scalar inf = Scalar(NumTraits<RealScalar>::infinity()); + const Scalar nan = Scalar(NumTraits<RealScalar>::quiet_NaN()); + const Scalar pos_zero = Scalar(0); + const Scalar neg_zero = Scalar(-RealScalar(0)); + + // GEMM (col-major). + { + ColMat A = ColMat::Random(m, k); + ColMat B = ColMat::Random(k, n); + A(0, 0) = inf; + B(1, 1) = nan; + + ColMat C = ColMat::Random(m, n); + const ColMat C_ref = C; + + C.noalias() += pos_zero * A * B; + VERIFY_IS_CWISE_EQUAL(C, C_ref); + + C.noalias() += neg_zero * A * B; + VERIFY_IS_CWISE_EQUAL(C, C_ref); + } + + // GEMV col-major. + { + ColMat A = ColMat::Random(m, k); + Vec x = Vec::Random(k); + A(0, 0) = inf; + x(1) = nan; + + Vec y = Vec::Random(m); + const Vec y_ref = y; + + y.noalias() += pos_zero * (A * x); + VERIFY_IS_CWISE_EQUAL(y, y_ref); + + y.noalias() += neg_zero * (A * x); + VERIFY_IS_CWISE_EQUAL(y, y_ref); + } + + // GEMV row-major. + { + RowMat A = RowMat::Random(m, k); + Vec x = Vec::Random(k); + A(0, 0) = inf; + x(1) = nan; + + Vec y = Vec::Random(m); + const Vec y_ref = y; + + y.noalias() += pos_zero * (A * x); + VERIFY_IS_CWISE_EQUAL(y, y_ref); + + y.noalias() += neg_zero * (A * x); + VERIFY_IS_CWISE_EQUAL(y, y_ref); + } +} + EIGEN_DECLARE_TEST(product_extra) { for (int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1(product_extra( @@ -678,4 +746,10 @@ // Complex GEMV conjugation at varied sizes (deterministic, outside g_repeat). CALL_SUBTEST_11(gemv_complex_conjugate<0>()); + + // alpha==0 fast path: GEMM/GEMV must leave the destination unchanged. + CALL_SUBTEST_12(alpha_zero_skips_kernel<float>()); + CALL_SUBTEST_12(alpha_zero_skips_kernel<double>()); + CALL_SUBTEST_12(alpha_zero_skips_kernel<std::complex<float> >()); + CALL_SUBTEST_12(alpha_zero_skips_kernel<std::complex<double> >()); }