Vectorize tan(x)

libeigen/eigen!2086

Co-authored-by: Rasmus Munk Larsen <rmlarsen@google.com>
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index 5ee67a5..f4f6794 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -32,6 +32,7 @@
 #ifdef EIGEN_VECTORIZE_AVX2
 EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d)
 EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d)
+EIGEN_DOUBLE_PACKET_FUNCTION(tan, Packet4d)
 #endif
 EIGEN_GENERIC_PACKET_FUNCTION(atan, Packet4d)
 EIGEN_GENERIC_PACKET_FUNCTION(exp2, Packet4d)
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index 318b375..eafff3d 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -110,6 +110,7 @@
     HasReciprocal = EIGEN_FAST_MATH,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasACos = 1,
     HasASin = 1,
     HasATan = 1,
@@ -143,6 +144,7 @@
 #ifdef EIGEN_VECTORIZE_AVX2
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
 #endif
     HasTanh = EIGEN_FAST_MATH,
     HasErf = 1,
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index ddc766b..c69ba15 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -119,6 +119,7 @@
     HasConj = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasACos = 1,
     HasASin = 1,
     HasATan = 1,
@@ -154,6 +155,7 @@
     HasCbrt = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasLog = 1,
     HasExp = 1,
     HasLog1p = 1,
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index c98f217..acc2048 100644
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -178,6 +178,7 @@
     HasAbs = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasACos = 1,
     HasASin = 1,
     HasATan = 1,
@@ -3098,6 +3099,7 @@
     HasAbs = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasTanh = EIGEN_FAST_MATH,
     HasErf = EIGEN_FAST_MATH,
     HasErfc = EIGEN_FAST_MATH,
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 487127b..13cdba7 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -773,6 +773,11 @@
   return pselect(zero_mask, cst_zero, pmax(pldexp(x, fx), _x));
 }
 
+// Enum for selecting which function to compute. SinCos is intended to compute
+// pairs of Sin and Cos of the even entries in the packet, e.g.
+// SinCos([a, *, b, *]) = [sin(a), cos(a), sin(b), cos(b)].
+enum class TrigFunction : uint8_t { Sin, Cos, Tan, SinCos };
+
 // The following code is inspired by the following stack-overflow answer:
 //   https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
 // It has been largely optimized:
@@ -829,7 +834,7 @@
   return float(double(int64_t(p)) * pio2_62);
 }
 
-template <bool ComputeSine, typename Packet, bool ComputeBoth = false>
+template <TrigFunction Func, typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
 #if EIGEN_COMP_GNUC_STRICT
     __attribute__((optimize("-fno-unsafe-math-optimizations")))
@@ -859,7 +864,7 @@
 #if defined(EIGEN_VECTORIZE_FMA)
   // This version requires true FMA for high accuracy.
   // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08):
-  const float huge_th = ComputeSine ? 117435.992f : 71476.0625f;
+  constexpr float huge_th = (Func == TrigFunction::Sin) ? 117435.992f : 71476.0625f;
   x = pmadd(y, pset1<Packet>(-1.57079601287841796875f), x);
   x = pmadd(y, pset1<Packet>(-3.1391647326017846353352069854736328125e-07f), x);
   x = pmadd(y, pset1<Packet>(-5.390302529957764765544681040410068817436695098876953125e-15f), x);
@@ -870,7 +875,7 @@
 
   // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively.
   // and 2 ULP up to:
-  const float huge_th = ComputeSine ? 25966.f : 18838.f;
+  constexpr float huge_th = (Func == TrigFunction::Sin) ? 25966.f : 18838.f;
   x = pmadd(y, pset1<Packet>(-1.5703125), x);  // = 0xbfc90000
   EIGEN_OPTIMIZATION_BARRIER(x)
   x = pmadd(y, pset1<Packet>(-0.000483989715576171875), x);  // = 0xb9fdc000
@@ -908,13 +913,6 @@
     y_int = ploadu<PacketI>(y_int2);
   }
 
