Revert "Vectorize cast"
This reverts commit eb5ff1861a4783876564a1a79573c3b9ff566863
diff --git a/Eigen/Core b/Eigen/Core
index 857cffa..bf0b9c7 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -183,6 +183,7 @@
// Generic half float support
#include "src/Core/arch/Default/Half.h"
#include "src/Core/arch/Default/BFloat16.h"
+#include "src/Core/arch/Default/TypeCasting.h"
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
#if defined EIGEN_VECTORIZE_AVX512
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index 02c0600..e233efb 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -621,110 +621,6 @@
Data m_d;
};
-// ----------------------- Casting ---------------------
-template <typename SrcType, typename DstType, typename ArgType>
-struct unary_evaluator<CwiseUnaryOp<scalar_cast_op<SrcType, DstType>, ArgType>, IndexBased> {
- using CastOp = scalar_cast_op<SrcType, DstType>;
- using XprType = CwiseUnaryOp<CastOp, ArgType>;
- using SrcPacketType = typename packet_traits<SrcType>::type;
-
- static constexpr int SrcPacketSize = packet_traits<SrcType>::size;
- static constexpr int SrcPacketSizeBytes = SrcPacketSize * sizeof(SrcType);
-
- enum {
- CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<CastOp>::Cost),
- Flags = evaluator<ArgType>::Flags &
- (HereditaryBits | LinearAccessBit | (functor_traits<CastOp>::PacketAccess ? PacketAccessBit : 0)),
- Alignment = evaluator<ArgType>::Alignment
- };
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& xpr)
- : m_argImpl(xpr.nestedExpression()) {
- EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<CastOp>::Cost);
- EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
- }
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index row, Index col) const {
- return CastOp()(m_argImpl.coeff(row, col));
- }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index index) const {
- return CastOp()(m_argImpl.coeff(index));
- }
-
- template <int LoadMode>
- EIGEN_ALWAYS_INLINE SrcPacketType srcPacket(Index row, Index col, Index offset) const {
- EIGEN_STATIC_ASSERT((LoadMode & (LoadMode - 1)) == 0, LoadMode must be a power of two)
- constexpr bool ArgIsRowMajor = evaluator<ArgType>::Flags & RowMajorBit;
- return m_argImpl.template packet<LoadMode, SrcPacketType>(ArgIsRowMajor ? row : row + (offset * SrcPacketSize),
- ArgIsRowMajor ? col + (offset * SrcPacketSize) : col);
- }
- template <int LoadMode>
- EIGEN_ALWAYS_INLINE SrcPacketType srcPacket(Index index, Index offset) const {
- EIGEN_STATIC_ASSERT((LoadMode & (LoadMode - 1)) == 0, LoadMode must be a power of two)
- return m_argImpl.template packet<LoadMode, SrcPacketType>(index + (offset * SrcPacketSize));
- }
-
- template <typename DstPacketType>
- using SrcPacketArgs1 = std::enable_if_t<unpacket_traits<DstPacketType>::size <= (1 * SrcPacketSize), bool>;
- template <typename DstPacketType>
- using SrcPacketArgs2 = std::enable_if_t<unpacket_traits<DstPacketType>::size == (2 * SrcPacketSize), bool>;
- template <typename DstPacketType>
- using SrcPacketArgs4 = std::enable_if_t<unpacket_traits<DstPacketType>::size == (4 * SrcPacketSize), bool>;
-
- template <int LoadMode, typename DstPacketType, SrcPacketArgs1<DstPacketType> = true>
- EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
- constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
- constexpr int SrcIncrementBytes = DstPacketSize * sizeof(SrcType);
- constexpr int SrcLoadMode = plain_enum_min(SrcIncrementBytes, LoadMode);
- return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode>(row, col, 0));
- }
- template <int LoadMode, typename DstPacketType, SrcPacketArgs2<DstPacketType> = true>
- EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
- constexpr int SrcLoadMode0 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode1 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
- return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode0>(row, col, 0),
- srcPacket<SrcLoadMode1>(row, col, 1));
- }
- template <int LoadMode, typename DstPacketType, SrcPacketArgs4<DstPacketType> = true>
- EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
- constexpr int SrcLoadMode0 = plain_enum_min(4 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode1 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode2 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode3 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
- return CastOp().template packetOp<DstPacketType>(
- srcPacket<SrcLoadMode0>(row, col, 0), srcPacket<SrcLoadMode1>(row, col, 1),
- srcPacket<SrcLoadMode2>(row, col, 2), srcPacket<SrcLoadMode3>(row, col, 3));
- }
-
- template <int LoadMode, typename DstPacketType, SrcPacketArgs1<DstPacketType> = true>
- EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
- constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
- constexpr int SrcIncrementBytes = DstPacketSize * sizeof(SrcType);
- constexpr int SrcLoadMode = plain_enum_min(SrcIncrementBytes, LoadMode);
- return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode>(index, 0));
- }
- template <int LoadMode, typename DstPacketType, SrcPacketArgs2<DstPacketType> = true>
- EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
- constexpr int SrcLoadMode0 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode1 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
- return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode0>(index, 0),
- srcPacket<SrcLoadMode1>(index, 1));
- }
- template <int LoadMode, typename DstPacketType, SrcPacketArgs4<DstPacketType> = true>
- EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
- constexpr int SrcLoadMode0 = plain_enum_min(4 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode1 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode2 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
- constexpr int SrcLoadMode3 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
- return CastOp().template packetOp<DstPacketType>(
- srcPacket<SrcLoadMode0>(index, 0), srcPacket<SrcLoadMode1>(index, 1),
- srcPacket<SrcLoadMode2>(index, 2), srcPacket<SrcLoadMode3>(index, 3));
- }
-
- protected:
- const evaluator<ArgType> m_argImpl;
-};
-
// -------------------- CwiseTernaryOp --------------------
// this is a ternary expression
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 5fa114f..40ee3f5 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -430,12 +430,6 @@
}
};
-template <typename OldType>
-struct cast_impl<OldType, bool> {
- EIGEN_DEVICE_FUNC
- static inline bool run(const OldType& x) { return x != OldType(0); }
-};
-
// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
// generating warnings on clang. Here we explicitly cast the real component.
template<typename OldType, typename NewType>
diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h
index 41db035..386543e 100644
--- a/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -62,15 +62,6 @@
TgtCoeffRatio = 1
};
};
-
-template <>
-struct type_casting_traits<float, double> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 2
- };
-};
#endif // EIGEN_VECTORIZE_AVX512
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
@@ -89,10 +80,6 @@
return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
}
-template<> EIGEN_STRONG_INLINE Packet4d pcast<Packet8f, Packet4d>(const Packet8f& a) {
- return _mm256_cvtps_pd(_mm256_castps256_ps128(a));
-}
-
template <>
EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
const Packet8f& b) {
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index c08b7c5..c8ca33a 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -1014,49 +1014,4 @@
} // end namespace std
#endif
-namespace Eigen {
-namespace internal {
-
-template <>
-struct cast_impl<float, half> {
- EIGEN_DEVICE_FUNC
- static inline half run(const float& a) {
-#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
- return __float2half(a);
-#else
- return Eigen::half(a);
-#endif
- }
-};
-
-template <>
-struct cast_impl<int, half> {
- EIGEN_DEVICE_FUNC
- static inline half run(const int& a) {
-#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
- return __float2half(static_cast<float>(a));
-#else
- return half(static_cast<float>(a));
-#endif
- }
-};
-
-template <>
-struct cast_impl<half, float> {
- EIGEN_DEVICE_FUNC
- static inline float run(const half& a) {
-#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
- return __half2float(a);
-#else
- return static_cast<float>(a);
-#endif
- }
-};
-
-} // namespace internal
-} // namespace Eigen
-
#endif // EIGEN_HALF_H
diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h
new file mode 100644
index 0000000..dc779a7
--- /dev/null
+++ b/Eigen/src/Core/arch/Default/TypeCasting.h
@@ -0,0 +1,116 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
+// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_GENERIC_TYPE_CASTING_H
+#define EIGEN_GENERIC_TYPE_CASTING_H
+
+#include "../../InternalHeaderCheck.h"
+
+namespace Eigen {
+
+namespace internal {
+
+template<>
+struct scalar_cast_op<float, Eigen::half> {
+ typedef Eigen::half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __float2half(a);
+ #else
+ return Eigen::half(a);
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::half> {
+ typedef Eigen::half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __float2half(static_cast<float>(a));
+ #else
+ return Eigen::half(static_cast<float>(a));
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::half, float> {
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
+ return __half2float(a);
+ #else
+ return static_cast<float>(a);
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::half, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<float, Eigen::bfloat16> {
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
+ return Eigen::bfloat16(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::bfloat16> {
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
+ return Eigen::bfloat16(static_cast<float>(a));
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::bfloat16, float> {
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
+ return static_cast<float>(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+}
+}
+
+#endif // EIGEN_GENERIC_TYPE_CASTING_H
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 8d7f59b..8354c0a 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -173,40 +173,22 @@
*
* \sa class CwiseUnaryOp, MatrixBase::cast()
*/
-template <typename SrcType, typename DstType>
+template<typename Scalar, typename NewType>
struct scalar_cast_op {
-
- using result_type = DstType;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstType operator()(const SrcType& a) const {
- return cast<SrcType, DstType>(a);
- }
-
- using SrcPacket = typename packet_traits<SrcType>::type;
-
- template <typename DstPacket>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a) const {
- return pcast<SrcPacket, DstPacket>(a);
- }
- template <typename DstPacket>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a, const SrcPacket& b) const {
- return pcast<SrcPacket, DstPacket>(a, b);
- }
- template <typename DstPacket>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a, const SrcPacket& b,
- const SrcPacket& c, const SrcPacket& d) const {
- return pcast<SrcPacket, DstPacket>(a, b, c, d);
- }
+ typedef NewType result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast<Scalar, NewType>(a); }
};
-template <typename SrcType, typename DstType>
-struct functor_traits<scalar_cast_op<SrcType, DstType>> {
- enum {
- Cost = is_same<SrcType, DstType>::value ? 0 : NumTraits<DstType>::AddCost,
- PacketAccess = (type_casting_traits<SrcType, DstType>::VectorizedCast != 0) &&
- (type_casting_traits<SrcType, DstType>::SrcCoeffRatio <= 4)
- };
+template <typename Scalar>
+struct scalar_cast_op<Scalar, bool> {
+ typedef bool result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { return a != Scalar(0); }
};
+template<typename Scalar, typename NewType>
+struct functor_traits<scalar_cast_op<Scalar,NewType> >
+{ enum { Cost = is_same<Scalar, NewType>::value ? 0 : NumTraits<NewType>::AddCost, PacketAccess = false }; };
+
/** \internal
* \brief Template functor to arithmetically shift a scalar right by a number of bits
*
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index f437c76..dfa81d4 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -9,7 +9,6 @@
#include <vector>
#include "main.h"
-#include "random_without_cast_overflow.h"
template <typename Scalar, std::enable_if_t<NumTraits<Scalar>::IsInteger,int> = 0>
std::vector<Scalar> special_values() {
@@ -1184,59 +1183,6 @@
typed_logicals_test_impl<ArrayType>::run(m);
}
-template <typename SrcType, typename DstType>
-struct cast_test_impl {
- using SrcArray = ArrayX<SrcType>;
- using DstArray = ArrayX<DstType>;
-
- static constexpr int SrcPacketSize = internal::packet_traits<SrcType>::size;
- static constexpr int DstPacketSize = internal::packet_traits<DstType>::size;
- static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize);
-
- static void run() {
- const Index testSize = 100 * MaxPacketSize;
- SrcArray src(testSize);
- for (Index i = 0; i < testSize; i++) src(i) = internal::random_without_cast_overflow<SrcType, DstType>::value();
- DstArray dst = src.template cast<DstType>();
- for (Index i = 0; i < testSize; i++) {
- DstType ref = static_cast<DstType>(src(i));
- bool all_nan = ((numext::isnan)(src(i)) && (numext::isnan)(ref) && (numext::isnan)(dst(i)));
- bool is_equal = ref == dst(i);
- bool pass = all_nan || is_equal;
- if (!pass) {
- std::cout << typeid(SrcType).name() << ": [" << +src(i) << "] to " << typeid(DstType).name() << ": [" << +dst(i)
- << "] != [" << +ref << "]\n";
- }
- VERIFY(pass);
- }
- }
-};
-
-template <typename... ScalarTypes>
-struct cast_tests_impl {
- using ScalarTuple = std::tuple<ScalarTypes...>;
- static constexpr size_t ScalarTupleSize = std::tuple_size<ScalarTuple>::value;
-
- template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
- static std::enable_if_t<Done> run() {}
-
- template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
- static std::enable_if_t<!Done> run() {
- using Type1 = typename std::tuple_element<i, ScalarTuple>::type;
- using Type2 = typename std::tuple_element<j, ScalarTuple>::type;
- cast_test_impl<Type1, Type2>::run();
- cast_test_impl<Type2, Type1>::run();
- static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0);
- static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1);
- run<next_i, next_j>();
- }
-};
-
-void cast_test() {
- cast_tests_impl<bool, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, float, double,
- long double, half, bfloat16>::run();
-}
-
EIGEN_DECLARE_TEST(array_cwise)
{
for(int i = 0; i < g_repeat; i++) {
@@ -1293,9 +1239,6 @@
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<float>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
}
- for (int i = 0; i < g_repeat; i++) {
- cast_test();
- }
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));