Vectorize fp16 tanh and logistic functions on Neon
Activates vectorization of the Eigen::half versions of the tanh and
logistic functions when they run on Neon. Both functions convert their
inputs to float before computing the output, and as a result of this
commit, the conversions and the computation in float are vectorized.
diff --git a/Eigen/Core b/Eigen/Core
index e0da499..d6cc162 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -263,6 +263,11 @@
#include "src/Core/arch/GPU/Complex.h"
#endif
+// Specializations of vectorized activation functions for NEON.
+#ifdef EIGEN_VECTORIZE_NEON
+#include "src/Core/arch/NEON/UnaryFunctors.h"
+#endif
+
#include "src/Core/util/IndexedViewHelper.h"
#include "src/Core/util/ReshapedHelper.h"
#include "src/Core/ArithmeticSequence.h"
diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h
index d34882a..0111cf3 100644
--- a/Eigen/src/Core/arch/NEON/MathFunctions.h
+++ b/Eigen/src/Core/arch/NEON/MathFunctions.h
@@ -40,6 +40,25 @@
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh<Packet4f>(const Packet4f& x)
{ return internal::generic_fast_tanh_float(x); }
+#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED
+Packet4hf ptanh<Packet4hf>(const Packet4hf& x) {
+ // Convert to float, call the float ptanh, and then convert back.
+ return vcvt_f16_f32(ptanh<Packet4f>(vcvt_f32_f16(x)));
+}
+
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED
+Packet8hf ptanh<Packet8hf>(const Packet8hf& x) {
+ // Convert each 4 halfs to float, call the float ptanh, and then convert back.
+ return vcombine_f16(
+ vcvt_f16_f32(ptanh<Packet4f>(vcvt_f32_f16(vget_low_f16(x)))),
+ vcvt_f16_f32(ptanh<Packet4f>(vcvt_high_f32_f16(x))));
+}
+#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+
+
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 382a2c8..e908bf5 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -4028,6 +4028,7 @@
HasCos = 0,
HasLog = 0,
HasExp = 0,
+ HasTanh = packet_traits<float>::HasTanh, // tanh<half> calls tanh<float>
HasSqrt = 1,
HasRsqrt = 1,
HasErf = EIGEN_FAST_MATH,
diff --git a/Eigen/src/Core/arch/NEON/UnaryFunctors.h b/Eigen/src/Core/arch/NEON/UnaryFunctors.h
new file mode 100644
index 0000000..131746d
--- /dev/null
+++ b/Eigen/src/Core/arch/NEON/UnaryFunctors.h
@@ -0,0 +1,64 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_NEON_UNARY_FUNCTORS_H
+#define EIGEN_NEON_UNARY_FUNCTORS_H
+
+#include "../../InternalHeaderCheck.h"
+
+namespace Eigen {
+
+namespace internal {
+
+#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+/** \internal
+ * \brief Template specialization of the logistic function for Eigen::half.
+ */
+template <>
+struct scalar_logistic_op<Eigen::half> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Eigen::half operator()(const Eigen::half& x) const {
+ // Convert to float and call scalar_logistic_op<float>.
+ const scalar_logistic_op<float> float_op;
+ return Eigen::half(float_op(float(x)));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Eigen::half packetOp(const Eigen::half& x) const {
+ return this->operator()(x);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Packet4hf packetOp(const Packet4hf& x) const {
+ const scalar_logistic_op<float> float_op;
+ return vcvt_f16_f32(float_op.packetOp(vcvt_f32_f16(x)));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Packet8hf packetOp(const Packet8hf& x) const {
+ const scalar_logistic_op<float> float_op;
+ return vcombine_f16(
+ vcvt_f16_f32(float_op.packetOp(vcvt_f32_f16(vget_low_f16(x)))),
+ vcvt_f16_f32(float_op.packetOp(vcvt_high_f32_f16(x))));
+ }
+};
+
+template<>
+struct functor_traits<scalar_logistic_op<Eigen::half>> {
+ enum {
+ Cost = functor_traits<scalar_logistic_op<float>>::Cost,
+ PacketAccess = functor_traits<scalar_logistic_op<float>>::PacketAccess,
+ };
+};
+#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_NEON_UNARY_FUNCTORS_H