Add truncation op
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index fc5d757..381d8ff 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -57,6 +57,9 @@
HasConj = 1,
HasSetLinear = 1,
HasSign = 1,
+ // By default, the nearest integer functions (rint, round, floor, ceil, trunc) are enabled for all scalar and packet
+ // types
+ HasRound = 1,
HasArg = 0,
HasAbsDiff = 0,
@@ -64,10 +67,6 @@
// This flag is used to indicate whether packet comparison is supported.
// pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
HasCmp = 0,
- HasRound = 0,
- HasRint = 0,
- HasFloor = 0,
- HasCeil = 0,
HasDiv = 0,
HasReciprocal = 0,
@@ -1138,33 +1137,45 @@
return numext::cbrt(a);
}
+template <typename Packet, bool IsScalar = is_scalar<Packet>::value,
+ bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
+struct nearest_integer_packetop_impl {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_floor(const Packet& x) { return numext::floor(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_ceil(const Packet& x) { return numext::ceil(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_rint(const Packet& x) { return numext::rint(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_round(const Packet& x) { return numext::round(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_trunc(const Packet& x) { return numext::trunc(x); }
+};
+
/** \internal \returns the rounded value of \a a (coeff-wise) */
template <typename Packet>
-EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pround(const Packet& a) {
- using numext::round;
- return round(a);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pround(const Packet& a) {
+ return nearest_integer_packetop_impl<Packet>::run_round(a);
}
/** \internal \returns the floor of \a a (coeff-wise) */
template <typename Packet>
-EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pfloor(const Packet& a) {
- using numext::floor;
- return floor(a);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pfloor(const Packet& a) {
+ return nearest_integer_packetop_impl<Packet>::run_floor(a);
}
/** \internal \returns the rounded value of \a a (coeff-wise) with current
* rounding mode */
template <typename Packet>
-EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet print(const Packet& a) {
- using numext::rint;
- return rint(a);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet print(const Packet& a) {
+ return nearest_integer_packetop_impl<Packet>::run_rint(a);
}
/** \internal \returns the ceil of \a a (coeff-wise) */
template <typename Packet>
-EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pceil(const Packet& a) {
- using numext::ceil;
- return ceil(a);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pceil(const Packet& a) {
+ return nearest_integer_packetop_impl<Packet>::run_ceil(a);
+}
+
+/** \internal \returns the truncation of \a a (coeff-wise) */
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet ptrunc(const Packet& a) {
+ return nearest_integer_packetop_impl<Packet>::run_trunc(a);
}
template <typename Packet, typename EnableIf = void>
diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h
index f0ae5a8..3f147b8 100644
--- a/Eigen/src/Core/GlobalFunctions.h
+++ b/Eigen/src/Core/GlobalFunctions.h
@@ -98,9 +98,12 @@
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(round, scalar_round_op,
nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(
- floor, scalar_floor_op, nearest integer not greater than the giben value,\sa Eigen::ceil DOXCOMMA ArrayBase::floor)
+ floor, scalar_floor_op, nearest integer not greater than the given value,\sa Eigen::ceil DOXCOMMA ArrayBase::floor)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(
- ceil, scalar_ceil_op, nearest integer not less than the giben value,\sa Eigen::floor DOXCOMMA ArrayBase::ceil)
+ ceil, scalar_ceil_op, nearest integer not less than the given value,\sa Eigen::floor DOXCOMMA ArrayBase::ceil)
+EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(trunc, scalar_trunc_op,
+ nearest integer not greater in magnitude than the given value,\sa Eigen::trunc DOXCOMMA
+ ArrayBase::trunc)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(
isnan, scalar_isnan_op, not -a - number test,\sa Eigen::isinf DOXCOMMA Eigen::isfinite DOXCOMMA ArrayBase::isnan)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 2a42b18..6bb9a12 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -894,6 +894,9 @@
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_round(const Scalar& x) {
EIGEN_USING_STD(round) return round(x);
}
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) {
+ EIGEN_USING_STD(trunc) return trunc(x);
+ }
};
template <typename Scalar>
struct nearest_integer_impl<Scalar, true> {
@@ -901,6 +904,7 @@
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_ceil(const Scalar& x) { return x; }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_rint(const Scalar& x) { return x; }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_round(const Scalar& x) { return x; }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
};
} // end namespace internal
@@ -1192,17 +1196,26 @@
return internal::nearest_integer_impl<Scalar>::run_round(x);
}
-#if defined(SYCL_DEVICE_ONLY)
-SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(round, round)
-#endif
-
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar(floor)(const Scalar& x) {
return internal::nearest_integer_impl<Scalar>::run_floor(x);
}
+template <typename Scalar>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar(ceil)(const Scalar& x) {
+ return internal::nearest_integer_impl<Scalar>::run_ceil(x);
+}
+
+template <typename Scalar>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar(trunc)(const Scalar& x) {
+ return internal::nearest_integer_impl<Scalar>::run_trunc(x);
+}
+
#if defined(SYCL_DEVICE_ONLY)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(round, round)
SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(floor, floor)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(ceil, ceil)
+SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(trunc, trunc)
#endif
#if defined(EIGEN_GPUCC)
@@ -1210,32 +1223,26 @@
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float floor(const float& x) {
return ::floorf(x);
}
-
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double floor(const double& x) {
return ::floor(x);
}
-#endif
-
-template <typename Scalar>
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar(ceil)(const Scalar& x) {
- return internal::nearest_integer_impl<Scalar>::run_ceil(x);
-}
-
-#if defined(SYCL_DEVICE_ONLY)
-SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(ceil, ceil)
-#endif
-
-#if defined(EIGEN_GPUCC)
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float ceil(const float& x) {
return ::ceilf(x);
}
-
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double ceil(const double& x) {
return ::ceil(x);
}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float trunc(const float& x) {
+ return ::truncf(x);
+}
+template <>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double trunc(const double& x) {
+ return ::trunc(x);
+}
#endif
// Integer division with rounding up.
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index dac43fc..b05429c 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -124,11 +124,7 @@
HasRsqrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
- HasBlend = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1
+ HasBlend = 1
};
};
template <>
@@ -151,11 +147,7 @@
HasSqrt = 1,
HasRsqrt = 1,
HasATan = 1,
- HasBlend = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1
+ HasBlend = 1
};
};
@@ -192,10 +184,6 @@
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasBlend = 0,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
HasBessel = 1,
HasNdtri = 1
};
@@ -235,10 +223,6 @@
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasBlend = 0,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
HasBessel = 1,
HasNdtri = 1
};
@@ -1258,6 +1242,15 @@
}
template <>
+EIGEN_STRONG_INLINE Packet8f ptrunc<Packet8f>(const Packet8f& a) {
+ return _mm256_round_ps(a, _MM_FROUND_TRUNC);
+}
+template <>
+EIGEN_STRONG_INLINE Packet4d ptrunc<Packet4d>(const Packet4d& a) {
+ return _mm256_round_pd(a, _MM_FROUND_TRUNC);
+}
+
+template <>
EIGEN_STRONG_INLINE Packet8i ptrue<Packet8i>(const Packet8i& a) {
#ifdef EIGEN_VECTORIZE_AVX2
// vpcmpeqd has lower latency than the more general vcmpps
@@ -2312,6 +2305,11 @@
}
template <>
+EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
+ return float2half(ptrunc<Packet8f>(half2float(a)));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
}
@@ -2687,6 +2685,11 @@
}
template <>
+EIGEN_STRONG_INLINE Packet8bf ptrunc<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(ptrunc<Packet8f>(Bf16ToF32(a)));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) {
return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
}
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 8f7662f..9a0edca 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -97,11 +97,7 @@
HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
- HasBlend = 0,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1
+ HasBlend = 0
};
};
#endif
@@ -138,11 +134,7 @@
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasCmp = 1,
- HasDiv = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1
+ HasDiv = 1
};
};
template <>
@@ -162,11 +154,7 @@
HasExp = 1,
HasATan = 1,
HasCmp = 1,
- HasDiv = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1
+ HasDiv = 1
};
};
@@ -782,6 +770,15 @@
}
template <>
+EIGEN_STRONG_INLINE Packet16f ptrunc<Packet16f>(const Packet16f& a) {
+ return _mm512_roundscale_ps(a, _MM_FROUND_TO_ZERO);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d ptrunc<Packet8d>(const Packet8d& a) {
+ return _mm512_roundscale_pd(a, _MM_FROUND_TO_ZERO);
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
return _mm512_set1_epi32(int32_t(-1));
}
@@ -2323,6 +2320,11 @@
}
template <>
+EIGEN_STRONG_INLINE Packet16h ptrunc<Packet16h>(const Packet16h& a) {
+ return float2half(ptrunc<Packet16f>(half2float(a)));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
@@ -2822,6 +2824,11 @@
}
template <>
+EIGEN_STRONG_INLINE Packet16bf ptrunc<Packet16bf>(const Packet16bf& a) {
+ return F32ToBf16(ptrunc<Packet16f>(Bf16ToF32(a)));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) {
return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
}
diff --git a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
index 131e6f1..d4a5816 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
@@ -60,11 +60,7 @@
HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasErf = 0, // EIGEN_FAST_MATH,
- HasBlend = 0,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1
+ HasBlend = 0
};
};
@@ -390,6 +386,13 @@
return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
}
+// ptrunc
+
+template <>
+EIGEN_STRONG_INLINE Packet32h ptrunc<Packet32h>(const Packet32h& a) {
+ return _mm512_roundscale_ph(a, _MM_FROUND_TO_ZERO);
+}
+
// predux
template <>
EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index b0f7262..4c92e05 100644
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -193,17 +193,12 @@
#endif
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
- HasRint = 1,
#else
HasSqrt = 0,
HasRsqrt = 0,
HasTanh = 0,
HasErf = 0,
- HasRint = 0,
#endif
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
HasNegate = 1,
HasBlend = 1
};
@@ -235,17 +230,12 @@
#else
HasRsqrt = 0,
#endif
- HasRint = 1,
#else
HasSqrt = 0,
HasRsqrt = 0,
- HasRint = 0,
#endif
HasTanh = 0,
HasErf = 0,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
HasNegate = 1,
HasBlend = 1
};
@@ -1506,6 +1496,10 @@
EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) {
return vec_floor(a);
}
+template <>
+EIGEN_STRONG_INLINE Packet4f ptrunc<Packet4f>(const Packet4f& a) {
+ return vec_trunc(a);
+}
#ifdef EIGEN_VECTORIZE_VSX
template <>
EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a) {
@@ -2364,6 +2358,10 @@
EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a) {
BF16_TO_F32_UNARY_OP_WRAPPER(pround<Packet4f>, a);
}
+template <>
+EIGEN_STRONG_INLINE Packet8bf ptrunc<Packet8bf>(const Packet8bf& a) {
+ BF16_TO_F32_UNARY_OP_WRAPPER(ptrunc<Packet4f>, a);
+}
#ifdef EIGEN_VECTORIZE_VSX
template <>
EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) {
@@ -3189,10 +3187,6 @@
#else
HasRsqrt = 0,
#endif
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
HasNegate = 1,
HasBlend = 1
};
@@ -3446,6 +3440,10 @@
return vec_floor(a);
}
template <>
+EIGEN_STRONG_INLINE Packet2d ptrunc<Packet2d>(const Packet2d& a) {
+ return vec_trunc(a);
+}
+template <>
EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a) {
Packet2d res;
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index be44909..f31c6ce 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -637,6 +637,7 @@
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { return bfloat16(::ceilf(float(a))); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) { return bfloat16(::rintf(float(a))); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) { return bfloat16(::roundf(float(a))); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(const bfloat16& a) { return bfloat16(::truncf(float(a))); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
return bfloat16(::fmodf(float(a), float(b)));
}
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 16ca807..537dffe 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -2469,6 +2469,95 @@
}
};
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a) {
+ using Scalar = typename unpacket_traits<Packet>::type;
+ using IntType = typename numext::get_integer_by_size<sizeof(Scalar)>::signed_type;
+ // Adds and subtracts signum(a) * 2^kMantissaBits to force rounding.
+ const IntType kLimit = IntType(1) << (NumTraits<Scalar>::digits() - 1);
+ const Packet cst_limit = pset1<Packet>(static_cast<Scalar>(kLimit));
+ Packet abs_a = pabs(a);
+ Packet sign_a = pandnot(a, abs_a);
+ Packet rint_a = padd(abs_a, cst_limit);
+ // Don't compile-away addition and subtraction.
+ EIGEN_OPTIMIZATION_BARRIER(rint_a);
+ rint_a = psub(rint_a, cst_limit);
+ rint_a = por(rint_a, sign_a);
+ // If greater than limit (or NaN), simply return a.
+ Packet mask = pcmp_lt(abs_a, cst_limit);
+ Packet result = pselect(mask, rint_a, a);
+ return result;
+}
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a) {
+ using Scalar = typename unpacket_traits<Packet>::type;
+ const Packet cst_1 = pset1<Packet>(Scalar(1));
+ Packet rint_a = generic_rint(a);
+ // if a < rint(a), then rint(a) == ceil(a)
+ Packet mask = pcmp_lt(a, rint_a);
+ Packet offset = pand(cst_1, mask);
+ Packet result = psub(rint_a, offset);
+ return result;
+}
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a) {
+ using Scalar = typename unpacket_traits<Packet>::type;
+ const Packet cst_1 = pset1<Packet>(Scalar(1));
+ Packet rint_a = generic_rint(a);
+ // if rint(a) < a, then rint(a) == floor(a)
+ Packet mask = pcmp_lt(rint_a, a);
+ Packet offset = pand(cst_1, mask);
+ Packet result = padd(rint_a, offset);
+ return result;
+}
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a) {
+ Packet abs_a = pabs(a);
+ Packet sign_a = pandnot(a, abs_a);
+ Packet floor_abs_a = generic_floor(abs_a);
+ Packet result = por(floor_abs_a, sign_a);
+ return result;
+}
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_round(const Packet& a) {
+ using Scalar = typename unpacket_traits<Packet>::type;
+ const Packet cst_half = pset1<Packet>(Scalar(0.5));
+ const Packet cst_1 = pset1<Packet>(Scalar(1));
+ Packet abs_a = pabs(a);
+ Packet sign_a = pandnot(a, abs_a);
+ Packet floor_abs_a = generic_floor(abs_a);
+ Packet diff = psub(abs_a, floor_abs_a);
+ Packet mask = pcmp_le(cst_half, diff);
+ Packet offset = pand(cst_1, mask);
+ Packet result = padd(floor_abs_a, offset);
+ result = por(result, sign_a);
+ return result;
+}
+
+template <typename Packet>
+struct nearest_integer_packetop_impl<Packet, /*IsScalar*/ false, /*IsInteger*/ false> {
+ using Scalar = typename unpacket_traits<Packet>::type;
+ static_assert(packet_traits<Scalar>::HasRound, "Generic nearest integer functions are disabled for this type.");
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_floor(const Packet& x) { return generic_floor(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_ceil(const Packet& x) { return generic_ceil(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_rint(const Packet& x) { return generic_rint(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_round(const Packet& x) { return generic_round(x); }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_trunc(const Packet& x) { return generic_trunc(x); }
+};
+
+template <typename Packet>
+struct nearest_integer_packetop_impl<Packet, /*IsScalar*/ false, /*IsInteger*/ true> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_floor(const Packet& x) { return x; }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_ceil(const Packet& x) { return x; }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_rint(const Packet& x) { return x; }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_round(const Packet& x) { return x; }
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run_trunc(const Packet& x) { return x; }
+};
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
index 05cac5c..41dc068 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -133,6 +133,21 @@
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& x);
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a);
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a);
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a);
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a);
+
+template <typename Packet>
+EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
+
// Macros for instantiating these generic functions for different backends.
#define EIGEN_PACKET_FUNCTION(METHOD, SCALAR, PACKET) \
template <> \
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index 17d534d..9c195c1 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -722,6 +722,7 @@
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) { return half(::rintf(float(a))); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) { return half(::roundf(float(a))); }
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half trunc(const half& a) { return half(::truncf(float(a))); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
return half(::fmodf(float(a), float(b)));
}
diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h
index 7900b0e..352c8f5 100644
--- a/Eigen/src/Core/arch/GPU/PacketMath.h
+++ b/Eigen/src/Core/arch/GPU/PacketMath.h
@@ -75,8 +75,7 @@
HasIGammac = 1,
HasBetaInc = 1,
- HasBlend = 0,
- HasFloor = 1,
+ HasBlend = 0
};
};
diff --git a/Eigen/src/Core/arch/HVX/PacketMath.h b/Eigen/src/Core/arch/HVX/PacketMath.h
index 7e139de..ccba96e 100644
--- a/Eigen/src/Core/arch/HVX/PacketMath.h
+++ b/Eigen/src/Core/arch/HVX/PacketMath.h
@@ -161,9 +161,6 @@
HasBlend = 0,
HasDiv = 0,
- HasFloor = 0,
- HasCeil = 0,
- HasRint = 0,
HasSin = 0,
HasCos = 0,
diff --git a/Eigen/src/Core/arch/MSA/PacketMath.h b/Eigen/src/Core/arch/MSA/PacketMath.h
index c1843c3..81da24f 100644
--- a/Eigen/src/Core/arch/MSA/PacketMath.h
+++ b/Eigen/src/Core/arch/MSA/PacketMath.h
@@ -91,9 +91,6 @@
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
HasBlend = 1
};
};
@@ -859,9 +856,6 @@
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
HasBlend = 1
};
};
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 2c18b5d..50cf56f 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -196,12 +196,7 @@
HasConj = 1,
HasSetLinear = 1,
HasBlend = 0,
-
HasDiv = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
-
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasACos = 1,
@@ -4470,76 +4465,25 @@
return vrndpq_f32(a);
}
-#else
-
template <>
-EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
- // Adds and subtracts signum(a) * 2^23 to force rounding.
- const Packet4f limit = pset1<Packet4f>(static_cast<float>(1 << 23));
- const Packet4f abs_a = pabs(a);
- Packet4f r = padd(abs_a, limit);
- // Don't compile-away addition and subtraction.
- EIGEN_OPTIMIZATION_BARRIER(r);
- r = psub(r, limit);
- // If greater than limit, simply return a. Otherwise, account for sign.
- r = pselect(pcmp_lt(abs_a, limit), pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
- return r;
+EIGEN_STRONG_INLINE Packet2f pround<Packet2f>(const Packet2f& a) {
+ return vrnda_f32(a);
}
template <>
-EIGEN_STRONG_INLINE Packet2f print(const Packet2f& a) {
- // Adds and subtracts signum(a) * 2^23 to force rounding.
- const Packet2f limit = pset1<Packet2f>(static_cast<float>(1 << 23));
- const Packet2f abs_a = pabs(a);
- Packet2f r = padd(abs_a, limit);
- // Don't compile-away addition and subtraction.
- EIGEN_OPTIMIZATION_BARRIER(r);
- r = psub(r, limit);
- // If greater than limit, simply return a. Otherwise, account for sign.
- r = pselect(pcmp_lt(abs_a, limit), pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
- return r;
+EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a) {
+ return vrndaq_f32(a);
}
template <>
-EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) {
- const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- Packet4f tmp = print<Packet4f>(a);
- // If greater, subtract one.
- Packet4f mask = pcmp_lt(a, tmp);
- mask = pand(mask, cst_1);
- return psub(tmp, mask);
+EIGEN_STRONG_INLINE Packet2f ptrunc<Packet2f>(const Packet2f& a) {
+ return vrnd_f32(a);
}
template <>
-EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a) {
- const Packet2f cst_1 = pset1<Packet2f>(1.0f);
- Packet2f tmp = print<Packet2f>(a);
- // If greater, subtract one.
- Packet2f mask = pcmp_lt(a, tmp);
- mask = pand(mask, cst_1);
- return psub(tmp, mask);
+EIGEN_STRONG_INLINE Packet4f ptrunc<Packet4f>(const Packet4f& a) {
+ return vrndq_f32(a);
}
-
-template <>
-EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) {
- const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- Packet4f tmp = print<Packet4f>(a);
- // If smaller, add one.
- Packet4f mask = pcmp_lt(tmp, a);
- mask = pand(mask, cst_1);
- return padd(tmp, mask);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a) {
- const Packet2f cst_1 = pset1<Packet2f>(1.0);
- Packet2f tmp = print<Packet2f>(a);
- // If smaller, add one.
- Packet2f mask = pcmp_lt(tmp, a);
- mask = pand(mask, cst_1);
- return padd(tmp, mask);
-}
-
#endif
/**
@@ -4800,10 +4744,6 @@
HasSetLinear = 1,
HasBlend = 0,
HasDiv = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
-
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
@@ -4984,6 +4924,16 @@
}
template <>
+EIGEN_STRONG_INLINE Packet4bf pround<Packet4bf>(const Packet4bf& a) {
+ return F32ToBf16(pround<Packet4f>(Bf16ToF32(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4bf ptrunc<Packet4bf>(const Packet4bf& a) {
+ return F32ToBf16(ptrunc<Packet4f>(Bf16ToF32(a)));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) {
return a;
}
@@ -5168,9 +5118,6 @@
HasBlend = 0,
HasDiv = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
HasExp = 1,
@@ -5461,6 +5408,16 @@
}
template <>
+EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a) {
+ return vrndaq_f64(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet2d ptrunc<Packet2d>(const Packet2d& a) {
+ return vrndq_f64(a);
+}
+
+template <>
EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
return pldexp_generic(a, exponent);
}
@@ -5521,9 +5478,6 @@
HasInsert = 1,
HasReduxp = 1,
HasDiv = 1,
- HasFloor = 1,
- HasCeil = 1,
- HasRint = 1,
HasSin = 0,
HasCos = 0,
HasLog = 0,
@@ -5792,6 +5746,26 @@
}
template <>
+EIGEN_STRONG_INLINE Packet8hf pround<Packet8hf>(const Packet8hf& a) {
+ return vrndaq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pround<Packet4hf>(const Packet4hf& a) {
+ return vrnda_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf ptrunc<Packet8hf>(const Packet8hf& a) {
+ return vrndq_f16(a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf ptrunc<Packet4hf>(const Packet4hf& a) {
+ return vrnd_f16(a);
+}
+
+template <>
EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
return vsqrtq_f16(a);
}
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 7bac3f9..e19e948 100644
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -198,12 +198,6 @@
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasBlend = 1,
- HasCeil = 1,
- HasFloor = 1,
-#ifdef EIGEN_VECTORIZE_SSE4_1
- HasRound = 1,
-#endif
- HasRint = 1,
HasSign = 0 // The manually vectorized version is slightly slower for SSE.
};
};
@@ -225,13 +219,7 @@
HasSqrt = 1,
HasRsqrt = 1,
HasATan = 1,
- HasBlend = 1,
- HasFloor = 1,
- HasCeil = 1,
-#ifdef EIGEN_VECTORIZE_SSE4_1
- HasRound = 1,
-#endif
- HasRint = 1
+ HasBlend = 1
};
};
template <>
@@ -1309,73 +1297,14 @@
EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) {
return _mm_floor_pd(a);
}
-#else
-template <>
-EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
- // Adds and subtracts signum(a) * 2^23 to force rounding.
- const Packet4f limit = pset1<Packet4f>(static_cast<float>(1 << 23));
- const Packet4f abs_a = pabs(a);
- Packet4f r = padd(abs_a, limit);
- // Don't compile-away addition and subtraction.
- EIGEN_OPTIMIZATION_BARRIER(r);
- r = psub(r, limit);
- // If greater than limit, simply return a. Otherwise, account for sign.
- r = pselect(pcmp_lt(abs_a, limit), pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
- return r;
-}
template <>
-EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) {
- // Adds and subtracts signum(a) * 2^52 to force rounding.
- const Packet2d limit = pset1<Packet2d>(static_cast<double>(1ull << 52));
- const Packet2d abs_a = pabs(a);
- Packet2d r = padd(abs_a, limit);
- // Don't compile-away addition and subtraction.
- EIGEN_OPTIMIZATION_BARRIER(r);
- r = psub(r, limit);
- // If greater than limit, simply return a. Otherwise, account for sign.
- r = pselect(pcmp_lt(abs_a, limit), pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a);
- return r;
+EIGEN_STRONG_INLINE Packet4f ptrunc<Packet4f>(const Packet4f& a) {
+ return _mm_round_ps(a, _MM_FROUND_TRUNC);
}
-
template <>
-EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) {
- const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- Packet4f tmp = print<Packet4f>(a);
- // If greater, subtract one.
- Packet4f mask = _mm_cmpgt_ps(tmp, a);
- mask = pand(mask, cst_1);
- return psub(tmp, mask);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) {
- const Packet2d cst_1 = pset1<Packet2d>(1.0);
- Packet2d tmp = print<Packet2d>(a);
- // If greater, subtract one.
- Packet2d mask = _mm_cmpgt_pd(tmp, a);
- mask = pand(mask, cst_1);
- return psub(tmp, mask);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) {
- const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- Packet4f tmp = print<Packet4f>(a);
- // If smaller, add one.
- Packet4f mask = _mm_cmplt_ps(tmp, a);
- mask = pand(mask, cst_1);
- return padd(tmp, mask);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) {
- const Packet2d cst_1 = pset1<Packet2d>(1.0);
- Packet2d tmp = print<Packet2d>(a);
- // If smaller, add one.
- Packet2d mask = _mm_cmplt_pd(tmp, a);
- mask = pand(mask, cst_1);
- return padd(tmp, mask);
+EIGEN_STRONG_INLINE Packet2d ptrunc<Packet2d>(const Packet2d& a) {
+ return _mm_round_pd(a, _MM_FROUND_TRUNC);
}
#endif
diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h
index 6a03de9..3f847a9 100644
--- a/Eigen/src/Core/arch/SVE/PacketMath.h
+++ b/Eigen/src/Core/arch/SVE/PacketMath.h
@@ -353,7 +353,6 @@
HasReduxp = 0, // Not implemented in SVE
HasDiv = 1,
- HasFloor = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
diff --git a/Eigen/src/Core/arch/ZVector/PacketMath.h b/Eigen/src/Core/arch/ZVector/PacketMath.h
index 8ac8f77..b456813 100644
--- a/Eigen/src/Core/arch/ZVector/PacketMath.h
+++ b/Eigen/src/Core/arch/ZVector/PacketMath.h
@@ -195,9 +195,6 @@
HasRsqrt = 1,
HasTanh = 1,
HasErf = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
HasNegate = 1,
HasBlend = 1
};
@@ -225,9 +222,6 @@
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
- HasRound = 1,
- HasFloor = 1,
- HasCeil = 1,
HasNegate = 1,
HasBlend = 1
};
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 8d95819..2b0c05c 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -882,7 +882,7 @@
struct functor_traits<scalar_floor_op<Scalar>> {
enum {
Cost = NumTraits<Scalar>::MulCost,
- PacketAccess = packet_traits<Scalar>::HasFloor || NumTraits<Scalar>::IsInteger
+ PacketAccess = packet_traits<Scalar>::HasRound || NumTraits<Scalar>::IsInteger
};
};
@@ -902,7 +902,7 @@
struct functor_traits<scalar_rint_op<Scalar>> {
enum {
Cost = NumTraits<Scalar>::MulCost,
- PacketAccess = packet_traits<Scalar>::HasRint || NumTraits<Scalar>::IsInteger
+ PacketAccess = packet_traits<Scalar>::HasRound || NumTraits<Scalar>::IsInteger
};
};
@@ -922,7 +922,27 @@
struct functor_traits<scalar_ceil_op<Scalar>> {
enum {
Cost = NumTraits<Scalar>::MulCost,
- PacketAccess = packet_traits<Scalar>::HasCeil || NumTraits<Scalar>::IsInteger
+ PacketAccess = packet_traits<Scalar>::HasRound || NumTraits<Scalar>::IsInteger
+ };
+};
+
+/** \internal
+ * \brief Template functor to compute the truncation of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::floor()
+ */
+template <typename Scalar>
+struct scalar_trunc_op {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { return numext::trunc(a); }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
+ return internal::ptrunc(a);
+ }
+};
+template <typename Scalar>
+struct functor_traits<scalar_trunc_op<Scalar>> {
+ enum {
+ Cost = NumTraits<Scalar>::MulCost,
+ PacketAccess = packet_traits<Scalar>::HasRound || NumTraits<Scalar>::IsInteger
};
};