Replace calls to numext::fma with numext:madd.
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 21a1bfc..139b10e 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -1350,20 +1350,20 @@
template <typename Scalar>
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
- return numext::fma(a, b, c);
+ return numext::madd<Scalar>(a, b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
- return numext::fma(a, b, Scalar(-c));
+ return numext::madd<Scalar>(a, b, Scalar(-c));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
- return numext::fma(Scalar(-a), b, c);
+ return numext::madd<Scalar>(Scalar(-a), b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
- return -Scalar(numext::fma(a, b, c));
+ return -Scalar(numext::madd<Scalar>(a, b, c));
}
};
-// FMA instructions.
+// Multiply-add instructions.
/** \internal \returns a * b + c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 481e057..44b16be 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -941,24 +941,45 @@
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
};
+// Extra namespace to prevent leaking std::fma into Eigen::internal.
+namespace has_fma_detail {
+
+template <typename T, typename EnableIf = void>
+struct has_fma_impl : public std::false_type {};
+
+using std::fma;
+
+template <typename T>
+struct has_fma_impl<
+ T, std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>>
+ : public std::true_type {};
+
+} // namespace has_fma_detail
+
+template <typename T>
+struct has_fma : public has_fma_detail::has_fma_impl<T> {};
+
// Default implementation.
-template <typename Scalar, typename Enable = void>
+template <typename T, typename Enable = void>
struct fma_impl {
- static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) {
- return a * b + c;
+ static_assert(has_fma<T>::value, "No function fma(...) for type. Please provide an implementation.");
+};
+
+// STD or ADL version if it exists.
+template <typename T>
+struct fma_impl<T, std::enable_if_t<has_fma<T>::value>> {
+ static T run(const T& a, const T& b, const T& c) {
+ using std::fma;
+ return fma(a, b, c);
}
};
-// ADL version if it exists.
-template <typename T>
-struct fma_impl<
- T,
- std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
- static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
-};
-
#if defined(EIGEN_GPUCC)
template <>
+struct has_fma<float> : public true_type {
+}
+
+template <>
struct fma_impl<float, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
return ::fmaf(a, b, c);
@@ -966,6 +987,10 @@
};
template <>
+struct has_fma<double> : public true_type {
+}
+
+template <>
struct fma_impl<double, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
return ::fma(a, b, c);
@@ -973,6 +998,24 @@
};
#endif
+// Basic multiply-add.
+template <typename Scalar, typename EnableIf = void>
+struct madd_impl {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
+ return x * y + z;
+ }
+};
+
+// Use FMA if there is a single CPU instruction.
+#ifdef EIGEN_VECTORIZE_FMA
+template <typename Scalar>
+struct madd_impl<Scalar, std::enable_if_t<has_fma<Scalar>::value>> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
+ return fma_impl<Scalar>::run(x, y, z);
+ }
+};
+#endif
+
} // end namespace internal
/****************************************************************************
@@ -1886,15 +1929,18 @@
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
}
-// Use std::fma if available.
-using std::fma;
-
// Otherwise, rely on template implementation.
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
return internal::fma_impl<Scalar>::run(x, y, z);
}
+// Multiply-add.
+template <typename Scalar>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar madd(const Scalar& x, const Scalar& y, const Scalar& z) {
+ return internal::madd_impl<Scalar>::run(x, y, z);
+}
+
} // end namespace numext
namespace internal {
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 64ba7ba..b66a4db 100644
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -2026,38 +2026,38 @@
}
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
-#ifdef EIGEN_VECTORIZE_FMA
+#if defined(EIGEN_VECTORIZE_FMA)
template <>
EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
- return ::fmaf(a, b, c);
+ return std::fmaf(a, b, c);
}
template <>
EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
- return ::fma(a, b, c);
+ return std::fma(a, b, c);
}
template <>
EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
- return ::fmaf(a, b, -c);
+ return std::fmaf(a, b, -c);
}
template <>
EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
- return ::fma(a, b, -c);
+ return std::fma(a, b, -c);
}
template <>
EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
- return ::fmaf(-a, b, c);
+ return std::fmaf(-a, b, c);
}
template <>
EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
- return ::fma(-a, b, c);
+ return std::fma(-a, b, c);
}
template <>
EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
- return ::fmaf(-a, b, -c);
+ return std::fmaf(-a, b, -c);
}
template <>
EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
- return ::fma(-a, b, -c);
+ return std::fma(-a, b, -c);
}
#endif
diff --git a/Eigen/src/SparseCore/SparseDot.h b/Eigen/src/SparseCore/SparseDot.h
index 8aeebc8..485605f 100644
--- a/Eigen/src/SparseCore/SparseDot.h
+++ b/Eigen/src/SparseCore/SparseDot.h
@@ -36,10 +36,10 @@
Scalar res1(0);
Scalar res2(0);
for (; i; ++i) {
- res1 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res1);
+ res1 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res1);
++i;
if (i) {
- res2 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res2);
+ res2 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res2);
}
}
return res1 + res2;
@@ -67,7 +67,7 @@
Scalar res(0);
while (i && j) {
if (i.index() == j.index()) {
- res = numext::fma(numext::conj(i.value()), j.value(), res);
+ res = numext::madd<Scalar>(numext::conj(i.value()), j.value(), res);
++i;
++j;
} else if (i.index() < j.index())
diff --git a/Eigen/src/SparseCore/TriangularSolver.h b/Eigen/src/SparseCore/TriangularSolver.h
index fb8c157..684de48 100644
--- a/Eigen/src/SparseCore/TriangularSolver.h
+++ b/Eigen/src/SparseCore/TriangularSolver.h
@@ -41,7 +41,7 @@
lastVal = it.value();
lastIndex = it.index();
if (lastIndex == i) break;
- tmp = numext::fma(-lastVal, other.coeff(lastIndex, col), tmp);
+ tmp = numext::madd<Scalar>(-lastVal, other.coeff(lastIndex, col), tmp);
}
if (Mode & UnitDiag)
other.coeffRef(i, col) = tmp;
@@ -75,7 +75,7 @@
} else if (it && it.index() == i)
++it;
for (; it; ++it) {
- tmp = numext::fma<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
+ tmp = numext::madd<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
}
if (Mode & UnitDiag)
@@ -108,7 +108,7 @@
}
if (it && it.index() == i) ++it;
for (; it; ++it) {
- other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
+ other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
}
}
}
@@ -138,7 +138,7 @@
}
LhsIterator it(lhsEval, i);
for (; it && it.index() < i; ++it) {
- other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
+ other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
}
}
}
@@ -220,11 +220,11 @@
if (IsLower) {
if (it.index() == i) ++it;
for (; it; ++it) {
- tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
+ tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
}
} else {
for (; it && it.index() < i; ++it) {
- tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
+ tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
}
}
}
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index 5f48d71..f21c726 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -44,16 +44,16 @@
struct madd_impl<Scalar,
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
- return numext::fma(a, b, c);
+ return numext::madd(a, b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
- return numext::fma(a, b, Scalar(-c));
+ return numext::madd(a, b, Scalar(-c));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
- return numext::fma(Scalar(-a), b, c);
+ return numext::madd(Scalar(-a), b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
- return -Scalar(numext::fma(a, b, c));
+ return -Scalar(numext::madd(a, b, c));
}
};
diff --git a/test/product.h b/test/product.h
index f37a932..21b4701 100644
--- a/test/product.h
+++ b/test/product.h
@@ -42,7 +42,7 @@
Scalar ref_dot_product(const V1& v1, const V2& v2) {
Scalar out = Scalar(0);
for (Index i = 0; i < v1.size(); ++i) {
- out = Eigen::numext::fma(v1[i], v2[i], out);
+ out = Eigen::numext::madd(v1[i], v2[i], out);
}
return out;
}
@@ -254,8 +254,6 @@
// inner product
{
Scalar x = square2.row(c) * square2.col(c2);
- // NOTE: FMA is necessary here in the reference to ensure accuracy for
- // large vector sizes and float16/bfloat16 types.
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
VERIFY_IS_APPROX(x, y);
}