Add signbit function
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index b67c4ed..af773dd 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h
@@ -563,13 +563,13 @@ parg(const Packet& a) { using numext::arg; return arg(a); } -/** \internal \returns \a a logically shifted by N bits to the right */ +/** \internal \returns \a a arithmetically shifted by N bits to the right */ template<int N> EIGEN_DEVICE_FUNC inline int parithmetic_shift_right(const int& a) { return a >> N; } template<int N> EIGEN_DEVICE_FUNC inline long int parithmetic_shift_right(const long int& a) { return a >> N; } -/** \internal \returns \a a arithmetically shifted by N bits to the right */ +/** \internal \returns \a a logically shifted by N bits to the right */ template<int N> EIGEN_DEVICE_FUNC inline int plogical_shift_right(const int& a) { return static_cast<int>(static_cast<unsigned int>(a) >> N); } template<int N> EIGEN_DEVICE_FUNC inline long int @@ -1191,6 +1191,34 @@ return preciprocal<Packet>(psqrt(a)); } +template <typename Packet, bool IsScalar = is_scalar<Packet>::value, + bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger> + struct psignbit_impl; +template <typename Packet, bool IsInteger> +struct psignbit_impl<Packet, true, IsInteger> { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return numext::signbit(a); } +}; +template <typename Packet> +struct psignbit_impl<Packet, false, false> { + // generic implementation if not specialized in PacketMath.h + // slower than arithmetic shift + typedef typename unpacket_traits<Packet>::type Scalar; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Packet run(const Packet& a) { + const Packet cst_pos_one = pset1<Packet>(Scalar(1)); + const Packet cst_neg_one = pset1<Packet>(Scalar(-1)); + return pcmp_eq(por(pand(a, cst_neg_one), cst_pos_one), cst_neg_one); + } +}; +template <typename Packet> +struct psignbit_impl<Packet, false, true> { + // generic implementation for integer packets + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return pcmp_lt(a, pzero(a)); } +}; +/** \internal \returns the sign bit of \a a as a bitmask*/ +template <typename Packet> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr Packet +psignbit(const Packet& a) { return psignbit_impl<Packet>::run(a); } + } // end namespace internal } // end namespace Eigen
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 0eee333..b194353 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h
@@ -1531,6 +1531,37 @@ } #endif +template <typename Scalar, bool IsInteger = NumTraits<Scalar>::IsInteger, bool IsSigned = NumTraits<Scalar>::IsSigned> +struct signbit_impl; +template <typename Scalar> +struct signbit_impl<Scalar, false, true> { + static constexpr size_t Size = sizeof(Scalar); + static constexpr size_t Shift = (CHAR_BIT * Size) - 1; + using intSize_t = typename get_integer_by_size<Size>::signed_type; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Scalar run(const Scalar& x) { + intSize_t a = bit_cast<intSize_t, Scalar>(x); + a = a >> Shift; + Scalar result = bit_cast<Scalar, intSize_t>(a); + return result; + } +}; +template <typename Scalar> +struct signbit_impl<Scalar, true, true> { + static constexpr size_t Size = sizeof(Scalar); + static constexpr size_t Shift = (CHAR_BIT * Size) - 1; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar run(const Scalar& x) { return x >> Shift; } +}; +template <typename Scalar> +struct signbit_impl<Scalar, true, false> { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar run(const Scalar& ) { + return Scalar(0); + } +}; +template <typename Scalar> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar signbit(const Scalar& x) { + return signbit_impl<Scalar>::run(x); +} + template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T exp(const T &x) {
diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h index 4f1f992..53362ef 100644 --- a/Eigen/src/Core/NumTraits.h +++ b/Eigen/src/Core/NumTraits.h
@@ -95,7 +95,7 @@ // Load src into registers first. This allows the memcpy to be elided by CUDA. const Src staged = src; EIGEN_USING_STD(memcpy) - memcpy(&tgt, &staged, sizeof(Tgt)); + memcpy(static_cast<void*>(&tgt),static_cast<const void*>(&staged), sizeof(Tgt)); return tgt; } } // namespace numext
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index ecbb73c..33a4dee 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -229,10 +229,7 @@ Vectorizable = 1, AlignedOnScalar = 1, HasCmp = 1, - size=4, - - // requires AVX512 - HasShift = 0, + size=4 }; }; #endif @@ -360,6 +357,35 @@ EIGEN_STRONG_INLINE Packet4l plogical_shift_left(Packet4l a) { return _mm256_slli_epi64(a, N); } +#ifdef EIGEN_VECTORIZE_AVX512FP16 +template <int N> +EIGEN_STRONG_INLINE Packet4l parithmetic_shift_right(Packet4l a) { return _mm256_srai_epi64(a, N); } +#else +template <int N> +EIGEN_STRONG_INLINE std::enable_if_t< (N == 0), Packet4l> parithmetic_shift_right(Packet4l a) { + return a; +} +template <int N> +EIGEN_STRONG_INLINE std::enable_if_t< (N > 0) && (N < 32), Packet4l> parithmetic_shift_right(Packet4l a) { + __m256i hi_word = _mm256_srai_epi32(a, N); + __m256i lo_word = _mm256_srli_epi64(a, N); + return _mm256_blend_epi32(hi_word, lo_word, 0b01010101); +} +template <int N> +EIGEN_STRONG_INLINE std::enable_if_t< (N >= 32) && (N < 63), Packet4l> parithmetic_shift_right(Packet4l a) { + __m256i hi_word = _mm256_srai_epi32(a, 31); + __m256i lo_word = _mm256_shuffle_epi32(_mm256_srai_epi32(a, N - 32), (shuffle_mask<1, 1, 3, 3>::mask)); + return _mm256_blend_epi32(hi_word, lo_word, 0b01010101); +} +template <int N> +EIGEN_STRONG_INLINE std::enable_if_t< (N == 63), Packet4l> parithmetic_shift_right(Packet4l a) { + return _mm256_shuffle_epi32(_mm256_srai_epi32(a, 31), (shuffle_mask<1, 1, 3, 3>::mask)); +} +template <int N> +EIGEN_STRONG_INLINE std::enable_if_t< (N < 0) || (N > 63), Packet4l> parithmetic_shift_right(Packet4l a) { + return parithmetic_shift_right<int(N&63)>(a); +} +#endif template <> EIGEN_STRONG_INLINE Packet4l pload<Packet4l>(const int64_t* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from)); @@ -1103,6 +1129,11 @@ #endif } +template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { return _mm_srai_epi16(a, 15); } +template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return _mm_srai_epi16(a, 15); } +template<> EIGEN_STRONG_INLINE Packet8f psignbit(const Packet8f& a) { return _mm256_castsi256_ps(parithmetic_shift_right<31>((Packet8i)_mm256_castps_si256(a))); } +template<> EIGEN_STRONG_INLINE Packet4d psignbit(const Packet4d& a) { return _mm256_castsi256_pd(parithmetic_shift_right<63>((Packet4l)_mm256_castpd_si256(a))); } + template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) { return pfrexp_generic(a,exponent); }
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 5f37740..c210f2f 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1127,6 +1127,11 @@ return _mm512_abs_epi32(a); } +template<> EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) { return _mm256_srai_epi16(a, 15); } +template<> EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) { return _mm256_srai_epi16(a, 15); } +template<> EIGEN_STRONG_INLINE Packet16f psignbit(const Packet16f& a) { return _mm512_castsi512_ps(_mm512_srai_epi32(_mm512_castps_si512(a), 31)); } +template<> EIGEN_STRONG_INLINE Packet8d psignbit(const Packet8d& a) { return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), 63)); } + template<> EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){ return pfrexp_generic(a, exponent);
diff --git a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h index 58621d9..13f285e 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +++ b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
@@ -196,6 +196,13 @@ return _mm512_abs_ph(a); } +// psignbit + +template <> +EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) { + return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15)); +} + // pmin template <>
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index d9ddb5e..d30ead4 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -1575,6 +1575,9 @@ return pand<Packet8us>(p8us_abs_mask, a); } +template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return vec_sra(a.m_val, vec_splat_u16(15)); } +template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return (Packet4f)vec_sra((Packet4i)a, vec_splats(uint32_t(31))); } + template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a) { return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); } template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a) @@ -2928,7 +2931,7 @@ return vec_sld(a, a, 8); } template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); } - +template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { return (Packet2d)vec_sra((Packet2l)a, vec_splats(uint64_t(63))); } // VSX support varies between different compilers and even different // versions of the same compiler. For gcc version >= 4.9.3, we can use // vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 5cbf4ac..067b725 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -2372,6 +2372,12 @@ } template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4h psignbit(const Packet4h& a) { vreinterpret_f16_s16( vshr_n_s16( vreinterpret_s16_f16(a), 15)); } +template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { vreinterpretq_f16_s16(vshrq_n_s16(vreinterpretq_s16_f16(a), 15)); } +template<> EIGEN_STRONG_INLINE Packet2f psignbit(const Packet2f& a) { vreinterpret_f32_s32( vshr_n_s32( vreinterpret_s32_f32(a), 31)); } +template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { vreinterpretq_f32_s32(vshrq_n_s32(vreinterpretq_s32_f32(a), 31)); } +template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { vreinterpretq_f64_s64(vshrq_n_s64(vreinterpretq_s64_f64(a), 63)); } + template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent) { return pfrexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 847ff07..a0ff359 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -649,6 +649,17 @@ #endif } +template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return _mm_castsi128_ps(_mm_srai_epi32(_mm_castps_si128(a), 31)); } +template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) +{ + Packet4f tmp = psignbit<Packet4f>(_mm_castpd_ps(a)); +#ifdef EIGEN_VECTORIZE_AVX + return _mm_castps_pd(_mm_permute_ps(tmp, (shuffle_mask<1, 1, 3, 3>::mask))); +#else + return _mm_castps_pd(_mm_shuffle_ps(tmp, tmp, (shuffle_mask<1, 1, 3, 3>::mask))); +#endif // EIGEN_VECTORIZE_AVX +} + #ifdef EIGEN_VECTORIZE_SSE4_1 template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) {
diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h index 32152ac..6c6fb71 100644 --- a/Eigen/src/Core/util/Meta.h +++ b/Eigen/src/Core/util/Meta.h
@@ -43,6 +43,32 @@ typedef std::int32_t int32_t; typedef std::uint64_t uint64_t; typedef std::int64_t int64_t; + +template <size_t Size> +struct get_integer_by_size { + typedef void signed_type; + typedef void unsigned_type; +}; +template <> +struct get_integer_by_size<1> { + typedef int8_t signed_type; + typedef uint8_t unsigned_type; +}; +template <> +struct get_integer_by_size<2> { + typedef int16_t signed_type; + typedef uint16_t unsigned_type; +}; +template <> +struct get_integer_by_size<4> { + typedef int32_t signed_type; + typedef uint32_t unsigned_type; +}; +template <> +struct get_integer_by_size<8> { + typedef int64_t signed_type; + typedef uint64_t unsigned_type; +}; } }
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index a7e0ff4..94c9451 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp
@@ -219,7 +219,7 @@ for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) { test_exponent<Base, Exponent>(exponent); } -}; +} void mixed_pow_test() { // The following cases will test promoting a smaller exponent type @@ -260,6 +260,81 @@ unary_pow_test<long long, int>(); } +namespace Eigen { +namespace internal { +template <typename Scalar> +struct test_signbit_op { + Scalar constexpr operator()(const Scalar& a) const { return numext::signbit(a); } + template <typename Packet> + inline Packet packetOp(const Packet& a) const { + return psignbit(a); + } +}; +template <typename Scalar> +struct functor_traits<test_signbit_op<Scalar>> { + enum { Cost = 1, PacketAccess = true }; //todo: define HasSignbit flag +}; +} // namespace internal +} // namespace Eigen + +template <typename T, bool IsInteger = NumTraits<T>::IsInteger> +struct ref_signbit_func_impl { + static bool run(const T& x) { return std::signbit(x); } +}; +template <typename T> +struct ref_signbit_func_impl<T, true> { + // MSVC (perhaps others) does not have a std::signbit overload for integers + static bool run(const T& x) { return x < T(0); } +}; +template <typename T> +bool ref_signbit_func(const T& x) { + return ref_signbit_func_impl<T>::run(x); +} + +template <typename Scalar> +void signbit_test() { + Scalar true_mask; + std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(Scalar)); + Scalar false_mask; + std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(Scalar)); + + const size_t size = 100 * internal::packet_traits<Scalar>::size; + ArrayX<Scalar> x(size), y(size); + x.setRandom(); + std::vector<Scalar> special_vals = special_values<Scalar>(); + for (size_t i = 0; i < special_vals.size(); i++) { + x(2 * i + 0) = special_vals[i]; + x(2 * i + 1) = -special_vals[i]; + } + y = x.unaryExpr(internal::test_signbit_op<Scalar>()); + + bool all_pass = true; + for (size_t i = 0; i < size; i++) { + const Scalar ref_val = ref_signbit_func(x(i)) ? true_mask : false_mask; + bool not_same = internal::predux_any(internal::bitwise_helper<Scalar>::bitwise_xor(ref_val, y(i))); + if (not_same) std::cout << "signbit(" << x(i) << ") != " << y(i) << "\n"; + all_pass = all_pass && !not_same; + } + + VERIFY(all_pass); +} +void signbit_tests() { + signbit_test<float>(); + signbit_test<double>(); + signbit_test<Eigen::half>(); + signbit_test<Eigen::bfloat16>(); + + signbit_test<uint8_t>(); + signbit_test<uint16_t>(); + signbit_test<uint32_t>(); + signbit_test<uint64_t>(); + + signbit_test<int8_t>(); + signbit_test<int16_t>(); + signbit_test<int32_t>(); + signbit_test<int64_t>(); +} + template<typename ArrayType> void array(const ArrayType& m) { typedef typename ArrayType::Scalar Scalar; @@ -855,6 +930,35 @@ VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<9>())).all() ); } +template <typename ArrayType> +struct signed_shift_test_impl { + typedef typename ArrayType::Scalar Scalar; + static constexpr size_t Size = sizeof(Scalar); + static constexpr size_t MaxShift = (CHAR_BIT * Size) - 1; + + template <size_t N = 0> + static inline std::enable_if_t<(N > MaxShift), void> run(const ArrayType& ) {} + template <size_t N = 0> + static inline std::enable_if_t<(N <= MaxShift), void> run(const ArrayType& m) { + const Index rows = m.rows(); + const Index cols = m.cols(); + + ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols); + + m2 = m1.unaryExpr([](const Scalar& x) { return x >> N; }); + VERIFY((m2 == m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>())).all()); + + m2 = m1.unaryExpr([](const Scalar& x) { return x << N; }); + VERIFY((m2 == m1.unaryExpr( internal::scalar_shift_left_op<Scalar, N>())).all()); + + run<N + 1>(m); + } +}; +template <typename ArrayType> +void signed_shift_test(const ArrayType& m) { + signed_shift_test_impl<ArrayType>::run(m); +} + EIGEN_DECLARE_TEST(array_cwise) { for(int i = 0; i < g_repeat; i++) { @@ -867,6 +971,9 @@ CALL_SUBTEST_6( array(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_6( array_integer(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_6( array_integer(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); + CALL_SUBTEST_7( signed_shift_test(ArrayXXi(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE)))); + CALL_SUBTEST_7( signed_shift_test(Array<Index, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE)))); + } for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1( comparisons(Array<float, 1, 1>()) ); @@ -897,6 +1004,7 @@ for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST_6( int_pow_test() ); CALL_SUBTEST_7( mixed_pow_test() ); + CALL_SUBTEST_8( signbit_tests() ); } VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
diff --git a/test/numext.cpp b/test/numext.cpp index ee879c9..5483e5c 100644 --- a/test/numext.cpp +++ b/test/numext.cpp
@@ -239,6 +239,58 @@ check_rsqrt_impl<T>::run(); } +template <typename T, bool IsInteger = NumTraits<T>::IsInteger> +struct ref_signbit_func_impl { + static bool run(const T& x) { return std::signbit(x); } +}; +template <typename T> +struct ref_signbit_func_impl<T, true> { + // MSVC (perhaps others) does not have a std::signbit overload for integers + static bool run(const T& x) { return x < T(0); } +}; +template <typename T> +bool ref_signbit_func(const T& x) { + return ref_signbit_func_impl<T>::run(x); +} + +template <typename T> +struct check_signbit_impl { + static void run() { + T true_mask; + std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(T)); + T false_mask; + std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(T)); + + // has sign bit + const T neg_zero = static_cast<T>(-0.0); + const T neg_one = static_cast<T>(-1.0); + const T neg_inf = -std::numeric_limits<T>::infinity(); + const T neg_nan = -std::numeric_limits<T>::quiet_NaN(); + // does not have sign bit + const T pos_zero = static_cast<T>(0.0); + const T pos_one = static_cast<T>(1.0); + const T pos_inf = std::numeric_limits<T>::infinity(); + const T pos_nan = std::numeric_limits<T>::quiet_NaN(); + + std::vector<T> values = {neg_zero, neg_one, neg_inf, neg_nan, pos_zero, pos_one, pos_inf, pos_nan}; + + bool all_pass = true; + + for (T val : values) { + const T numext_val = numext::signbit(val); + const T ref_val = ref_signbit_func(val) ? true_mask : false_mask; + bool not_same = internal::predux_any(internal::bitwise_helper<T>::bitwise_xor(ref_val, numext_val)); + all_pass = all_pass && !not_same; + if (not_same) std::cout << "signbit(" << val << ") != " << numext_val << "\n"; + } + VERIFY(all_pass); + } +}; +template <typename T> +void check_signbit() { + check_signbit_impl<T>::run(); +} + EIGEN_DECLARE_TEST(numext) { for(int k=0; k<g_repeat; ++k) { @@ -271,5 +323,20 @@ CALL_SUBTEST( check_rsqrt<double>() ); CALL_SUBTEST( check_rsqrt<std::complex<float> >() ); CALL_SUBTEST( check_rsqrt<std::complex<double> >() ); + + CALL_SUBTEST( check_signbit<half>()); + CALL_SUBTEST( check_signbit<bfloat16>()); + CALL_SUBTEST( check_signbit<float>()); + CALL_SUBTEST( check_signbit<double>()); + + CALL_SUBTEST( check_signbit<uint8_t>()); + CALL_SUBTEST( check_signbit<uint16_t>()); + CALL_SUBTEST( check_signbit<uint32_t>()); + CALL_SUBTEST( check_signbit<uint64_t>()); + + CALL_SUBTEST( check_signbit<int8_t>()); + CALL_SUBTEST( check_signbit<int16_t>()); + CALL_SUBTEST( check_signbit<int32_t>()); + CALL_SUBTEST( check_signbit<int64_t>()); } }