Enable packet segment in partial redux
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h
index a33a21a..b4e8794 100644
--- a/Eigen/src/Core/AssignEvaluator.h
+++ b/Eigen/src/Core/AssignEvaluator.h
@@ -136,8 +136,7 @@
: Traversal == SliceVectorizedTraversal ? (MayUnrollInner ? InnerUnrolling : NoUnrolling)
#endif
: NoUnrolling;
- static constexpr bool UsePacketSegment =
- enable_packet_segment<Src>::value && enable_packet_segment<Dst>::value && has_packet_segment<PacketType>::value;
+ static constexpr bool UsePacketSegment = has_packet_segment<PacketType>::value;
#ifdef EIGEN_DEBUG_ASSIGN
static void debug() {
diff --git a/Eigen/src/Core/PartialReduxEvaluator.h b/Eigen/src/Core/PartialReduxEvaluator.h
index 7b2c8dc..1f638f9 100644
--- a/Eigen/src/Core/PartialReduxEvaluator.h
+++ b/Eigen/src/Core/PartialReduxEvaluator.h
@@ -103,19 +103,36 @@
EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size) {
if (size == 0) return packetwise_redux_empty_value<PacketType>(func);
- const Index size4 = (size - 1) & (~3);
+ const Index size4 = 1 + numext::round_down(size - 1, 4);
PacketType p = eval.template packetByOuterInner<Unaligned, PacketType>(0, 0);
- Index i = 1;
// This loop is optimized for instruction pipelining:
// - each iteration generates two independent instructions
// - thanks to branch prediction and out-of-order execution we have independent instructions across loops
- for (; i < size4; i += 4)
+ for (Index i = 1; i < size4; i += 4)
p = func.packetOp(
p, func.packetOp(func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 0, 0),
eval.template packetByOuterInner<Unaligned, PacketType>(i + 1, 0)),
func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 2, 0),
eval.template packetByOuterInner<Unaligned, PacketType>(i + 3, 0))));
- for (; i < size; ++i) p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0));
+ for (Index i = size4; i < size; ++i)
+ p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0));
+ return p;
+ }
+};
+
+template <typename Func, typename Evaluator>
+struct packetwise_segment_redux_impl {
+ typedef typename Evaluator::Scalar Scalar;
+ typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
+
+ template <typename PacketType>
+ EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size, Index begin,
+ Index count) {
+ if (size == 0) return packetwise_redux_empty_value<PacketType>(func);
+
+ PacketType p = eval.template packetSegmentByOuterInner<Unaligned, PacketType>(0, 0, begin, count);
+ for (Index i = 1; i < size; ++i)
+ p = func.packetOp(p, eval.template packetSegmentByOuterInner<Unaligned, PacketType>(i, 0, begin, count));
return p;
}
};
@@ -174,14 +191,13 @@
template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packet(Index idx) const {
- enum { PacketSize = internal::unpacket_traits<PacketType>::size };
- typedef Block<const ArgTypeNestedCleaned, Direction == Vertical ? int(ArgType::RowsAtCompileTime) : int(PacketSize),
- Direction == Vertical ? int(PacketSize) : int(ArgType::ColsAtCompileTime), true /* InnerPanel */>
- PanelType;
-
- PanelType panel(m_arg, Direction == Vertical ? 0 : idx, Direction == Vertical ? idx : 0,
- Direction == Vertical ? m_arg.rows() : Index(PacketSize),
- Direction == Vertical ? Index(PacketSize) : m_arg.cols());
+ static constexpr int PacketSize = internal::unpacket_traits<PacketType>::size;
+ static constexpr int PanelRows = Direction == Vertical ? ArgType::RowsAtCompileTime : PacketSize;
+ static constexpr int PanelCols = Direction == Vertical ? PacketSize : ArgType::ColsAtCompileTime;
+ using PanelType = Block<const ArgTypeNestedCleaned, PanelRows, PanelCols, true /* InnerPanel */>;
+ using PanelEvaluator = typename internal::redux_evaluator<PanelType>;
+ using BinaryOp = typename MemberOp::BinaryOp;
+ using Impl = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>;
// FIXME
// See bug 1612, currently if PacketSize==1 (i.e. complex<double> with 128bits registers) then the storage-order of
@@ -189,11 +205,39 @@
// by pass "vectorization" in this case:
if (PacketSize == 1) return internal::pset1<PacketType>(coeff(idx));
- typedef typename internal::redux_evaluator<PanelType> PanelEvaluator;
+ Index startRow = Direction == Vertical ? 0 : idx;
+ Index startCol = Direction == Vertical ? idx : 0;
+ Index numRows = Direction == Vertical ? m_arg.rows() : PacketSize;
+ Index numCols = Direction == Vertical ? PacketSize : m_arg.cols();
+
+ PanelType panel(m_arg, startRow, startCol, numRows, numCols);
PanelEvaluator panel_eval(panel);
- typedef typename MemberOp::BinaryOp BinaryOp;
- PacketType p = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>::template run<PacketType>(
- panel_eval, m_functor.binaryFunc(), m_arg.outerSize());
+ PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize());
+ return p;
+ }
+
+ template <int LoadMode, typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index i, Index j, Index begin, Index count) const {
+ return packetSegment<LoadMode, PacketType>(Direction == Vertical ? j : i, begin, count);
+ }
+
+ template <int LoadMode, typename PacketType>
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packetSegment(Index idx, Index begin, Index count) const {
+ static constexpr int PanelRows = Direction == Vertical ? ArgType::RowsAtCompileTime : Dynamic;
+ static constexpr int PanelCols = Direction == Vertical ? Dynamic : ArgType::ColsAtCompileTime;
+ using PanelType = Block<const ArgTypeNestedCleaned, PanelRows, PanelCols, true /* InnerPanel */>;
+ using PanelEvaluator = typename internal::redux_evaluator<PanelType>;
+ using BinaryOp = typename MemberOp::BinaryOp;
+ using Impl = internal::packetwise_segment_redux_impl<BinaryOp, PanelEvaluator>;
+
+ Index startRow = Direction == Vertical ? 0 : idx;
+ Index startCol = Direction == Vertical ? idx : 0;
+ Index numRows = Direction == Vertical ? m_arg.rows() : begin + count;
+ Index numCols = Direction == Vertical ? begin + count : m_arg.cols();
+
+ PanelType panel(m_arg, startRow, startCol, numRows, numCols);
+ PanelEvaluator panel_eval(panel);
+ PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize(), begin, count);
return p;
}
diff --git a/Eigen/src/Core/Redux.h b/Eigen/src/Core/Redux.h
index 0c5f2d9..4e9ab0e 100644
--- a/Eigen/src/Core/Redux.h
+++ b/Eigen/src/Core/Redux.h
@@ -414,6 +414,13 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetByOuterInner(Index outer, Index inner) const {
return Base::template packet<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer);
}
+
+ template <int LoadMode, typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegmentByOuterInner(Index outer, Index inner, Index begin,
+ Index count) const {
+ return Base::template packetSegment<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer,
+ begin, count);
+ }
};
} // end namespace internal
diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h
index 1342478..b861b23 100644
--- a/Eigen/src/Core/VectorwiseOp.h
+++ b/Eigen/src/Core/VectorwiseOp.h
@@ -37,9 +37,6 @@
namespace internal {
-template <typename ArgType, typename MemberOp, int Direction>
-struct enable_packet_segment<PartialReduxExpr<ArgType, MemberOp, Direction>> : std::false_type {};
-
template <typename MatrixType, typename MemberOp, int Direction>
struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> > : traits<MatrixType> {
typedef typename MemberOp::result_type Scalar;
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index 8d1073c..3c0bc46 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -517,9 +517,6 @@
template <typename Packet>
struct has_packet_segment : std::false_type {};
-
-template <typename Xpr>
-struct enable_packet_segment : std::true_type {};
} // namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h
index 24268bf..a42bb0f 100644
--- a/Eigen/src/Core/util/XprHelper.h
+++ b/Eigen/src/Core/util/XprHelper.h
@@ -996,36 +996,6 @@
template <typename XprType>
struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
-/*---------------- load/store segment support ----------------*/
-
-// recursively traverse unary, binary, and ternary expressions to determine if packet segments are supported
-
-template <typename Func, typename Xpr>
-struct enable_packet_segment<CwiseNullaryOp<Func, Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
-
-template <typename Func, typename Xpr>
-struct enable_packet_segment<CwiseUnaryOp<Func, Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
-
-template <typename Func, typename LhsXpr, typename RhsXpr>
-struct enable_packet_segment<CwiseBinaryOp<Func, LhsXpr, RhsXpr>>
- : bool_constant<enable_packet_segment<remove_all_t<LhsXpr>>::value &&
- enable_packet_segment<remove_all_t<RhsXpr>>::value> {};
-
-template <typename Func, typename LhsXpr, typename MidXpr, typename RhsXpr>
-struct enable_packet_segment<CwiseTernaryOp<Func, LhsXpr, MidXpr, RhsXpr>>
- : bool_constant<enable_packet_segment<remove_all_t<LhsXpr>>::value &&
- enable_packet_segment<remove_all_t<MidXpr>>::value &&
- enable_packet_segment<remove_all_t<RhsXpr>>::value> {};
-
-template <typename Xpr>
-struct enable_packet_segment<ArrayWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
-
-template <typename Xpr>
-struct enable_packet_segment<MatrixWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
-
-template <typename Xpr>
-struct enable_packet_segment<DiagonalWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
-
} // end namespace internal
/** \class ScalarBinaryOpTraits