Clean up float16 a.k.a. Eigen::half support in Eigen. Move the definition of half to Core/arch/Default and move arch-specific packet ops to their respective sub-directories.
diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h
index 5084fc7..d1df26e 100644
--- a/Eigen/src/Core/arch/GPU/PacketMath.h
+++ b/Eigen/src/Core/arch/GPU/PacketMath.h
@@ -458,6 +458,546 @@
#endif
+// Packet math for Eigen::half
+// Most of the following operations require arch >= 3.0
+#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDACC) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIPCC) && defined(EIGEN_HIP_DEVICE_COMPILE)) || \
+ (defined(EIGEN_HAS_CUDA_FP16) && defined(__clang__) && defined(__CUDA__))
+
+template<> struct is_arithmetic<half2> { enum { value = true }; };
+
+template<> struct packet_traits<Eigen::half> : default_packet_traits
+{
+ typedef half2 type;
+ typedef half2 half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size=2,
+ HasHalfPacket = 0,
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasDiv = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasExp = 1,
+ HasExpm1 = 1,
+ HasLog = 1,
+ HasLog1p = 1
+ };
+};
+
+template<> struct unpacket_traits<half2> { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef half2 half; };
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1<half2>(const Eigen::half& from) {
+#if !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_HIP_DEVICE_COMPILE)
+ half2 r;
+ r.x = from;
+ r.y = from;
+ return r;
+#else
+ return __half2half2(from);
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload<half2>(const Eigen::half* from) {
+ return *reinterpret_cast<const half2*>(from);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu<half2>(const Eigen::half* from) {
+ return __halves2half2(from[0], from[1]);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploaddup<half2>(const Eigen::half* from) {
+ return __halves2half2(from[0], from[0]);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const half2& from) {
+ *reinterpret_cast<half2*>(to) = from;
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const half2& from) {
+#if !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_HIP_DEVICE_COMPILE)
+ to[0] = from.x;
+ to[1] = from.y;
+#else
+ to[0] = __low2half(from);
+ to[1] = __high2half(from);
+#endif
+}
+
+template<>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Aligned>(const Eigen::half* from) {
+
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __ldg((const half2*)from);
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 350
+ return __ldg((const half2*)from);
+#else
+ return __halves2half2(*(from+0), *(from+1));
+#endif
+
+#endif
+}
+
+template<>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Unaligned>(const Eigen::half* from) {
+
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __halves2half2(__ldg(from+0), __ldg(from+1));
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 350
+ return __halves2half2(__ldg(from+0), __ldg(from+1));
+#else
+ return __halves2half2(*(from+0), *(from+1));
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather<Eigen::half, half2>(const Eigen::half* from, Index stride) {
+ return __halves2half2(from[0*stride], from[1*stride]);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, half2>(Eigen::half* to, const half2& from, Index stride) {
+ to[stride*0] = __low2half(from);
+ to[stride*1] = __high2half(from);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst<half2>(const half2& a) {
+ return __low2half(a);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs<half2>(const half2& a) {
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half result1 = half_impl::raw_uint16_to_half(a1.x & 0x7FFF);
+ half result2 = half_impl::raw_uint16_to_half(a2.x & 0x7FFF);
+ return __halves2half2(result1, result2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ptrue<half2>(const half2& a) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ return pset1<half2>(true_half);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pzero<half2>(const half2& a) {
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ return pset1<half2>(false_half);
+}
+
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<half2,2>& kernel) {
+ __half a1 = __low2half(kernel.packet[0]);
+ __half a2 = __high2half(kernel.packet[0]);
+ __half b1 = __low2half(kernel.packet[1]);
+ __half b2 = __high2half(kernel.packet[1]);
+ kernel.packet[0] = __halves2half2(a1, b1);
+ kernel.packet[1] = __halves2half2(a2, b2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset<half2>(const Eigen::half& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __halves2half2(a, __hadd(a, __float2half(1.0f)));
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __halves2half2(a, __hadd(a, __float2half(1.0f)));
+#else
+ float f = __half2float(a) + 1.0f;
+ return __halves2half2(a, __float2half(f));
+#endif
+
+#endif
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect<half2>(const half2& mask,
+ const half2& a,
+ const half2& b) {
+ half mask_low = __low2half(mask);
+ half mask_high = __high2half(mask);
+ half result_low = mask_low == half(0) ? __low2half(b) : __low2half(a);
+ half result_high = mask_high == half(0) ? __high2half(b) : __high2half(a);
+ return __halves2half2(result_low, result_high);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq<half2>(const half2& a,
+ const half2& b) {
+ half true_half = half_impl::raw_uint16_to_half(0xffffu);
+ half false_half = half_impl::raw_uint16_to_half(0x0000u);
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half b1 = __low2half(b);
+ half b2 = __high2half(b);
+ half eq1 = __half2float(a1) == __half2float(b1) ? true_half : false_half;
+ half eq2 = __half2float(a2) == __half2float(b2) ? true_half : false_half;
+ return __halves2half2(eq1, eq2);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand<half2>(const half2& a,
+ const half2& b) {
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half b1 = __low2half(b);
+ half b2 = __high2half(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x & b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x & b2.x);
+ return __halves2half2(result1, result2);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por<half2>(const half2& a,
+ const half2& b) {
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half b1 = __low2half(b);
+ half b2 = __high2half(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x | b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x | b2.x);
+ return __halves2half2(result1, result2);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor<half2>(const half2& a,
+ const half2& b) {
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half b1 = __low2half(b);
+ half b2 = __high2half(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x ^ b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x ^ b2.x);
+ return __halves2half2(result1, result2);
+}
+
+template <>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot<half2>(const half2& a,
+ const half2& b) {
+ half a1 = __low2half(a);
+ half a2 = __high2half(a);
+ half b1 = __low2half(b);
+ half b2 = __high2half(b);
+ half result1 = half_impl::raw_uint16_to_half(a1.x & ~b1.x);
+ half result2 = half_impl::raw_uint16_to_half(a2.x & ~b2.x);
+ return __halves2half2(result1, result2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd<half2>(const half2& a, const half2& b) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hadd2(a, b);
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hadd2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 + b1;
+ float r2 = a2 + b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub<half2>(const half2& a, const half2& b) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hsub2(a, b);
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hsub2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 - b1;
+ float r2 = a2 - b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hneg2(a);
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hneg2(a);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return __floats2half2_rn(-a1, -a2);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; }
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul<half2>(const half2& a, const half2& b) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hmul2(a, b);
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hmul2(a, b);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 * b1;
+ float r2 = a2 * b2;
+ return __floats2half2_rn(r1, r2);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd<half2>(const half2& a, const half2& b, const half2& c) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hfma2(a, b, c);
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hfma2(a, b, c);
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float c1 = __low2float(c);
+ float c2 = __high2float(c);
+ float r1 = a1 * b1 + c1;
+ float r2 = a2 * b2 + c2;
+ return __floats2half2_rn(r1, r2);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv<half2>(const half2& a, const half2& b) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __h2div(a, b);
+
+#else // EIGEN_CUDA_ARCH
+
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ float r1 = a1 / b1;
+ float r2 = a2 / b2;
+ return __floats2half2_rn(r1, r2);
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin<half2>(const half2& a, const half2& b) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ __half r1 = a1 < b1 ? __low2half(a) : __low2half(b);
+ __half r2 = a2 < b2 ? __high2half(a) : __high2half(b);
+ return __halves2half2(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax<half2>(const half2& a, const half2& b) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float b1 = __low2float(b);
+ float b2 = __high2float(b);
+ __half r1 = a1 > b1 ? __low2half(a) : __low2half(b);
+ __half r2 = a2 > b2 ? __high2half(a) : __high2half(b);
+ return __halves2half2(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux<half2>(const half2& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hadd(__low2half(a), __high2half(a));
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hadd(__low2half(a), __high2half(a));
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return Eigen::half(__float2half(a1 + a2));
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max<half2>(const half2& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ __half first = __low2half(a);
+ __half second = __high2half(a);
+ return __hgt(first, second) ? first : second;
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ __half first = __low2half(a);
+ __half second = __high2half(a);
+ return __hgt(first, second) ? first : second;
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return a1 > a2 ? __low2half(a) : __high2half(a);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min<half2>(const half2& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ __half first = __low2half(a);
+ __half second = __high2half(a);
+ return __hlt(first, second) ? first : second;
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ __half first = __low2half(a);
+ __half second = __high2half(a);
+ return __hlt(first, second) ? first : second;
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return a1 < a2 ? __low2half(a) : __high2half(a);
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul<half2>(const half2& a) {
+#if defined(EIGEN_HIP_DEVICE_COMPILE)
+
+ return __hmul(__low2half(a), __high2half(a));
+
+#else // EIGEN_CUDA_ARCH
+
+#if EIGEN_CUDA_ARCH >= 530
+ return __hmul(__low2half(a), __high2half(a));
+#else
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ return Eigen::half(__float2half(a1 * a2));
+#endif
+
+#endif
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = log1pf(a1);
+ float r2 = log1pf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = expm1f(a1);
+ float r2 = expm1f(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
+ defined(EIGEN_HIP_DEVICE_COMPILE)
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 plog<half2>(const half2& a) {
+ return h2log(a);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 pexp<half2>(const half2& a) {
+ return h2exp(a);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 psqrt<half2>(const half2& a) {
+ return h2sqrt(a);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+half2 prsqrt<half2>(const half2& a) {
+ return h2rsqrt(a);
+}
+
+#else
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = logf(a1);
+ float r2 = logf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = expf(a1);
+ float r2 = expf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = sqrtf(a1);
+ float r2 = sqrtf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = rsqrtf(a1);
+ float r2 = rsqrtf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+#endif
+
+#endif
+
} // end namespace internal
} // end namespace Eigen