Specialize numext::madd for half/bfloat16.
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index f2e55f3..b93c4bc 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -793,6 +793,12 @@
return numext::bit_cast<bfloat16>(from_bits);
}
+// Specialize multiply-add to match packet operations and reduce conversions to/from float.
+template<>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) {
+ return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
+}
+
} // namespace numext
} // namespace Eigen
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index c073fe8..210dfff 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -955,6 +955,12 @@
return Eigen::half_impl::raw_half_as_uint16(src);
}
+// Specialize multiply-add to match packet operations and reduce conversions to/from float.
+template<>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half madd<Eigen::half>(const Eigen::half& x, const Eigen::half& y, const Eigen::half& z) {
+ return Eigen::half(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
+}
+
} // namespace numext
} // namespace Eigen