-  // Compute the sign to apply to the polynomial.
-  // sin: sign = second_bit(y_int) xor signbit(_x)
-  // cos: sign = second_bit(y_int+1)
-  Packet sign_bit = ComputeSine ? pxor(_x, preinterpret<Packet>(plogical_shift_left<30>(y_int)))
-                                : preinterpret<Packet>(plogical_shift_left<30>(padd(y_int, csti_1)));
-  sign_bit = pand(sign_bit, cst_sign_mask);  // clear all but left most bit
-
   // Get the polynomial selection mask from the second bit of y_int
   // We'll calculate both (sin and cos) polynomials and then select from the two.
   Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(y_int, csti_1), pzero(y_int)));
@@ -943,7 +941,15 @@
   y2 = pmadd(y2, x, x);
 
   // Select the correct result from the two polynomials.
-  if (ComputeBoth) {
+  // Compute the sign to apply to the polynomial.
+  // sin: sign = second_bit(y_int) xor signbit(_x)
+  // cos: sign = second_bit(y_int+1)
+  Packet sign_bit = (Func == TrigFunction::Sin) ? pxor(_x, preinterpret<Packet>(plogical_shift_left<30>(y_int)))
+                                                : preinterpret<Packet>(plogical_shift_left<30>(padd(y_int, csti_1)));
+  sign_bit = pand(sign_bit, cst_sign_mask);  // clear all but left most bit
+
+  if ((Func == TrigFunction::SinCos) || (Func == TrigFunction::Tan)) {
+    // TODO(rmlarsen): Add single polynomial for tan(x) instead of paying for sin+cos+div.
     Packet peven = peven_mask(x);
     Packet ysin = pselect(poly_mask, y2, y1);
     Packet ycos = pselect(poly_mask, y1, y2);
@@ -951,23 +957,28 @@
     Packet sign_bit_cos = preinterpret<Packet>(plogical_shift_left<30>(padd(y_int, csti_1)));
     sign_bit_sin = pand(sign_bit_sin, cst_sign_mask);  // clear all but left most bit
     sign_bit_cos = pand(sign_bit_cos, cst_sign_mask);  // clear all but left most bit
-    y = pselect(peven, pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos));
+    y = (Func == TrigFunction::SinCos) ? pselect(peven, pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos))
+                                       : pdiv(pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos));
   } else {
-    y = ComputeSine ? pselect(poly_mask, y2, y1) : pselect(poly_mask, y1, y2);
+    y = (Func == TrigFunction::Sin) ? pselect(poly_mask, y2, y1) : pselect(poly_mask, y1, y2);
     y = pxor(y, sign_bit);
   }
