Use native _Float16 for AVX512FP16 and update vectorization.
This allows us to do faster native scalar operations. Also
updated half/quarter packets to use the native type if available.
Benchmark improvement:
```
Comparing ./2910_without_float16 to ./2910_with_float16
Benchmark Time CPU Time Old Time New CPU Old CPU New
------------------------------------------------------------------------------------------------------------------------------------
BM_CalcMat<float>/10000/768/500 -0.0041 -0.0040 58276392 58039442 58273420 58039582
BM_CalcMat<_Float16>/10000/768/500 +0.0073 +0.0073 642506339 647214446 642481384 647188303
BM_CalcMat<Eigen::half>/10000/768/500 -0.3170 -0.3170 92511115 63182101 92506771 63179258
BM_CalcVec<float>/10000/768/500 +0.0022 +0.0022 5198157 5209469 5197913 5209334
BM_CalcVec<_Float16>/10000/768/500 +0.0025 +0.0026 10133324 10159111 10132641 10158507
BM_CalcVec<Eigen::half>/10000/768/500 -0.7760 -0.7760 45337937 10156952 45336532 10156389
OVERALL_GEOMEAN -0.2677 -0.2677 0 0 0 0
```
Fixes #2910.
diff --git a/Eigen/Core b/Eigen/Core
index 99cd473..6ae069a 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -193,21 +193,27 @@
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
#if defined EIGEN_VECTORIZE_AVX512
+#include "src/Core/arch/SSE/PacketMath.h"
+#include "src/Core/arch/AVX/PacketMath.h"
+#include "src/Core/arch/AVX512/PacketMath.h"
#if defined EIGEN_VECTORIZE_AVX512FP16
#include "src/Core/arch/AVX512/PacketMathFP16.h"
#endif
-#include "src/Core/arch/SSE/PacketMath.h"
#include "src/Core/arch/SSE/TypeCasting.h"
-#include "src/Core/arch/SSE/Complex.h"
-#include "src/Core/arch/AVX/PacketMath.h"
#include "src/Core/arch/AVX/TypeCasting.h"
-#include "src/Core/arch/AVX/Complex.h"
-#include "src/Core/arch/AVX512/PacketMath.h"
#include "src/Core/arch/AVX512/TypeCasting.h"
+#if defined EIGEN_VECTORIZE_AVX512FP16
+#include "src/Core/arch/AVX512/TypeCastingFP16.h"
+#endif
+#include "src/Core/arch/SSE/Complex.h"
+#include "src/Core/arch/AVX/Complex.h"
#include "src/Core/arch/AVX512/Complex.h"
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/AVX/MathFunctions.h"
#include "src/Core/arch/AVX512/MathFunctions.h"
+#if defined EIGEN_VECTORIZE_AVX512FP16
+#include "src/Core/arch/AVX512/MathFunctionsFP16.h"
+#endif
#include "src/Core/arch/AVX512/TrsmKernel.h"
#elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers
diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h
index 8e2705b..cf8dcc3 100644
--- a/Eigen/src/Core/MathFunctionsImpl.h
+++ b/Eigen/src/Core/MathFunctionsImpl.h
@@ -76,7 +76,7 @@
static_assert(Steps > 0, "Steps must be at least 1.");
using Scalar = typename unpacket_traits<Packet>::type;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
- constexpr Scalar kMinusHalf = Scalar(-1) / Scalar(2);
+ const Scalar kMinusHalf = Scalar(-1) / Scalar(2);
const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index a5c38e7..eb0011c 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -106,6 +106,8 @@
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
+
+#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp2)
@@ -118,6 +120,7 @@
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
+#endif
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index c29523a..aa93e45 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -1839,10 +1839,13 @@
return a;
}
+#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) {
return _mm_cmpgt_epi16(_mm_setzero_si128(), a);
}
+#endif // EIGEN_VECTORIZE_AVX512FP16
+
template <>
EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) {
return _mm_cmpgt_epi16(_mm_setzero_si128(), a);
@@ -2044,10 +2047,13 @@
return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
}
+#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8h& x) {
return _mm_movemask_epi8(x) != 0;
}
+#endif // EIGEN_VECTORIZE_AVX512FP16
+
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& x) {
return _mm_movemask_epi8(x) != 0;
@@ -2211,7 +2217,6 @@
};
typedef Packet8h half;
};
-#endif
template <>
EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
@@ -2446,14 +2451,12 @@
to[stride * 7] = aux[7];
}
-#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
Packet8f af = half2float(a);
float reduced = predux<Packet8f>(af);
return Eigen::half(reduced);
}
-#endif
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
@@ -2553,6 +2556,8 @@
kernel.packet[3] = pload<Packet8h>(out[3]);
}
+#endif
+
// BFloat16 implementation.
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h
index 9dcd6ef..5b73ffe 100644
--- a/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -279,20 +279,22 @@
}
#endif
+#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
return half2float(a);
}
template <>
-EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
- return Bf16ToF32(a);
-}
-
-template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
+#endif
+
+template <>
+EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
+ return Bf16ToF32(a);
+}
template <>
EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index 6039254..04499a0 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -47,16 +47,16 @@
#if EIGEN_FAST_MATH
template <>
-EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& _x) {
- return generic_sqrt_newton_step<Packet16f>::run(_x, _mm512_rsqrt14_ps(_x));
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& x) {
+ return generic_sqrt_newton_step<Packet16f>::run(x, _mm512_rsqrt14_ps(x));
}
template <>
-EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& _x) {
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& x) {
#ifdef EIGEN_VECTORIZE_AVX512ER
- return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(_x, _mm512_rsqrt28_pd(_x));
+ return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
#else
- return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
+ return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
#endif
}
#else
@@ -80,19 +80,19 @@
#elif EIGEN_FAST_MATH
template <>
-EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& _x) {
- return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(_x, _mm512_rsqrt14_ps(_x));
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& x) {
+ return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(x, _mm512_rsqrt14_ps(x));
}
#endif
// prsqrt for double.
#if EIGEN_FAST_MATH
template <>
-EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& _x) {
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& x) {
#ifdef EIGEN_VECTORIZE_AVX512ER
- return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(_x, _mm512_rsqrt28_pd(_x));
+ return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
#else
- return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
+ return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
#endif
}
@@ -118,6 +118,8 @@
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
+
+#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2)
@@ -130,6 +132,7 @@
F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
+#endif // EIGEN_VECTORIZE_AVX512FP16
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h b/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
new file mode 100644
index 0000000..240ade4
--- /dev/null
+++ b/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
@@ -0,0 +1,75 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2025 The Eigen Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
+#define EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
+
+// IWYU pragma: private
+#include "../../InternalHeaderCheck.h"
+
+namespace Eigen {
+namespace internal {
+
+EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
+ __m512i result = _mm512_castsi256_si512(_mm256_castph_si256(a));
+ result = _mm512_inserti64x4(result, _mm256_castph_si256(b), 1);
+ return _mm512_castsi512_ph(result);
+}
+
+EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
+ a = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_castph_si512(x)));
+ b = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(_mm512_castph_si512(x), 1));
+}
+
+#define _EIGEN_GENERATE_FP16_MATH_FUNCTION(func) \
+ template <> \
+ EIGEN_STRONG_INLINE Packet8h func<Packet8h>(const Packet8h& a) { \
+ return float2half(func(half2float(a))); \
+ } \
+ \
+ template <> \
+ EIGEN_STRONG_INLINE Packet16h func<Packet16h>(const Packet16h& a) { \
+ return float2half(func(half2float(a))); \
+ } \
+ \
+ template <> \
+ EIGEN_STRONG_INLINE Packet32h func<Packet32h>(const Packet32h& a) { \
+ Packet16h low; \
+ Packet16h high; \
+ extract2Packet16h(a, low, high); \
+ return combine2Packet16h(func(low), func(high)); \
+ }
+
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(psin)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(pcos)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog2)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog1p)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexpm1)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp2)
+_EIGEN_GENERATE_FP16_MATH_FUNCTION(ptanh)
+#undef _EIGEN_GENERATE_FP16_MATH_FUNCTION
+
+// pfrexp
+template <>
+EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
+ return pfrexp_generic(a, exponent);
+}
+
+// pldexp
+template <>
+EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
+ return pldexp_generic(a, exponent);
+}
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
\ No newline at end of file
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 5d869e4..c077749 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -40,6 +40,10 @@
#endif
typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
+typedef eigen_packet_wrapper<__m512i, 6> Packet32s;
+typedef eigen_packet_wrapper<__m256i, 6> Packet16s;
+typedef eigen_packet_wrapper<__m128i, 6> Packet8s;
+
template <>
struct is_arithmetic<__m512> {
enum { value = true };
@@ -249,6 +253,39 @@
#endif
template <>
+struct unpacket_traits<Packet32s> {
+ typedef numext::int16_t type;
+ typedef Packet16s half;
+ enum {
+ size = 32,
+ alignment = Aligned64,
+ vectorizable = false,
+ };
+};
+
+template <>
+struct unpacket_traits<Packet16s> {
+ typedef numext::int16_t type;
+ typedef Packet8s half;
+ enum {
+ size = 16,
+ alignment = Aligned32,
+ vectorizable = false,
+ };
+};
+
+template <>
+struct unpacket_traits<Packet8s> {
+ typedef numext::int16_t type;
+ typedef Packet8s half;
+ enum {
+ size = 8,
+ alignment = Aligned16,
+ vectorizable = false,
+ };
+};
+
+template <>
EIGEN_STRONG_INLINE Packet16f pset1<Packet16f>(const float& from) {
return _mm512_set1_ps(from);
}
@@ -1335,10 +1372,13 @@
return _mm512_abs_epi64(a);
}
+#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) {
return _mm256_srai_epi16(a, 15);
}
+#endif // EIGEN_VECTORIZE_AVX512FP16
+
template <>
EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) {
return _mm256_srai_epi16(a, 15);
@@ -2199,6 +2239,7 @@
}
// Packet math for Eigen::half
+#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);
@@ -2369,7 +2410,6 @@
return _mm256_xor_si256(a, sign_mask);
}
-#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
@@ -2408,8 +2448,6 @@
return half(predux(from_float));
}
-#endif
-
template <>
EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
Packet8h lane0 = _mm256_extractf128_si256(a, 0);
@@ -2643,6 +2681,8 @@
kernel.packet[3] = pload<Packet16h>(out[3]);
}
+#endif // EIGEN_VECTORIZE_AVX512FP16
+
template <>
struct is_arithmetic<Packet16bf> {
enum { value = true };
@@ -3095,6 +3135,158 @@
kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
}
+// Minimal implementation of 16-bit int packets for use in pfrexp, pldexp.
+
+template <>
+EIGEN_STRONG_INLINE Packet32s pset1<Packet32s>(const numext::int16_t& x) {
+ return _mm512_set1_epi16(x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16s pset1<Packet16s>(const numext::int16_t& x) {
+ return _mm256_set1_epi16(x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
+ return _mm_set1_epi16(x);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
+ _mm512_storeu_epi16(out, x);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
+ _mm256_storeu_epi16(out, x);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
+ _mm_storeu_epi16(out, x);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
+ _mm512_storeu_epi16(out, x);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
+ _mm256_storeu_epi16(out, x);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
+ _mm_storeu_epi16(out, x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32s padd(const Packet32s& a, const Packet32s& b) {
+ return _mm512_add_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16s padd(const Packet16s& a, const Packet16s& b) {
+ return _mm256_add_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) {
+ return _mm_add_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32s psub(const Packet32s& a, const Packet32s& b) {
+ return _mm512_sub_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16s psub(const Packet16s& a, const Packet16s& b) {
+ return _mm256_sub_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) {
+ return _mm_sub_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32s pmul(const Packet32s& a, const Packet32s& b) {
+ return _mm512_mullo_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16s pmul(const Packet16s& a, const Packet16s& b) {
+ return _mm256_mullo_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) {
+ return _mm_mullo_epi16(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32s pnegate(const Packet32s& a) {
+ return _mm512_sub_epi16(_mm512_setzero_si512(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16s pnegate(const Packet16s& a) {
+ return _mm256_sub_epi16(_mm256_setzero_si256(), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) {
+ return _mm_sub_epi16(_mm_setzero_si128(), a);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet32s parithmetic_shift_right(Packet32s a) {
+ return _mm512_srai_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet16s parithmetic_shift_right(Packet16s a) {
+ return _mm256_srai_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) {
+ return _mm_srai_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet32s plogical_shift_left(Packet32s a) {
+ return _mm512_slli_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet16s plogical_shift_left(Packet16s a) {
+ return _mm256_slli_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) {
+ return _mm_slli_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet32s plogical_shift_right(Packet32s a) {
+ return _mm512_srli_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet16s plogical_shift_right(Packet16s a) {
+ return _mm256_srli_epi16(a, N);
+}
+
+template <int N>
+EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) {
+ return _mm_srli_epi16(a, N);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
index df5a0ef..a040bbe 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-//
+// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -18,8 +18,8 @@
namespace internal {
typedef __m512h Packet32h;
-typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
-typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
+typedef __m256h Packet16h;
+typedef __m128h Packet8h;
template <>
struct is_arithmetic<Packet8h> {
@@ -68,6 +68,7 @@
struct unpacket_traits<Packet32h> {
typedef Eigen::half type;
typedef Packet16h half;
+ typedef Packet32s integer_packet;
enum {
size = 32,
alignment = Aligned64,
@@ -81,6 +82,7 @@
struct unpacket_traits<Packet16h> {
typedef Eigen::half type;
typedef Packet8h half;
+ typedef Packet16s integer_packet;
enum {
size = 16,
alignment = Aligned32,
@@ -94,6 +96,7 @@
struct unpacket_traits<Packet8h> {
typedef Eigen::half type;
typedef Packet8h half;
+ typedef Packet8s integer_packet;
enum {
size = 8,
alignment = Aligned16,
@@ -103,14 +106,33 @@
};
};
+// Conversions
+
+EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { return _mm512_cvtxph_ps(a); }
+
+EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { return _mm256_cvtxph_ps(a); }
+
+EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { return _mm512_cvtxps_ph(a); }
+
+EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { return _mm256_cvtxps_ph(a); }
+
// Memory functions
// pset1
template <>
EIGEN_STRONG_INLINE Packet32h pset1<Packet32h>(const Eigen::half& from) {
- // half/half_raw is bit compatible
- return _mm512_set1_ph(numext::bit_cast<_Float16>(from));
+ return _mm512_set1_ph(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
+ return _mm256_set1_ph(from.x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
+ return _mm_set1_ph(from.x);
}
template <>
@@ -118,24 +140,47 @@
return _mm512_setzero_ph();
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pzero(const Packet16h& /*a*/) {
+ return _mm256_setzero_ph();
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pzero(const Packet8h& /*a*/) {
+ return _mm_setzero_ph();
+}
+
// pset1frombits
template <>
EIGEN_STRONG_INLINE Packet32h pset1frombits<Packet32h>(unsigned short from) {
return _mm512_castsi512_ph(_mm512_set1_epi16(from));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pset1frombits<Packet16h>(unsigned short from) {
+ return _mm256_castsi256_ph(_mm256_set1_epi16(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pset1frombits<Packet8h>(unsigned short from) {
+ return _mm_castsi128_ph(_mm_set1_epi16(from));
+}
+
// pfirst
template <>
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet32h>(const Packet32h& from) {
-#ifdef EIGEN_VECTORIZE_AVX512DQ
- return half_impl::raw_uint16_to_half(
- static_cast<unsigned short>(_mm256_extract_epi16(_mm512_extracti32x8_epi32(_mm512_castph_si512(from), 0), 0)));
-#else
- Eigen::half dest[32];
- _mm512_storeu_ph(dest, from);
- return dest[0];
-#endif
+ return Eigen::half(_mm512_cvtsh_h(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
+ return Eigen::half(_mm256_cvtsh_h(from));
+}
+
+template <>
+EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
+ return Eigen::half(_mm_cvtsh_h(from));
}
// pload
@@ -145,6 +190,16 @@
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ph(from);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ph(from);
+}
+
// ploadu
template <>
@@ -152,6 +207,16 @@
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ph(from);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_ph(from);
+}
+
// pstore
template <>
@@ -159,6 +224,16 @@
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from);
}
+template <>
+EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ph(to, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet8h& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm_store_ph(to, from);
+}
+
// pstoreu
template <>
@@ -166,6 +241,16 @@
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from);
}
+template <>
+EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ph(to, from);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet8h& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ph(to, from);
+}
+
// ploaddup
template <>
EIGEN_STRONG_INLINE Packet32h ploaddup<Packet32h>(const Eigen::half* from) {
@@ -175,6 +260,17 @@
a);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h ploaddup<Packet16h>(const Eigen::half* from) {
+ __m256h a = _mm256_castph128_ph256(_mm_loadu_ph(from));
+ return _mm256_permutexvar_ph(_mm256_set_epi16(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0), a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h ploaddup<Packet8h>(const Eigen::half* from) {
+ return _mm_set_ph(from[3].x, from[3].x, from[2].x, from[2].x, from[1].x, from[1].x, from[0].x, from[0].x);
+}
+
// ploadquad
template <>
EIGEN_STRONG_INLINE Packet32h ploadquad<Packet32h>(const Eigen::half* from) {
@@ -184,6 +280,17 @@
a);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h ploadquad<Packet16h>(const Eigen::half* from) {
+ return _mm256_set_ph(from[3].x, from[3].x, from[3].x, from[3].x, from[2].x, from[2].x, from[2].x, from[2].x,
+ from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h ploadquad<Packet8h>(const Eigen::half* from) {
+ return _mm_set_ph(from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x);
+}
+
// pabs
template <>
@@ -191,6 +298,16 @@
return _mm512_abs_ph(a);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pabs<Packet16h>(const Packet16h& a) {
+ return _mm256_abs_ph(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pabs<Packet8h>(const Packet8h& a) {
+ return _mm_abs_ph(a);
+}
+
// psignbit
template <>
@@ -198,6 +315,16 @@
return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h psignbit<Packet16h>(const Packet16h& a) {
+ return _mm256_castsi256_ph(_mm256_srai_epi16(_mm256_castph_si256(a), 15));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h psignbit<Packet8h>(const Packet8h& a) {
+ return _mm_castsi128_ph(_mm_srai_epi16(_mm_castph_si128(a), 15));
+}
+
// pmin
template <>
@@ -205,6 +332,16 @@
return _mm512_min_ph(a, b);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, const Packet16h& b) {
+ return _mm256_min_ph(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ return _mm_min_ph(a, b);
+}
+
// pmax
template <>
@@ -212,6 +349,16 @@
return _mm512_max_ph(a, b);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, const Packet16h& b) {
+ return _mm256_max_ph(a, b);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a, const Packet8h& b) {
+ return _mm_max_ph(a, b);
+}
+
// plset
template <>
EIGEN_STRONG_INLINE Packet32h plset<Packet32h>(const half& a) {
@@ -219,6 +366,16 @@
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
+ return _mm256_add_ph(pset1<Packet16h>(a), _mm256_set_ph(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
+ return _mm_add_ph(pset1<Packet8h>(a), _mm_set_ph(7, 6, 5, 4, 3, 2, 1, 0));
+}
+
// por
template <>
@@ -226,6 +383,16 @@
return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a, const Packet16h& b) {
+ return _mm256_castsi256_ph(_mm256_or_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a, const Packet8h& b) {
+ return _mm_castsi128_ph(_mm_or_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
+}
+
// pxor
template <>
@@ -233,6 +400,16 @@
return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a, const Packet16h& b) {
+ return _mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a, const Packet8h& b) {
+ return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
+}
+
// pand
template <>
@@ -240,6 +417,16 @@
return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a, const Packet16h& b) {
+ return _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a, const Packet8h& b) {
+ return _mm_castsi128_ph(_mm_and_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
+}
+
// pandnot
template <>
@@ -247,6 +434,16 @@
return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a, const Packet16h& b) {
+ return _mm256_castsi256_ph(_mm256_andnot_si256(_mm256_castph_si256(b), _mm256_castph_si256(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a, const Packet8h& b) {
+ return _mm_castsi128_ph(_mm_andnot_si128(_mm_castph_si128(b), _mm_castph_si128(a)));
+}
+
// pselect
template <>
@@ -255,6 +452,18 @@
return _mm512_mask_blend_ph(mask32, a, b);
}
+template <>
+EIGEN_DEVICE_FUNC inline Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
+ __mmask16 mask16 = _mm256_cmp_epi16_mask(_mm256_castph_si256(mask), _mm256_setzero_si256(), _MM_CMPINT_EQ);
+ return _mm256_mask_blend_ph(mask16, a, b);
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
+ __mmask8 mask8 = _mm_cmp_epi16_mask(_mm_castph_si128(mask), _mm_setzero_si128(), _MM_CMPINT_EQ);
+ return _mm_mask_blend_ph(mask8, a, b);
+}
+
// pcmp_eq
template <>
@@ -263,6 +472,18 @@
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) {
+ __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_EQ_OQ);
+ return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
+ __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_EQ_OQ);
+ return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
// pcmp_le
template <>
@@ -271,6 +492,18 @@
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a, const Packet16h& b) {
+ __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LE_OQ);
+ return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) {
+ __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LE_OQ);
+ return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
// pcmp_lt
template <>
@@ -279,6 +512,18 @@
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a, const Packet16h& b) {
+ __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LT_OQ);
+ return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) {
+ __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LT_OQ);
+ return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
// pcmp_lt_or_nan
template <>
@@ -287,6 +532,18 @@
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, static_cast<short>(0xffffu)));
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a, const Packet16h& b) {
+ __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_NGE_UQ);
+ return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) {
+ __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_NGE_UQ);
+ return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
+}
+
// padd
template <>
@@ -296,12 +553,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
- return _mm256_castph_si256(_mm256_add_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
+ return _mm256_add_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
- return _mm_castph_si128(_mm_add_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
+ return _mm_add_ph(a, b);
}
// psub
@@ -313,12 +570,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
- return _mm256_castph_si256(_mm256_sub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
+ return _mm256_sub_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
- return _mm_castph_si128(_mm_sub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
+ return _mm_sub_ph(a, b);
}
// pmul
@@ -330,12 +587,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
- return _mm256_castph_si256(_mm256_mul_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
+ return _mm256_mul_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
- return _mm_castph_si128(_mm_mul_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
+ return _mm_mul_ph(a, b);
}
// pdiv
@@ -347,12 +604,13 @@
template <>
EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
- return _mm256_castph_si256(_mm256_div_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
+ return _mm256_div_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
- return _mm_castph_si128(_mm_div_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
+ return _mm_div_ph(a, b);
+ ;
}
// pround
@@ -361,14 +619,40 @@
EIGEN_STRONG_INLINE Packet32h pround<Packet32h>(const Packet32h& a) {
// Work-around for default std::round rounding mode.
- // Mask for the sign bit
- const Packet32h signMask = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x8000u));
- // The largest half-preicision float less than 0.5
+ // Mask for the sign bit.
+ const Packet32h signMask =
+ pset1frombits<Packet32h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
+ // The largest half-precision float less than 0.5.
const Packet32h prev0dot5 = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x37FFu));
return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
+ // Work-around for default std::round rounding mode.
+
+ // Mask for the sign bit.
+ const Packet16h signMask =
+ pset1frombits<Packet16h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
+ // The largest half-precision float less than 0.5.
+ const Packet16h prev0dot5 = pset1frombits<Packet16h>(static_cast<numext::uint16_t>(0x37FFu));
+
+ return _mm256_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
+ // Work-around for default std::round rounding mode.
+
+ // Mask for the sign bit.
+ const Packet8h signMask = pset1frombits<Packet8h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
+ // The largest half-precision float less than 0.5.
+ const Packet8h prev0dot5 = pset1frombits<Packet8h>(static_cast<numext::uint16_t>(0x37FFu));
+
+ return _mm_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
+}
+
// print
template <>
@@ -376,6 +660,16 @@
return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
+ return _mm256_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
+ return _mm_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
+}
+
// pceil
template <>
@@ -383,6 +677,16 @@
return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
+ return _mm256_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
+ return _mm_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
+}
+
// pfloor
template <>
@@ -390,6 +694,16 @@
return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
+ return _mm256_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
+ return _mm_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
+}
+
// ptrunc
template <>
@@ -397,47 +711,99 @@
return _mm512_roundscale_ph(a, _MM_FROUND_TO_ZERO);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h ptrunc<Packet16h>(const Packet16h& a) {
+ return _mm256_roundscale_ph(a, _MM_FROUND_TO_ZERO);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
+ return _mm_roundscale_ph(a, _MM_FROUND_TO_ZERO);
+}
+
// predux
template <>
EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
- return (half)_mm512_reduce_add_ph(a);
+ return half(_mm512_reduce_add_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& a) {
- return (half)_mm256_reduce_add_ph(_mm256_castsi256_ph(a));
+ return half(_mm256_reduce_add_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux<Packet8h>(const Packet8h& a) {
- return (half)_mm_reduce_add_ph(_mm_castsi128_ph(a));
+ return half(_mm_reduce_add_ph(a));
}
// predux_half_dowto4
template <>
EIGEN_STRONG_INLINE Packet16h predux_half_dowto4<Packet32h>(const Packet32h& a) {
-#ifdef EIGEN_VECTORIZE_AVX512DQ
- __m256i lowHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 0));
- __m256i highHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 1));
+ const __m512i bits = _mm512_castph_si512(a);
+ Packet16h lo = _mm256_castsi256_ph(_mm512_castsi512_si256(bits));
+ Packet16h hi = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(bits, 1));
+ return padd(lo, hi);
+}
- return Packet16h(padd<Packet16h>(lowHalf, highHalf));
-#else
- Eigen::half data[32];
- _mm512_storeu_ph(data, a);
-
- __m256i lowHalf = _mm256_castph_si256(_mm256_loadu_ph(data));
- __m256i highHalf = _mm256_castph_si256(_mm256_loadu_ph(data + 16));
-
- return Packet16h(padd<Packet16h>(lowHalf, highHalf));
-#endif
+template <>
+EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
+ Packet8h lo = _mm_castsi128_ph(_mm256_castsi256_si128(_mm256_castph_si256(a)));
+ Packet8h hi = _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(a), 1));
+ return padd(lo, hi);
}
// predux_max
+template <>
+EIGEN_STRONG_INLINE half predux_max<Packet32h>(const Packet32h& a) {
+ return half(_mm512_reduce_max_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE half predux_max<Packet16h>(const Packet16h& a) {
+ return half(_mm256_reduce_max_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE half predux_max<Packet8h>(const Packet8h& a) {
+ return half(_mm_reduce_max_ph(a));
+}
+
// predux_min
+template <>
+EIGEN_STRONG_INLINE half predux_min<Packet32h>(const Packet32h& a) {
+ return half(_mm512_reduce_min_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE half predux_min<Packet16h>(const Packet16h& a) {
+ return half(_mm256_reduce_min_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE half predux_min<Packet8h>(const Packet8h& a) {
+ return half(_mm_reduce_min_ph(a));
+}
+
// predux_mul
+template <>
+EIGEN_STRONG_INLINE half predux_mul<Packet32h>(const Packet32h& a) {
+ return half(_mm512_reduce_mul_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& a) {
+ return half(_mm256_reduce_mul_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE half predux_mul<Packet8h>(const Packet8h& a) {
+ return half(_mm_reduce_mul_ph(a));
+}
+
#ifdef EIGEN_VECTORIZE_FMA
// pmadd
@@ -449,12 +815,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
- return _mm256_castph_si256(_mm256_fmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
+ return _mm256_fmadd_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
- return _mm_castph_si128(_mm_fmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
+ return _mm_fmadd_ph(a, b, c);
}
// pmsub
@@ -466,12 +832,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
- return _mm256_castph_si256(_mm256_fmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
+ return _mm256_fmsub_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
- return _mm_castph_si128(_mm_fmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
+ return _mm_fmsub_ph(a, b, c);
}
// pnmadd
@@ -483,12 +849,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
- return _mm256_castph_si256(_mm256_fnmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
+ return _mm256_fnmadd_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
- return _mm_castph_si128(_mm_fnmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
+ return _mm_fnmadd_ph(a, b, c);
}
// pnmsub
@@ -500,12 +866,12 @@
template <>
EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
- return _mm256_castph_si256(_mm256_fnmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
+ return _mm256_fnmsub_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
- return _mm_castph_si128(_mm_fnmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
+ return _mm_fnmsub_ph(a, b, c);
}
#endif
@@ -514,35 +880,74 @@
template <>
EIGEN_STRONG_INLINE Packet32h pnegate<Packet32h>(const Packet32h& a) {
- return psub(pzero(a), a);
+ return _mm512_castsi512_ph(
+ _mm512_xor_si512(_mm512_castph_si512(a), _mm512_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pnegate<Packet16h>(const Packet16h& a) {
+ return _mm256_castsi256_ph(
+ _mm256_xor_si256(_mm256_castph_si256(a), _mm256_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h pnegate<Packet8h>(const Packet8h& a) {
+ return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
}
// pconj
-template <>
-EIGEN_STRONG_INLINE Packet32h pconj<Packet32h>(const Packet32h& a) {
- return a;
-}
+// Nothing, packets are real.
// psqrt
template <>
EIGEN_STRONG_INLINE Packet32h psqrt<Packet32h>(const Packet32h& a) {
- return _mm512_sqrt_ph(a);
+ return generic_sqrt_newton_step<Packet32h>::run(a, _mm512_rsqrt_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h psqrt<Packet16h>(const Packet16h& a) {
+ return generic_sqrt_newton_step<Packet16h>::run(a, _mm256_rsqrt_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h psqrt<Packet8h>(const Packet8h& a) {
+ return generic_sqrt_newton_step<Packet8h>::run(a, _mm_rsqrt_ph(a));
}
// prsqrt
template <>
EIGEN_STRONG_INLINE Packet32h prsqrt<Packet32h>(const Packet32h& a) {
- return _mm512_rsqrt_ph(a);
+ return generic_rsqrt_newton_step<Packet32h, /*Steps=*/1>::run(a, _mm512_rsqrt_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h prsqrt<Packet16h>(const Packet16h& a) {
+ return generic_rsqrt_newton_step<Packet16h, /*Steps=*/1>::run(a, _mm256_rsqrt_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h prsqrt<Packet8h>(const Packet8h& a) {
+ return generic_rsqrt_newton_step<Packet8h, /*Steps=*/1>::run(a, _mm_rsqrt_ph(a));
}
// preciprocal
template <>
EIGEN_STRONG_INLINE Packet32h preciprocal<Packet32h>(const Packet32h& a) {
- return _mm512_rcp_ph(a);
+ return generic_reciprocal_newton_step<Packet32h, /*Steps=*/1>::run(a, _mm512_rcp_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h preciprocal<Packet16h>(const Packet16h& a) {
+ return generic_reciprocal_newton_step<Packet16h, /*Steps=*/1>::run(a, _mm256_rcp_ph(a));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h preciprocal<Packet8h>(const Packet8h& a) {
+ return generic_reciprocal_newton_step<Packet8h, /*Steps=*/1>::run(a, _mm_rcp_ph(a));
}
// ptranspose
@@ -663,6 +1068,246 @@
a.packet[3] = _mm512_castsi512_ph(a3);
}
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 16>& kernel) {
+ __m256i a = _mm256_castph_si256(kernel.packet[0]);
+ __m256i b = _mm256_castph_si256(kernel.packet[1]);
+ __m256i c = _mm256_castph_si256(kernel.packet[2]);
+ __m256i d = _mm256_castph_si256(kernel.packet[3]);
+ __m256i e = _mm256_castph_si256(kernel.packet[4]);
+ __m256i f = _mm256_castph_si256(kernel.packet[5]);
+ __m256i g = _mm256_castph_si256(kernel.packet[6]);
+ __m256i h = _mm256_castph_si256(kernel.packet[7]);
+ __m256i i = _mm256_castph_si256(kernel.packet[8]);
+ __m256i j = _mm256_castph_si256(kernel.packet[9]);
+ __m256i k = _mm256_castph_si256(kernel.packet[10]);
+ __m256i l = _mm256_castph_si256(kernel.packet[11]);
+ __m256i m = _mm256_castph_si256(kernel.packet[12]);
+ __m256i n = _mm256_castph_si256(kernel.packet[13]);
+ __m256i o = _mm256_castph_si256(kernel.packet[14]);
+ __m256i p = _mm256_castph_si256(kernel.packet[15]);
+
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
+
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
+
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
+
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
+
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
+
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
+ __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
+
+ kernel.packet[0] = _mm256_castsi256_ph(a_p_0);
+ kernel.packet[1] = _mm256_castsi256_ph(a_p_1);
+ kernel.packet[2] = _mm256_castsi256_ph(a_p_2);
+ kernel.packet[3] = _mm256_castsi256_ph(a_p_3);
+ kernel.packet[4] = _mm256_castsi256_ph(a_p_4);
+ kernel.packet[5] = _mm256_castsi256_ph(a_p_5);
+ kernel.packet[6] = _mm256_castsi256_ph(a_p_6);
+ kernel.packet[7] = _mm256_castsi256_ph(a_p_7);
+ kernel.packet[8] = _mm256_castsi256_ph(a_p_8);
+ kernel.packet[9] = _mm256_castsi256_ph(a_p_9);
+ kernel.packet[10] = _mm256_castsi256_ph(a_p_a);
+ kernel.packet[11] = _mm256_castsi256_ph(a_p_b);
+ kernel.packet[12] = _mm256_castsi256_ph(a_p_c);
+ kernel.packet[13] = _mm256_castsi256_ph(a_p_d);
+ kernel.packet[14] = _mm256_castsi256_ph(a_p_e);
+ kernel.packet[15] = _mm256_castsi256_ph(a_p_f);
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 8>& kernel) {
+ EIGEN_ALIGN64 half in[8][16];
+ pstore<half>(in[0], kernel.packet[0]);
+ pstore<half>(in[1], kernel.packet[1]);
+ pstore<half>(in[2], kernel.packet[2]);
+ pstore<half>(in[3], kernel.packet[3]);
+ pstore<half>(in[4], kernel.packet[4]);
+ pstore<half>(in[5], kernel.packet[5]);
+ pstore<half>(in[6], kernel.packet[6]);
+ pstore<half>(in[7], kernel.packet[7]);
+
+ EIGEN_ALIGN64 half out[8][16];
+
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ out[i][j] = in[j][2 * i];
+ }
+ for (int j = 0; j < 8; ++j) {
+ out[i][j + 8] = in[j][2 * i + 1];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet16h>(out[0]);
+ kernel.packet[1] = pload<Packet16h>(out[1]);
+ kernel.packet[2] = pload<Packet16h>(out[2]);
+ kernel.packet[3] = pload<Packet16h>(out[3]);
+ kernel.packet[4] = pload<Packet16h>(out[4]);
+ kernel.packet[5] = pload<Packet16h>(out[5]);
+ kernel.packet[6] = pload<Packet16h>(out[6]);
+ kernel.packet[7] = pload<Packet16h>(out[7]);
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
+ EIGEN_ALIGN64 half in[4][16];
+ pstore<half>(in[0], kernel.packet[0]);
+ pstore<half>(in[1], kernel.packet[1]);
+ pstore<half>(in[2], kernel.packet[2]);
+ pstore<half>(in[3], kernel.packet[3]);
+
+ EIGEN_ALIGN64 half out[4][16];
+
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ out[i][j] = in[j][4 * i];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j + 4] = in[j][4 * i + 1];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j + 8] = in[j][4 * i + 2];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j + 12] = in[j][4 * i + 3];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet16h>(out[0]);
+ kernel.packet[1] = pload<Packet16h>(out[1]);
+ kernel.packet[2] = pload<Packet16h>(out[2]);
+ kernel.packet[3] = pload<Packet16h>(out[3]);
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 8>& kernel) {
+ __m128i a = _mm_castph_si128(kernel.packet[0]);
+ __m128i b = _mm_castph_si128(kernel.packet[1]);
+ __m128i c = _mm_castph_si128(kernel.packet[2]);
+ __m128i d = _mm_castph_si128(kernel.packet[3]);
+ __m128i e = _mm_castph_si128(kernel.packet[4]);
+ __m128i f = _mm_castph_si128(kernel.packet[5]);
+ __m128i g = _mm_castph_si128(kernel.packet[6]);
+ __m128i h = _mm_castph_si128(kernel.packet[7]);
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+
+ kernel.packet[0] = _mm_castsi128_ph(a0b0c0d0e0f0g0h0);
+ kernel.packet[1] = _mm_castsi128_ph(a1b1c1d1e1f1g1h1);
+ kernel.packet[2] = _mm_castsi128_ph(a2b2c2d2e2f2g2h2);
+ kernel.packet[3] = _mm_castsi128_ph(a3b3c3d3e3f3g3h3);
+ kernel.packet[4] = _mm_castsi128_ph(a4b4c4d4e4f4g4h4);
+ kernel.packet[5] = _mm_castsi128_ph(a5b5c5d5e5f5g5h5);
+ kernel.packet[6] = _mm_castsi128_ph(a6b6c6d6e6f6g6h6);
+ kernel.packet[7] = _mm_castsi128_ph(a7b7c7d7e7f7g7h7);
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 4>& kernel) {
+ EIGEN_ALIGN32 Eigen::half in[4][8];
+ pstore<Eigen::half>(in[0], kernel.packet[0]);
+ pstore<Eigen::half>(in[1], kernel.packet[1]);
+ pstore<Eigen::half>(in[2], kernel.packet[2]);
+ pstore<Eigen::half>(in[3], kernel.packet[3]);
+
+ EIGEN_ALIGN32 Eigen::half out[4][8];
+
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ out[i][j] = in[j][2 * i];
+ }
+ for (int j = 0; j < 4; ++j) {
+ out[i][j + 4] = in[j][2 * i + 1];
+ }
+ }
+
+ kernel.packet[0] = pload<Packet8h>(out[0]);
+ kernel.packet[1] = pload<Packet8h>(out[1]);
+ kernel.packet[2] = pload<Packet8h>(out[2]);
+ kernel.packet[3] = pload<Packet8h>(out[3]);
+}
+
// preverse
template <>
@@ -672,6 +1317,20 @@
a);
}
+template <>
+EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) {
+ __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
+ return _mm256_castsi256_ph(_mm256_insertf128_si256(
+ _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 1), m)),
+ _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 0), m), 1));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) {
+ __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
+ return _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a), m));
+}
+
// pscatter
template <>
@@ -684,191 +1343,68 @@
to[stride * i] = aux[i];
}
}
+template <>
+EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride) {
+ EIGEN_ALIGN64 half aux[16];
+ pstore(aux, from);
+ to[stride * 0] = aux[0];
+ to[stride * 1] = aux[1];
+ to[stride * 2] = aux[2];
+ to[stride * 3] = aux[3];
+ to[stride * 4] = aux[4];
+ to[stride * 5] = aux[5];
+ to[stride * 6] = aux[6];
+ to[stride * 7] = aux[7];
+ to[stride * 8] = aux[8];
+ to[stride * 9] = aux[9];
+ to[stride * 10] = aux[10];
+ to[stride * 11] = aux[11];
+ to[stride * 12] = aux[12];
+ to[stride * 13] = aux[13];
+ to[stride * 14] = aux[14];
+ to[stride * 15] = aux[15];
+}
+
+template <>
+EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) {
+ EIGEN_ALIGN32 Eigen::half aux[8];
+ pstore(aux, from);
+ to[stride * 0] = aux[0];
+ to[stride * 1] = aux[1];
+ to[stride * 2] = aux[2];
+ to[stride * 3] = aux[3];
+ to[stride * 4] = aux[4];
+ to[stride * 5] = aux[5];
+ to[stride * 6] = aux[6];
+ to[stride * 7] = aux[7];
+}
// pgather
template <>
EIGEN_STRONG_INLINE Packet32h pgather<Eigen::half, Packet32h>(const Eigen::half* from, Index stride) {
- return _mm512_castsi512_ph(_mm512_set_epi16(
- from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x, from[27 * stride].x,
- from[26 * stride].x, from[25 * stride].x, from[24 * stride].x, from[23 * stride].x, from[22 * stride].x,
- from[21 * stride].x, from[20 * stride].x, from[19 * stride].x, from[18 * stride].x, from[17 * stride].x,
- from[16 * stride].x, from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
- from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, from[7 * stride].x,
- from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x, from[2 * stride].x,
- from[1 * stride].x, from[0 * stride].x));
+ return _mm512_set_ph(from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x,
+ from[27 * stride].x, from[26 * stride].x, from[25 * stride].x, from[24 * stride].x,
+ from[23 * stride].x, from[22 * stride].x, from[21 * stride].x, from[20 * stride].x,
+ from[19 * stride].x, from[18 * stride].x, from[17 * stride].x, from[16 * stride].x,
+ from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
+ from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
+ from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
+ from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
template <>
-EIGEN_STRONG_INLINE Packet16h pcos<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h psin<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h plog<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h plog2<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h plog1p<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h pexp<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h pexpm1<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h ptanh<Packet16h>(const Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h pfrexp<Packet16h>(const Packet16h&, Packet16h&);
-template <>
-EIGEN_STRONG_INLINE Packet16h pldexp<Packet16h>(const Packet16h&, const Packet16h&);
-
-EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
- __m512d result = _mm512_undefined_pd();
- result = _mm512_insertf64x4(result, _mm256_castsi256_pd(a), 0);
- result = _mm512_insertf64x4(result, _mm256_castsi256_pd(b), 1);
- return _mm512_castpd_ph(result);
+EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride) {
+ return _mm256_set_ph(from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
+ from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
+ from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
+ from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
-EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
- a = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 0));
- b = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 1));
-}
-
-// psin
template <>
-EIGEN_STRONG_INLINE Packet32h psin<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = psin(low);
- Packet16h highOut = psin(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// pcos
-template <>
-EIGEN_STRONG_INLINE Packet32h pcos<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = pcos(low);
- Packet16h highOut = pcos(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// plog
-template <>
-EIGEN_STRONG_INLINE Packet32h plog<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = plog(low);
- Packet16h highOut = plog(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// plog2
-template <>
-EIGEN_STRONG_INLINE Packet32h plog2<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = plog2(low);
- Packet16h highOut = plog2(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// plog1p
-template <>
-EIGEN_STRONG_INLINE Packet32h plog1p<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = plog1p(low);
- Packet16h highOut = plog1p(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// pexp
-template <>
-EIGEN_STRONG_INLINE Packet32h pexp<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = pexp(low);
- Packet16h highOut = pexp(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// pexpm1
-template <>
-EIGEN_STRONG_INLINE Packet32h pexpm1<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = pexpm1(low);
- Packet16h highOut = pexpm1(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// ptanh
-template <>
-EIGEN_STRONG_INLINE Packet32h ptanh<Packet32h>(const Packet32h& a) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h lowOut = ptanh(low);
- Packet16h highOut = ptanh(high);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// pfrexp
-template <>
-EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h exp1 = _mm256_undefined_si256();
- Packet16h exp2 = _mm256_undefined_si256();
-
- Packet16h lowOut = pfrexp(low, exp1);
- Packet16h highOut = pfrexp(high, exp2);
-
- exponent = combine2Packet16h(exp1, exp2);
-
- return combine2Packet16h(lowOut, highOut);
-}
-
-// pldexp
-template <>
-EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
- Packet16h low;
- Packet16h high;
- extract2Packet16h(a, low, high);
-
- Packet16h exp1;
- Packet16h exp2;
- extract2Packet16h(exponent, exp1, exp2);
-
- Packet16h lowOut = pldexp(low, exp1);
- Packet16h highOut = pldexp(high, exp2);
-
- return combine2Packet16h(lowOut, highOut);
+EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) {
+ return _mm_set_ph(from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x,
+ from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h
index 9508ac6..fc55fd8 100644
--- a/Eigen/src/Core/arch/AVX512/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h
@@ -237,17 +237,13 @@
return _mm512_castsi512_si128(a);
}
+#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
return _mm256_castsi256_si128(a);
}
template <>
-EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
- return _mm256_castsi256_si128(a);
-}
-
-template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
}
@@ -257,6 +253,13 @@
return float2half(a);
}
+#endif
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
+ return _mm256_castsi256_si128(a);
+}
+
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
return Bf16ToF32(a);
@@ -267,68 +270,6 @@
return F32ToBf16(a);
}
-#ifdef EIGEN_VECTORIZE_AVX512FP16
-
-template <>
-EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet32h>(const Packet32h& a) {
- return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
-}
-template <>
-EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet32h>(const Packet32h& a) {
- return _mm256_castsi256_si128(preinterpret<Packet16h>(a));
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
- // Discard second-half of input.
- Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
- return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
- __m512d result = _mm512_undefined_pd();
- result = _mm512_insertf64x4(
- result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
- result = _mm512_insertf64x4(
- result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
- return _mm512_castpd_ph(result);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
- // Discard second-half of input.
- Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0));
- return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
- __m256d result = _mm256_undefined_pd();
- result = _mm256_insertf64x2(result,
- _mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
- result = _mm256_insertf64x2(result,
- _mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
- return _mm256_castpd_si256(result);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
- Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a));
- // Discard second-half of input.
- return _mm256_extractf32x4_ps(full, 0);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
- __m256 result = _mm256_undefined_ps();
- result = _mm256_insertf128_ps(result, a, 0);
- result = _mm256_insertf128_ps(result, b, 1);
- return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT);
-}
-
-#endif
-
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h b/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
new file mode 100644
index 0000000..f06f13d
--- /dev/null
+++ b/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
@@ -0,0 +1,130 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2025 The Eigen Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_FP16_AVX512_H
+#define EIGEN_TYPE_CASTING_FP16_AVX512_H
+
+// IWYU pragma: private
+#include "../../InternalHeaderCheck.h"
+
+namespace Eigen {
+namespace internal {
+
+template <>
+EIGEN_STRONG_INLINE Packet32s preinterpret<Packet32s, Packet32h>(const Packet32h& a) {
+ return _mm512_castph_si512(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16s preinterpret<Packet16s, Packet16h>(const Packet16h& a) {
+ return _mm256_castph_si256(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8s preinterpret<Packet8s, Packet8h>(const Packet8h& a) {
+ return _mm_castph_si128(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32h preinterpret<Packet32h, Packet32s>(const Packet32s& a) {
+ return _mm512_castsi512_ph(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet16s>(const Packet16s& a) {
+ return _mm256_castsi256_ph(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet8s>(const Packet8s& a) {
+ return _mm_castsi128_ph(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
+ return half2float(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
+ return half2float(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
+ return float2half(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
+ return float2half(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
+ // Discard second-half of input.
+ Packet16h low = _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
+ return _mm512_cvtxph_ps(low);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
+ // Discard second-half of input.
+ Packet8h low = _mm_castps_ph(_mm256_extractf32x4_ps(_mm256_castph_ps(a), 0));
+ return _mm256_cvtxph_ps(low);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
+ Packet8f full = _mm256_cvtxph_ps(a);
+ // Discard second-half of input.
+ return _mm256_extractf32x4_ps(full, 0);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
+ __m512 result = _mm512_castsi512_ps(_mm512_castsi256_si512(_mm256_castph_si256(_mm512_cvtxps_ph(a))));
+ result = _mm512_insertf32x8(result, _mm256_castph_ps(_mm512_cvtxps_ph(b)), 1);
+ return _mm512_castps_ph(result);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
+ __m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castph_si128(_mm256_cvtxps_ph(a))));
+ result = _mm256_insertf32x4(result, _mm_castph_ps(_mm256_cvtxps_ph(b)), 1);
+ return _mm256_castps_ph(result);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
+ __m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castps_si128(a)));
+ result = _mm256_insertf128_ps(result, b, 1);
+ return _mm256_cvtxps_ph(result);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32s pcast<Packet32h, Packet32s>(const Packet32h& a) {
+ return _mm512_cvtph_epi16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16s pcast<Packet16h, Packet16s>(const Packet16h& a) {
+ return _mm256_cvtph_epi16(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8s pcast<Packet8h, Packet8s>(const Packet8h& a) {
+ return _mm_cvtph_epi16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet32h pcast<Packet32s, Packet32h>(const Packet32s& a) {
+ return _mm512_cvtepi16_ph(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16h pcast<Packet16s, Packet16h>(const Packet16s& a) {
+ return _mm256_cvtepi16_ph(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8h pcast<Packet8s, Packet8h>(const Packet8s& a) {
+ return _mm_cvtepi16_ph(a);
+}
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_FP16_AVX512_H
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index 95697f3..d8c9d5a 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -37,21 +37,23 @@
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
-#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
// When compiling with GPU support, the "__half_raw" base class as well as
// some other routines are defined in the GPU compiler header files
// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
// As a consequence, we get compile failures when compiling Eigen with
// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
-// Eigen with GPU support
-#pragma push_macro("EIGEN_CONSTEXPR")
-#undef EIGEN_CONSTEXPR
-#define EIGEN_CONSTEXPR
+// Eigen with GPU support.
+// Any functions that require `numext::bit_cast` may also not be constexpr,
+// including any native types when setting via raw bit values.
+#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
+#define _EIGEN_MAYBE_CONSTEXPR
+#else
+#define _EIGEN_MAYBE_CONSTEXPR constexpr
#endif
#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
template <> \
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
+ EIGEN_UNUSED EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
return float2half(METHOD<PACKET_F>(half2float(_x))); \
}
@@ -81,8 +83,10 @@
// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
// this error, and hence the following convoluted #if condition
#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
+
// Make our own __half_raw definition that is similar to CUDA's.
struct __half_raw {
+ struct construct_from_rep_tag {};
#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
// Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
// The element type for shared memory cannot have non-trivial constructors
@@ -91,43 +95,53 @@
// hence the need for this
EIGEN_DEVICE_FUNC __half_raw() {}
#else
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw() : x(0) {}
#endif
+
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
- explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
+ explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
+ EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, __fp16 rep) : x{rep} {}
__fp16 x;
+#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
+ explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<_Float16>(raw)) {}
+ EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, _Float16 rep) : x{rep} {}
+ _Float16 x;
#else
- explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
+ explicit EIGEN_DEVICE_FUNC constexpr __half_raw(numext::uint16_t raw) : x(raw) {}
+ EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, numext::uint16_t rep) : x{rep} {}
numext::uint16_t x;
#endif
};
#elif defined(EIGEN_HAS_HIP_FP16)
-// Nothing to do here
+// HIP GPU compile phase: nothing to do here.
// HIP fp16 header file has a definition for __half_raw
#elif defined(EIGEN_HAS_CUDA_FP16)
+
+// CUDA GPU compile phase.
#if EIGEN_CUDA_SDK_VER < 90000
// In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
typedef __half __half_raw;
#endif // defined(EIGEN_HAS_CUDA_FP16)
+
#elif defined(SYCL_DEVICE_ONLY)
typedef cl::sycl::half __half_raw;
#endif
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
struct half_base : public __half_raw {
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base() {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
#if defined(EIGEN_HAS_GPU_FP16)
#if defined(EIGEN_HAS_HIP_FP16)
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
#elif defined(EIGEN_HAS_CUDA_FP16)
#if EIGEN_CUDA_SDK_VER >= 90000
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
#endif
#endif
#endif
@@ -156,21 +170,29 @@
#endif
#endif
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half() {}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
#if defined(EIGEN_HAS_GPU_FP16)
#if defined(EIGEN_HAS_HIP_FP16)
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
#elif defined(EIGEN_HAS_CUDA_FP16)
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
#endif
#endif
#endif
- explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(__fp16 b)
+ : half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
+#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
+ explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(_Float16 b)
+ : half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
+#endif
+
+ explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(bool b)
: half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
template <class T>
explicit EIGEN_DEVICE_FUNC half(T val)
@@ -201,99 +223,99 @@
namespace half_impl {
template <typename = void>
struct numeric_limits_half_impl {
- static EIGEN_CONSTEXPR const bool is_specialized = true;
- static EIGEN_CONSTEXPR const bool is_signed = true;
- static EIGEN_CONSTEXPR const bool is_integer = false;
- static EIGEN_CONSTEXPR const bool is_exact = false;
- static EIGEN_CONSTEXPR const bool has_infinity = true;
- static EIGEN_CONSTEXPR const bool has_quiet_NaN = true;
- static EIGEN_CONSTEXPR const bool has_signaling_NaN = true;
+ static constexpr const bool is_specialized = true;
+ static constexpr const bool is_signed = true;
+ static constexpr const bool is_integer = false;
+ static constexpr const bool is_exact = false;
+ static constexpr const bool has_infinity = true;
+ static constexpr const bool has_quiet_NaN = true;
+ static constexpr const bool has_signaling_NaN = true;
EIGEN_DIAGNOSTICS(push)
EIGEN_DISABLE_DEPRECATED_WARNING
- static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present;
- static EIGEN_CONSTEXPR const bool has_denorm_loss = false;
+ static constexpr const std::float_denorm_style has_denorm = std::denorm_present;
+ static constexpr const bool has_denorm_loss = false;
EIGEN_DIAGNOSTICS(pop)
- static EIGEN_CONSTEXPR const std::float_round_style round_style = std::round_to_nearest;
- static EIGEN_CONSTEXPR const bool is_iec559 = true;
+ static constexpr const std::float_round_style round_style = std::round_to_nearest;
+ static constexpr const bool is_iec559 = true;
// The C++ standard defines this as "true if the set of values representable
// by the type is finite." Half has finite precision.
- static EIGEN_CONSTEXPR const bool is_bounded = true;
- static EIGEN_CONSTEXPR const bool is_modulo = false;
- static EIGEN_CONSTEXPR const int digits = 11;
- static EIGEN_CONSTEXPR const int digits10 =
+ static constexpr const bool is_bounded = true;
+ static constexpr const bool is_modulo = false;
+ static constexpr const int digits = 11;
+ static constexpr const int digits10 =
3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
- static EIGEN_CONSTEXPR const int max_digits10 =
+ static constexpr const int max_digits10 =
5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
- static EIGEN_CONSTEXPR const int radix = std::numeric_limits<float>::radix;
- static EIGEN_CONSTEXPR const int min_exponent = -13;
- static EIGEN_CONSTEXPR const int min_exponent10 = -4;
- static EIGEN_CONSTEXPR const int max_exponent = 16;
- static EIGEN_CONSTEXPR const int max_exponent10 = 4;
- static EIGEN_CONSTEXPR const bool traps = std::numeric_limits<float>::traps;
+ static constexpr const int radix = std::numeric_limits<float>::radix;
+ static constexpr const int min_exponent = -13;
+ static constexpr const int min_exponent10 = -4;
+ static constexpr const int max_exponent = 16;
+ static constexpr const int max_exponent10 = 4;
+ static constexpr const bool traps = std::numeric_limits<float>::traps;
// IEEE754: "The implementer shall choose how tininess is detected, but shall
// detect tininess in the same way for all operations in radix two"
- static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
+ static constexpr const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
- static EIGEN_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
- static EIGEN_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
- static EIGEN_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
- static EIGEN_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
- static EIGEN_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
- static EIGEN_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
- static EIGEN_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
- static EIGEN_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
- static EIGEN_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
};
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_specialized;
+constexpr const bool numeric_limits_half_impl<T>::is_specialized;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_signed;
+constexpr const bool numeric_limits_half_impl<T>::is_signed;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_integer;
+constexpr const bool numeric_limits_half_impl<T>::is_integer;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_exact;
+constexpr const bool numeric_limits_half_impl<T>::is_exact;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_infinity;
+constexpr const bool numeric_limits_half_impl<T>::has_infinity;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_quiet_NaN;
+constexpr const bool numeric_limits_half_impl<T>::has_quiet_NaN;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_signaling_NaN;
+constexpr const bool numeric_limits_half_impl<T>::has_signaling_NaN;
EIGEN_DIAGNOSTICS(push)
EIGEN_DISABLE_DEPRECATED_WARNING
template <typename T>
-EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
+constexpr const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_denorm_loss;
+constexpr const bool numeric_limits_half_impl<T>::has_denorm_loss;
EIGEN_DIAGNOSTICS(pop)
template <typename T>
-EIGEN_CONSTEXPR const std::float_round_style numeric_limits_half_impl<T>::round_style;
+constexpr const std::float_round_style numeric_limits_half_impl<T>::round_style;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_iec559;
+constexpr const bool numeric_limits_half_impl<T>::is_iec559;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_bounded;
+constexpr const bool numeric_limits_half_impl<T>::is_bounded;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_modulo;
+constexpr const bool numeric_limits_half_impl<T>::is_modulo;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::digits;
+constexpr const int numeric_limits_half_impl<T>::digits;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::digits10;
+constexpr const int numeric_limits_half_impl<T>::digits10;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_digits10;
+constexpr const int numeric_limits_half_impl<T>::max_digits10;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::radix;
+constexpr const int numeric_limits_half_impl<T>::radix;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::min_exponent;
+constexpr const int numeric_limits_half_impl<T>::min_exponent;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::min_exponent10;
+constexpr const int numeric_limits_half_impl<T>::min_exponent10;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_exponent;
+constexpr const int numeric_limits_half_impl<T>::max_exponent;
template <typename T>
-EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_exponent10;
+constexpr const int numeric_limits_half_impl<T>::max_exponent10;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::traps;
+constexpr const bool numeric_limits_half_impl<T>::traps;
template <typename T>
-EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::tinyness_before;
+constexpr const bool numeric_limits_half_impl<T>::tinyness_before;
} // end namespace half_impl
} // end namespace Eigen
@@ -320,8 +342,7 @@
(defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
// Note: We deliberately do *not* define this to 1 even if we have Arm's native
// fp16 type since GPU half types are rather different from native CPU half types.
-// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
-#define EIGEN_HAS_NATIVE_FP16
+#define EIGEN_HAS_NATIVE_GPU_FP16
#endif
// Intrinsics for native fp16 support. Note that on current hardware,
@@ -329,7 +350,7 @@
// versions to get the ALU speed increased), but you do save the
// conversion steps back and forth.
-#if defined(EIGEN_HAS_NATIVE_FP16)
+#if defined(EIGEN_HAS_NATIVE_GPU_FP16)
EIGEN_STRONG_INLINE __device__ half operator+(const half& a, const half& b) {
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
return __hadd(::__half(a), ::__half(b));
@@ -371,7 +392,8 @@
EIGEN_STRONG_INLINE __device__ bool operator<=(const half& a, const half& b) { return __hle(a, b); }
EIGEN_STRONG_INLINE __device__ bool operator>(const half& a, const half& b) { return __hgt(a, b); }
EIGEN_STRONG_INLINE __device__ bool operator>=(const half& a, const half& b) { return __hge(a, b); }
-#endif
+
+#endif // EIGEN_HAS_NATIVE_GPU_FP16
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) && !defined(EIGEN_GPU_COMPILE_PHASE)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(vaddh_f16(a.x, b.x)); }
@@ -401,16 +423,47 @@
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return vcleh_f16(a.x, b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return vcgth_f16(a.x, b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return vcgeh_f16(a.x, b.x); }
+
+#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) && !defined(EIGEN_GPU_COMPILE_PHASE)
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(a.x + b.x); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(a.x * b.x); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(a.x - b.x); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(a.x / b.x); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) { return half(-a.x); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) {
+ a = a + b;
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) {
+ a = a * b;
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) {
+ a = a - b;
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
+ a = a / b;
+ return a;
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { return a.x == b.x; }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return a.x != b.x; }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return a.x < b.x; }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return a.x <= b.x; }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return a.x > b.x; }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return a.x >= b.x; }
+
// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
// of the functions, while the latter can only deal with one of them.
-#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
+#elif !defined(EIGEN_HAS_NATIVE_GPU_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
// We need to provide emulated *host-side* FP16 operators for clang.
#pragma push_macro("EIGEN_DEVICE_FUNC")
#undef EIGEN_DEVICE_FUNC
-#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
+#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_GPU_FP16)
#define EIGEN_DEVICE_FUNC __host__
#else // both host and device need emulated ops.
#define EIGEN_DEVICE_FUNC __host__ __device__
@@ -458,6 +511,7 @@
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
#endif
+
#endif // Emulate support for half floats
// Division by an index. Do it in full float precision to avoid accuracy
@@ -493,7 +547,7 @@
// these in hardware. If we need more performance on older/other CPUs, they are
// also possible to vectorize directly.
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
// We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
// in the hip_fp16 header file, and that will trigger a compile error
// On the other hand, having anything but a return statement also triggers a compile error
@@ -515,6 +569,8 @@
// For SYCL, cl::sycl::half is _Float16, so cast directly.
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return numext::bit_cast<numext::uint16_t>(h.x);
+#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
+ return numext::bit_cast<numext::uint16_t>(h.x);
#elif defined(SYCL_DEVICE_ONLY)
return numext::bit_cast<numext::uint16_t>(h);
#else
@@ -528,6 +584,16 @@
__half tmp_ff = __float2half(ff);
return *(__half_raw*)&tmp_ff;
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ __half_raw h;
+ h.x = static_cast<__fp16>(ff);
+ return h;
+
+#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
+ __half_raw h;
+ h.x = static_cast<_Float16>(ff);
+ return h;
+
#elif defined(EIGEN_HAS_FP16_C)
__half_raw h;
#if EIGEN_COMP_MSVC
@@ -538,11 +604,6 @@
#endif
return h;
-#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
- __half_raw h;
- h.x = static_cast<__fp16>(ff);
- return h;
-
#else
uint32_t f_bits = Eigen::numext::bit_cast<uint32_t>(ff);
const uint32_t f32infty_bits = {255 << 23};
@@ -595,6 +656,8 @@
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __half2float(h);
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
+ return static_cast<float>(h.x);
#elif defined(EIGEN_HAS_FP16_C)
#if EIGEN_COMP_MSVC
// MSVC does not have scalar instructions.
@@ -602,8 +665,6 @@
#else
return _cvtsh_ss(h.x);
#endif
-#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
- return static_cast<float>(h.x);
#else
const float magic = Eigen::numext::bit_cast<float>(static_cast<uint32_t>(113 << 23));
const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
@@ -628,7 +689,7 @@
// --- standard functions ---
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const half& a) {
-#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
#else
return (a.x & 0x7fff) == 0x7c00;
@@ -638,7 +699,7 @@
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hisnan(a);
-#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
#else
return (a.x & 0x7fff) > 0x7c00;
@@ -651,6 +712,11 @@
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vabsh_f16(a.x));
+#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
+ half result;
+ result.x =
+ numext::bit_cast<_Float16>(static_cast<numext::uint16_t>(numext::bit_cast<numext::uint16_t>(a.x) & 0x7FFF));
+ return result;
#else
half result;
result.x = a.x & 0x7FFF;
@@ -734,26 +800,9 @@
return half(::fmodf(float(a), float(b)));
}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) {
-#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
- return __hlt(b, a) ? b : a;
-#else
- const float f1 = static_cast<float>(a);
- const float f2 = static_cast<float>(b);
- return f2 < f1 ? b : a;
-#endif
-}
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) {
-#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
- return __hlt(a, b) ? b : a;
-#else
- const float f1 = static_cast<float>(a);
- const float f2 = static_cast<float>(b);
- return f1 < f2 ? b : a;
-#endif
-}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) { return b < a ? b : a; }
+
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; }
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
@@ -794,31 +843,29 @@
struct NumTraits<Eigen::half> : GenericNumTraits<Eigen::half> {
enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
return half_impl::raw_uint16_to_half(0x0800);
}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
return half_impl::raw_uint16_to_half(0x7bff);
}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
return half_impl::raw_uint16_to_half(0xfbff);
}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
return half_impl::raw_uint16_to_half(0x7c00);
}
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
return half_impl::raw_uint16_to_half(0x7e00);
}
};
} // end namespace Eigen
-#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
-#pragma pop_macro("EIGEN_CONSTEXPR")
-#endif
+#undef _EIGEN_MAYBE_CONSTEXPR
namespace Eigen {
namespace numext {
@@ -976,6 +1023,36 @@
}
};
+#ifdef EIGEN_VECTORIZE_FMA
+
+template <>
+EIGEN_DEVICE_FUNC inline half pmadd(const half& a, const half& b, const half& c) {
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return half(vfmah_f16(a.x, b.x, c.x));
+#elif defined(EIGEN_VECTORIZE_AVX512FP16)
+ // Reduces to vfmadd213sh.
+ return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
+#else
+ // Emulate FMA via float.
+ return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
+#endif
+}
+
+template <>
+EIGEN_DEVICE_FUNC inline half pmsub(const half& a, const half& b, const half& c) {
+#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
+ return half(vfmah_f16(a.x, b.x, -c.x));
+#elif defined(EIGEN_VECTORIZE_AVX512FP16)
+ // Reduces to vfmadd213sh.
+ return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), -_mm_set_sh(c.x))));
+#else
+ // Emulate FMA via float.
+ return half(static_cast<float>(a) * static_cast<float>(b) - static_cast<float>(c));
+#endif
+}
+
+#endif
+
} // namespace internal
} // namespace Eigen
diff --git a/Eigen/src/Core/util/ConfigureVectorization.h b/Eigen/src/Core/util/ConfigureVectorization.h
index 5d3f1cf..49f307c 100644
--- a/Eigen/src/Core/util/ConfigureVectorization.h
+++ b/Eigen/src/Core/util/ConfigureVectorization.h
@@ -285,6 +285,8 @@
#ifdef __AVX512FP16__
#ifdef __AVX512VL__
#define EIGEN_VECTORIZE_AVX512FP16
+// Built-in _Float16.
+#define EIGEN_HAS_BUILTIN_FLOAT16 1
#else
#if EIGEN_COMP_GNUC
#error Please add -mavx512vl to your compiler flags: compiling with -mavx512fp16 alone without AVX512-VL is not supported.
diff --git a/test/packet_ostream.h b/test/packet_ostream.h
index 49e1bb0..4a3ee9c 100644
--- a/test/packet_ostream.h
+++ b/test/packet_ostream.h
@@ -7,7 +7,8 @@
// Include this header to be able to print Packets while debugging.
template <typename Packet,
- typename EnableIf = std::enable_if_t<Eigen::internal::unpacket_traits<Packet>::vectorizable> >
+ typename EnableIf = std::enable_if_t<(Eigen::internal::unpacket_traits<Packet>::vectorizable ||
+ Eigen::internal::unpacket_traits<Packet>::size > 1)> >
std::ostream& operator<<(std::ostream& os, const Packet& packet) {
using Scalar = typename Eigen::internal::unpacket_traits<Packet>::type;
Scalar v[Eigen::internal::unpacket_traits<Packet>::size];
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index 102817f..64c55fb 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -26,19 +26,19 @@
}
template <typename T>
inline T REF_MADD(const T& a, const T& b, const T& c) {
- return a * b + c;
+ return internal::pmadd(a, b, c);
}
template <typename T>
inline T REF_MSUB(const T& a, const T& b, const T& c) {
- return a * b - c;
+ return internal::pmsub(a, b, c);
}
template <typename T>
inline T REF_NMADD(const T& a, const T& b, const T& c) {
- return c - a * b;
+ return internal::pnmadd(a, b, c);
}
template <typename T>
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
- return test::negate(a * b + c);
+ return internal::pnmsub(a, b, c);
}
template <typename T>
inline T REF_DIV(const T& a, const T& b) {