Fix AVX512 nomalloc issues in trsm.
diff --git a/Eigen/src/Core/arch/AVX512/TrsmKernel.h b/Eigen/src/Core/arch/AVX512/TrsmKernel.h
index 8fbf161..714afac 100644
--- a/Eigen/src/Core/arch/AVX512/TrsmKernel.h
+++ b/Eigen/src/Core/arch/AVX512/TrsmKernel.h
@@ -16,6 +16,13 @@
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
#endif
+// TRSM kernels currently unconditionally rely on malloc with AVX512.
+// Disable them if malloc is explicitly disabled at compile-time.
+#ifdef EIGEN_NO_MALLOC
+#undef EIGEN_USE_AVX512_TRSM_KERNELS
+#define EIGEN_USE_AVX512_TRSM_KERNELS 0
+#endif
+
#if EIGEN_USE_AVX512_TRSM_KERNELS
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
@@ -860,32 +867,6 @@
}
}
-#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
-/**
- * Reduce blocking sizes so that the size of the temporary workspace needed is less than "limit" bytes,
- * - kB must be at least psize
- * - numM must be at least EIGEN_AVX_MAX_NUM_ROW
- */
-template <typename Scalar, bool isBRowMajor>
-constexpr std::pair<int64_t, int64_t> trsmBlocking(const int64_t limit) {
- constexpr int64_t psize = packet_traits<Scalar>::size;
- int64_t kB = 15 * psize;
- int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
- // If B is rowmajor, no temp workspace needed, so use default blocking sizes.
- if (isBRowMajor) return {kB, numM};
-
- // Very simple heuristic, prefer keeping kB as large as possible to fully use vector registers.
- for (int64_t k = kB; k > psize; k -= psize) {
- for (int64_t m = numM; m > EIGEN_AVX_MAX_NUM_ROW; m -= EIGEN_AVX_MAX_NUM_ROW) {
- if ((((k + psize - 1) / psize + 4) * psize) * m * sizeof(Scalar) < limit) {
- return {k, m};
- }
- }
- }
- return {psize, EIGEN_AVX_MAX_NUM_ROW}; // Minimum blocking size required
-}
-#endif // (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
-
/**
* Main triangular solve driver
*
@@ -930,30 +911,8 @@
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
* determine optimal values for numM.
*/
-#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
- /**
- * If EIGEN_NO_MALLOC is requested, we try to reduce kB and numM so the maximum temp workspace required is less
- * than EIGEN_STACK_ALLOCATION_LIMIT. Actual workspace size may be less, depending on the number of vectors to
- * solve.
- * - kB must be at least psize
- * - numM must be at least EIGEN_AVX_MAX_NUM_ROW
- *
- * If B is row-major, the blocking sizes are not reduced (no temp workspace needed).
- */
- constexpr std::pair<int64_t, int64_t> blocking_ = trsmBlocking<Scalar, isBRowMajor>(EIGEN_STACK_ALLOCATION_LIMIT);
- constexpr int64_t kB = blocking_.first;
- constexpr int64_t numM = blocking_.second;
- /**
- * If the temp workspace size exceeds EIGEN_STACK_ALLOCATION_LIMIT even with the minimum blocking sizes,
- * we throw an assertion. Use -DEIGEN_USE_AVX512_TRSM_L_KERNELS=0 if necessary
- */
- static_assert(!(((((kB + psize - 1) / psize + 4) * psize) * numM * sizeof(Scalar) >= EIGEN_STACK_ALLOCATION_LIMIT) &&
- !isBRowMajor),
- "Temp workspace required is too large.");
-#else
constexpr int64_t kB = (3 * psize) * 5; // 5*U3
constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
-#endif
int64_t sizeBTemp = 0;
Scalar *B_temp = NULL;
@@ -966,13 +925,7 @@
sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
}
-#if !defined(EIGEN_NO_MALLOC)
EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
-#elif (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
- // Use alloca if malloc not allowed, requested temp workspace size should be less than EIGEN_STACK_ALLOCATION_LIMIT
- ei_declare_aligned_stack_constructed_variable(Scalar, B_temp_alloca, sizeBTemp, 0);
- B_temp = B_temp_alloca;
-#endif
for (int64_t k = 0; k < numRHS; k += kB) {
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
@@ -1102,43 +1055,55 @@
}
}
-#if !defined(EIGEN_NO_MALLOC)
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
-#endif
}
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
#if (EIGEN_USE_AVX512_TRSM_KERNELS)
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride, bool Specialized>
struct trsmKernelR;
template <typename Index, int Mode, int TriStorageOrder>
-struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1> {
+struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
-struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1> {
+struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
-EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
+EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
+#ifdef EIGEN_NO_RUNTIME_MALLOC
+ if (!is_malloc_allowed()) {
+ trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
+ return;
+ }
+#endif
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
}
template <typename Index, int Mode, int TriStorageOrder>
-EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
+EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
+#ifdef EIGEN_NO_RUNTIME_MALLOC
+ if (!is_malloc_allowed()) {
+ trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
+ return;
+ }
+#endif
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
}
@@ -1146,35 +1111,49 @@
// These trsm kernels require temporary memory allocation
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride, bool Specialized = true>
struct trsmKernelL;
template <typename Index, int Mode, int TriStorageOrder>
-struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1> {
+struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
-struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1> {
+struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
-EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
+EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
+#ifdef EIGEN_NO_RUNTIME_MALLOC
+ if (!is_malloc_allowed()) {
+ trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
+ return;
+ }
+#endif
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
}
template <typename Index, int Mode, int TriStorageOrder>
-EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
+EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
+#ifdef EIGEN_NO_RUNTIME_MALLOC
+ if (!is_malloc_allowed()) {
+ trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
+ return;
+ }
+#endif
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
}
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h
index b148d9c..22b4a7f 100644
--- a/Eigen/src/Core/products/TriangularSolverMatrix.h
+++ b/Eigen/src/Core/products/TriangularSolverMatrix.h
@@ -17,7 +17,7 @@
namespace internal {
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride, bool Specialized>
struct trsmKernelL {
// Generic Implementation of triangular solve for triangular matrix on left and multiple rhs.
// Handles non-packed matrices.
@@ -27,7 +27,7 @@
Scalar* _other, Index otherIncr, Index otherStride);
};
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride, bool Specialized>
struct trsmKernelR {
// Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
// Handles non-packed matrices.
@@ -37,8 +37,8 @@
Scalar* _other, Index otherIncr, Index otherStride);
};
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
-EIGEN_STRONG_INLINE void trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::kernel(
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride, bool Specialized>
+EIGEN_STRONG_INLINE void trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, Specialized>::kernel(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride)
@@ -88,8 +88,8 @@
}
-template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
-EIGEN_STRONG_INLINE void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::kernel(
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride, bool Specialized>
+EIGEN_STRONG_INLINE void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, Specialized>::kernel(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride)
@@ -180,7 +180,7 @@
// TODO: Investigate better heuristics for cutoffs.
double L2Cap = 0.5; // 50% of L2 size
if (size < avx512_trsm_cutoff<Scalar>(l2, cols, L2Cap)) {
- trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1>::kernel(
+ trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1, /*Specialized=*/true>::kernel(
size, cols, _tri, triStride, _other, 1, otherStride);
return;
}
@@ -253,7 +253,7 @@
i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
}
#endif
- trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::kernel(
+ trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, /*Specialized=*/true>::kernel(
actualPanelWidth, actual_cols,
_tri + i + (i)*triStride, triStride,
_other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
@@ -327,7 +327,7 @@
manage_caching_sizes(GetAction, &l1, &l2, &l3);
double L2Cap = 0.5; // 50% of L2 size
if (size < avx512_trsm_cutoff<Scalar>(l2, rows, L2Cap)) {
- trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
+ trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, /*Specialized=*/true>::
kernel(size, rows, _tri, triStride, _other, 1, otherStride);
return;
}
@@ -423,7 +423,7 @@
{
// unblocked triangular solve
- trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
+ trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, /*Specialized=*/true>::
kernel(actualPanelWidth, actual_mc,
_tri + absolute_j2 + absolute_j2*triStride, triStride,
_other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);