Remove inline assembly for FMA (AVX) and add remaining extensions as packet ops: pmsub, pnmadd, and pnmsub.
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 4f1ff6b..f60724a 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -939,6 +939,35 @@
* The following functions might not have to be overwritten for vectorized types
***************************************************************************/
+// FMA instructions.
+/** \internal \returns a * b + c (coeff-wise) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b,
+ const Packet& c) {
+ return padd(pmul(a, b), c);
+}
+
+/** \internal \returns a * b - c (coeff-wise) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pmsub(const Packet& a, const Packet& b,
+ const Packet& c) {
+ return psub(pmul(a, b), c);
+}
+
+/** \internal \returns -(a * b) + c (coeff-wise) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b,
+ const Packet& c) {
+ return padd(pnegate(pmul(a, b)), c);
+}
+
+/** \internal \returns -(a * b) - c (coeff-wise) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b,
+ const Packet& c) {
+ return psub(pnegate(pmul(a, b)), c);
+}
+
/** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned */
// NOTE: this function must really be templated on the packet type (think about different packet types for the same scalar type)
template<typename Packet>
@@ -947,13 +976,6 @@
pstore(to, pset1<Packet>(a));
}
-/** \internal \returns a * b + c (coeff-wise) */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pmadd(const Packet& a,
- const Packet& b,
- const Packet& c)
-{ return padd(pmul(a, b),c); }
-
/** \internal \returns a packet version of \a *from.
* The pointer \a from must be aligned on a \a Alignment bytes boundary. */
template<typename Packet, int Alignment>
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index bf832c9..2df899d 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -540,30 +540,46 @@
}
#ifdef EIGEN_VECTORIZE_FMA
-template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
-#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
- // Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers,
- // and even register spilling with clang>=6.0 (bug 1637).
- // Gcc stupidly generates a vfmadd132ps instruction.
- // So let's enforce it to generate a vfmadd231ps instruction since the most common use
- // case is to accumulate the result of the product.
- Packet8f res = c;
- __asm__("vfmadd231ps %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b));
- return res;
-#else
- return _mm256_fmadd_ps(a,b,c);
-#endif
+
+template <>
+EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
+ return _mm256_fmadd_ps(a, b, c);
}
-template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
-#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
- // see above
- Packet4d res = c;
- __asm__("vfmadd231pd %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b));
- return res;
-#else
- return _mm256_fmadd_pd(a,b,c);
-#endif
+template <>
+EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
+ return _mm256_fmadd_pd(a, b, c);
}
+
+template <>
+EIGEN_STRONG_INLINE Packet8f pmsub(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
+ return _mm256_fmsub_ps(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4d pmsub(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
+ return _mm256_fmsub_pd(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8f pnmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
+ return _mm256_fnmadd_ps(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4d pnmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
+ return _mm256_fnmadd_pd(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8f pnmsub(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
+ return _mm256_fnmsub_ps(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4d pnmsub(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
+ return _mm256_fnmsub_pd(a, b, c);
+}
+
#endif
template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 8a00c62..d0ccfd8 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -359,6 +359,39 @@
const Packet8d& c) {
return _mm512_fmadd_pd(a, b, c);
}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pmsub(const Packet16f& a, const Packet16f& b,
+ const Packet16f& c) {
+ return _mm512_fmsub_ps(a, b, c);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pmsub(const Packet8d& a, const Packet8d& b,
+ const Packet8d& c) {
+ return _mm512_fmsub_pd(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pnmadd(const Packet16f& a, const Packet16f& b,
+ const Packet16f& c) {
+ return _mm512_fnmadd_ps(a, b, c);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pnmadd(const Packet8d& a, const Packet8d& b,
+ const Packet8d& c) {
+ return _mm512_fnmadd_pd(a, b, c);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pnmsub(const Packet16f& a, const Packet16f& b,
+ const Packet16f& c) {
+ return _mm512_fnmsub_ps(a, b, c);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pnmsub(const Packet8d& a, const Packet8d& b,
+ const Packet8d& c) {
+ return _mm512_fnmsub_pd(a, b, c);
+}
#endif
template <>
@@ -2281,13 +2314,13 @@
template <>
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
- return Packet16bf(pand<Packet8i>((Packet8i)a, (Packet8i)b));
+ return Packet16bf(pand<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
const Packet16bf& b) {
- return Packet16bf(pandnot<Packet8i>((Packet8i)a, (Packet8i)b));
+ return Packet16bf(pandnot<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 4de3d47..80f86ff 100755
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -364,6 +364,12 @@
#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet4f pmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmsub_ps(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet2d pmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmsub_pd(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmadd_ps(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); }
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1
@@ -1263,6 +1269,24 @@
template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
return ::fma(a,b,c);
}
+template<> EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
+ return ::fmaf(a,b,-c);
+}
+template<> EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
+ return ::fma(a,b,-c);
+}
+template<> EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
+ return ::fmaf(-a,b,c);
+}
+template<> EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
+ return ::fma(-a,b,c);
+}
+template<> EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
+ return ::fmaf(-a,b,-c);
+}
+template<> EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
+ return ::fma(-a,b,-c);
+}
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index 455ecab..e60308a 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -24,6 +24,22 @@
return a * b;
}
template <typename T>
+inline T REF_MADD(const T& a, const T& b, const T& c) {
+ return a * b + c;
+}
+template <typename T>
+inline T REF_MSUB(const T& a, const T& b, const T& c) {
+ return a * b - c;
+}
+template <typename T>
+inline T REF_NMADD(const T& a, const T& b, const T& c) {
+ return (-a * b) + c;
+}
+template <typename T>
+inline T REF_NMSUB(const T& a, const T& b, const T& c) {
+ return (-a * b) - c;
+}
+template <typename T>
inline T REF_DIV(const T& a, const T& b) {
return a / b;
}
@@ -49,6 +65,10 @@
inline bool REF_MUL(const bool& a, const bool& b) {
return a && b;
}
+template <>
+inline bool REF_MADD(const bool& a, const bool& b, const bool& c) {
+ return (a && b) || c;
+}
template <typename T>
inline T REF_FREXP(const T& x, T& exp) {
@@ -622,6 +642,12 @@
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
+ CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd);
+ if (!std::is_same<Scalar, bool>::value) {
+ CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub);
+ CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd);
+ CHECK_CWISE3_IF(true, REF_NMSUB, internal::pnmsub);
+ }
}
// Notice that this definition works for complex types as well.