Improve accuracy of fast approximate tanh and the logistic functions in Eigen, such that they preserve relative accuracy to within a few ULPs where their function values tend to zero (around x=0 for tanh, and for large negative x for the logistic function).
This change re-instates the fast rational approximation of the logistic function for float32 in Eigen (removed in https://gitlab.com/libeigen/eigen/commit/66f07efeaed39d6a67005343d7e0caf7d9eeacdb), but uses the more accurate approximation 1/(1+exp(-1)) ~= exp(x) below -9. The exponential is only calculated on the vectorized path if at least one element in the SIMD input vector is less than -9.
This change also contains a few improvements to speed up the original float specialization of logistic:
- Introduce EIGEN_PREDICT_{FALSE,TRUE} for __builtin_predict and use it to predict that the logistic-only path is most likely (~2-3% speedup for the common case).
- Carefully set the upper clipping point to the smallest x where the approximation evaluates to exactly 1. This saves the explicit clamping of the output (~7% speedup).
The increased accuracy for tanh comes at a cost of 10-20% depending on instruction set.
The benchmarks below repeated calls
u = v.logistic() (u = v.tanh(), respectively)
where u and v are of type Eigen::ArrayXf, have length 8k, and v contains random numbers in [-1,1].
Benchmark numbers for logistic:
Before:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_logistic_float 4467 4468 155835 model_time: 4827
AVX
BM_eigen_logistic_float 2347 2347 299135 model_time: 2926
AVX+FMA
BM_eigen_logistic_float 1467 1467 476143 model_time: 2926
AVX512
BM_eigen_logistic_float 805 805 858696 model_time: 1463
After:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_logistic_float 2589 2590 270264 model_time: 4827
AVX
BM_eigen_logistic_float 1428 1428 489265 model_time: 2926
AVX+FMA
BM_eigen_logistic_float 1059 1059 662255 model_time: 2926
AVX512
BM_eigen_logistic_float 673 673 1000000 model_time: 1463
Benchmark numbers for tanh:
Before:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_tanh_float 2391 2391 292624 model_time: 4242
AVX
BM_eigen_tanh_float 1256 1256 554662 model_time: 2633
AVX+FMA
BM_eigen_tanh_float 823 823 866267 model_time: 1609
AVX512
BM_eigen_tanh_float 443 443 1578999 model_time: 805
After:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_tanh_float 2588 2588 273531 model_time: 4242
AVX
BM_eigen_tanh_float 1536 1536 452321 model_time: 2633
AVX+FMA
BM_eigen_tanh_float 1007 1007 694681 model_time: 1609
AVX512
BM_eigen_tanh_float 471 471 1472178 model_time: 805
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index 00c3ae5..f83e358 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -297,9 +297,14 @@
template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
-template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
-template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
+template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
+
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); }
+template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
+
template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) {
#ifdef EIGEN_VECTORIZE_AVX2
@@ -311,6 +316,7 @@
#endif
}
+
template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); }
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 19c03cf..a53f9bc 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -365,6 +365,12 @@
}
#endif
+template <>
+EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
+ return _mm512_castsi512_ps(
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
+}
template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_ps(
@@ -388,12 +394,6 @@
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
}
-template <>
-EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
- __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
- return _mm512_castsi512_ps(
- _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
-}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
@@ -401,6 +401,24 @@
return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGT_UQ);
+ return _mm512_castsi512_pd(
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
+}
template <>
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h
index 3f90c45..9e18c51 100644
--- a/Eigen/src/Core/arch/GPU/PacketMath.h
+++ b/Eigen/src/Core/arch/GPU/PacketMath.h
@@ -156,6 +156,15 @@
return __longlong_as_double(a == b ? 0xffffffffffffffffull : 0ull);
}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a,
+ const float& b) {
+ return __int_as_float(a < b ? 0xffffffffu : 0u);
+}
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a,
+ const double& b) {
+ return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull);
+}
+
} // namespace
template <>
@@ -213,10 +222,21 @@
eq_mask(a.w, b.w));
}
template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt<float4>(const float4& a,
+ const float4& b) {
+ return make_float4(lt_mask(a.x, b.x), lt_mask(a.y, b.y), lt_mask(a.z, b.z),
+ lt_mask(a.w, b.w));
+}
+template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pcmp_eq<double2>(const double2& a, const double2& b) {
return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
+pcmp_lt<double2>(const double2& a, const double2& b) {
+ return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y));
+}
#endif // EIGEN_CUDA_ARCH || defined(EIGEN_HIP_DEVICE_COMPILE)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
@@ -647,6 +667,20 @@
}
template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt<half2>(const half2& a,
+ const half2& b) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half b1 = __low2half(b);
+ half b2 = __high2half(b);
+ half eq1 = __half2float(a1) < __half2float(b1) ? true_half : false_half;
+ half eq2 = __half2float(a2) < __half2float(b2) ? true_half : false_half;
+ return __halves2half2(eq1, eq2);
+}
+
+template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand<half2>(const half2& a,
const half2& b) {
half a1 = __low2half(a);
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 8344754..dcb1fca 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -713,6 +713,8 @@
return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
}
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vcleq_f64(a,b)); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vcltq_f64(a,b)); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vceqq_f64(a,b)); }
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index c9aa391..2f50326 100755
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -384,10 +384,17 @@
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
-template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
-template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { return _mm_cmpnge_ps(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return _mm_cmple_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return _mm_cmplt_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { return _mm_cmpnge_pd(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
+
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
+
template<> EIGEN_STRONG_INLINE Packet4i ptrue<Packet4i>(const Packet4i& a) { return _mm_cmpeq_epi32(a, a); }
template<> EIGEN_STRONG_INLINE Packet4f