-  // Update the sign and filter huge inputs
   return y;
 }
 
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_float(const Packet& x) {
-  return psincos_float<true>(x);
+  return psincos_float<TrigFunction::Sin>(x);
 }
 
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_float(const Packet& x) {
-  return psincos_float<false>(x);
+  return psincos_float<TrigFunction::Cos>(x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_float(const Packet& x) {
+  return psincos_float<TrigFunction::Tan>(x);
 }
 
 // Trigonometric argument reduction for double for inputs smaller than 15.
@@ -1007,7 +1018,7 @@
   return t;
 }
 
-template <bool ComputeSine, typename Packet, bool ComputeBoth = false>
+template <TrigFunction Func, typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
 #if EIGEN_COMP_GNUC_STRICT
     __attribute__((optimize("-fno-unsafe-math-optimizations")))
@@ -1094,19 +1105,26 @@
   Packet sign_sin = pxor(x, preinterpret<Packet>(plogical_shift_left<62>(q_int)));
   Packet sign_cos = preinterpret<Packet>(plogical_shift_left<62>(padd(q_int, cst_one)));
   Packet sign_bit, sFinalRes;
-  if (ComputeBoth) {
+  if (Func == TrigFunction::Sin) {
+    sign_bit = sign_sin;
+    sFinalRes = pselect(poly_mask, ssin, scos);
+  } else if (Func == TrigFunction::Cos) {
+    sign_bit = sign_cos;
+    sFinalRes = pselect(poly_mask, scos, ssin);
+  } else if (Func == TrigFunction::Tan) {
+    // TODO(rmlarsen): Add single polynomial for tan(x) instead of paying for sin+cos+div.
+    sign_bit = pxor(sign_sin, sign_cos);
+    sFinalRes = pdiv(pselect(poly_mask, ssin, scos), pselect(poly_mask, scos, ssin));
+  } else if (Func == TrigFunction::SinCos) {
     Packet peven = peven_mask(x);
     sign_bit = pselect((s), sign_sin, sign_cos);
     sFinalRes = pselect(pxor(peven, poly_mask), ssin, scos);
-  } else {
-    sign_bit = ComputeSine ? sign_sin : sign_cos;
-    sFinalRes = ComputeSine ? pselect(poly_mask, ssin, scos) : pselect(poly_mask, scos, ssin);
   }
   sign_bit = pand(sign_bit, cst_sign_mask);  // clear all but left most bit
   sFinalRes = pxor(sFinalRes, sign_bit);
 
   // If the inputs values are higher than that a value that the argument reduction can currently address, compute them
-  // using std::sin and std::cos
+  // using the C++ standard library.
   // TODO Remove it when huge angle argument reduction is implemented
   if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1<Packet>(huge_th), x_abs)))) {
     const int PacketSize = unpacket_traits<Packet>::size;
@@ -1117,10 +1135,15 @@
     for (int k = 0; k < PacketSize; ++k) {
       double val = x_cpy[k];
       if (std::abs(val) > huge_th && (numext::isfinite)(val)) {
-        if (ComputeBoth)
+        if (Func == TrigFunction::Sin) {
+          sincos_vals[k] = std::sin(val);
+        } else if (Func == TrigFunction::Cos) {
+          sincos_vals[k] = std::cos(val);
+        } else if (Func == TrigFunction::Tan) {
+          sincos_vals[k] = std::tan(val);
+        } else if (Func == TrigFunction::SinCos) {
           sincos_vals[k] = k % 2 == 0 ? std::sin(val) : std::cos(val);
-        else
-          sincos_vals[k] = ComputeSine ? std::sin(val) : std::cos(val);
+        }
       }
     }
     sFinalRes = ploadu<Packet>(sincos_vals);
@@ -1130,26 +1153,31 @@
 
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_double(const Packet& x) {
-  return psincos_double<true>(x);
+  return psincos_double<TrigFunction::Sin>(x);
 }
 
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_double(const Packet& x) {
-  return psincos_double<false>(x);
+  return psincos_double<TrigFunction::Cos>(x);
 }
 
-template <bool ComputeSin, typename Packet>
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_double(const Packet& x) {
+  return psincos_double<TrigFunction::Tan>(x);
+}
+
+template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
     std::enable_if_t<std::is_same<typename unpacket_traits<Packet>::type, float>::value, Packet>
     psincos_selector(const Packet& x) {
-  return psincos_float<ComputeSin, Packet, true>(x);
+  return psincos_float<TrigFunction::SinCos, Packet>(x);
 }
 
-template <bool ComputeSin, typename Packet>
+template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
     std::enable_if_t<std::is_same<typename unpacket_traits<Packet>::type, double>::value, Packet>
     psincos_selector(const Packet& x) {
-  return psincos_double<ComputeSin, Packet, true>(x);
+  return psincos_double<TrigFunction::SinCos, Packet>(x);
 }
 
 // Generic implementation of acos(x).
@@ -1599,7 +1627,7 @@
   // cis(y):
   RealPacket y = pand(odd_mask, a.v);
   y = por(y, pcplxflip(Packet(y)).v);
-  RealPacket cisy = psincos_selector<false, RealPacket>(y);
+  RealPacket cisy = psincos_selector<RealPacket>(y);
   cisy = pcplxflip(Packet(cisy)).v;  // cos(y) + i * sin(y)
 
   const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
index 69a5517..942ae12 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -110,6 +110,10 @@
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_float(const Packet& x);
 
+/** \internal \returns tan(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_float(const Packet& x);
+
 /** \internal \returns sin(x) for double precision float */
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_double(const Packet& x);
@@ -118,6 +122,10 @@
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_double(const Packet& x);
 
+/** \internal \returns tan(x) for double precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptan_double(const Packet& x);
+
 /** \internal \returns asin(x) for single precision float */
 template <typename Packet>
 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pasin_float(const Packet& x);
@@ -200,6 +208,7 @@
 #define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(PACKET) \
   EIGEN_FLOAT_PACKET_FUNCTION(sin, PACKET)                 \
   EIGEN_FLOAT_PACKET_FUNCTION(cos, PACKET)                 \
+  EIGEN_FLOAT_PACKET_FUNCTION(tan, PACKET)                 \
   EIGEN_FLOAT_PACKET_FUNCTION(asin, PACKET)                \
   EIGEN_FLOAT_PACKET_FUNCTION(acos, PACKET)                \
   EIGEN_FLOAT_PACKET_FUNCTION(tanh, PACKET)                \
@@ -216,6 +225,7 @@
 #define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
   EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET)                 \
   EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET)                 \
+  EIGEN_DOUBLE_PACKET_FUNCTION(tan, PACKET)                 \
   EIGEN_DOUBLE_PACKET_FUNCTION(log, PACKET)                 \
   EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET)                \
   EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET)                 \
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 6f93b15..bf50697 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -197,6 +197,7 @@
     HasDiv = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasACos = 1,
     HasASin = 1,
     HasATan = 1,
@@ -5017,6 +5018,7 @@
 #endif
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasSqrt = 1,
     HasRsqrt = 1,
     HasCbrt = 1,
diff --git a/Eigen/src/Core/arch/RVV10/PacketMath.h b/Eigen/src/Core/arch/RVV10/PacketMath.h
index 54db626..e0e0be4 100644
--- a/Eigen/src/Core/arch/RVV10/PacketMath.h
+++ b/Eigen/src/Core/arch/RVV10/PacketMath.h
@@ -507,6 +507,7 @@
 
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasLog = 1,
     HasExp = 1,
     HasSqrt = 1,
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 1ea23b0..7d53fa2 100644
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -183,6 +183,7 @@
     HasReciprocal = EIGEN_FAST_MATH,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasACos = 1,
     HasASin = 1,
     HasATan = 1,
@@ -216,6 +217,7 @@
     HasDiv = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasTanh = EIGEN_FAST_MATH,
     HasErf = EIGEN_FAST_MATH,
     HasErfc = EIGEN_FAST_MATH,
diff --git a/Eigen/src/Core/arch/SVE/MathFunctions.h b/Eigen/src/Core/arch/SVE/MathFunctions.h
index 8c8ed84..5967433 100644
--- a/Eigen/src/Core/arch/SVE/MathFunctions.h
+++ b/Eigen/src/Core/arch/SVE/MathFunctions.h
@@ -36,6 +36,11 @@
   return pcos_float(x);
 }
 
+template <>
+EIGEN_STRONG_INLINE PacketXf ptan<PacketXf>(const PacketXf& x) {
+  return ptan_float(x);
+}
+
 // Hyperbolic Tangent function.
 template <>
 EIGEN_STRONG_INLINE PacketXf ptanh<PacketXf>(const PacketXf& x) {
diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h
index 28fc62b..39b29fa 100644
--- a/Eigen/src/Core/arch/SVE/PacketMath.h
+++ b/Eigen/src/Core/arch/SVE/PacketMath.h
@@ -353,6 +353,7 @@
     HasCmp = 1,
     HasSin = EIGEN_FAST_MATH,
     HasCos = EIGEN_FAST_MATH,
+    HasTan = EIGEN_FAST_MATH,
     HasLog = 1,
     HasExp = 1,
     HasPow = 1,
diff --git a/Eigen/src/Core/arch/clang/PacketMath.h b/Eigen/src/Core/arch/clang/PacketMath.h
index e142264..19e5e8f 100644
--- a/Eigen/src/Core/arch/clang/PacketMath.h
+++ b/Eigen/src/Core/arch/clang/PacketMath.h
@@ -56,6 +56,7 @@
     HasReciprocal = 1,
     HasSin = 1,
     HasCos = 1,
+    HasTan = 1,
     HasACos = 1,
     HasASin = 1,
     HasATan = 1,