Detect "effectively inner/outer" chipping in TensorChipping
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h index 000b1fb..f5172cd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h
@@ -173,6 +173,26 @@ } m_inputStride *= input_dims[m_dim.actualDim()]; m_inputOffset = m_stride * op.offset(); + + // Check if chipping is effectively inner or outer: products of dimensions + // before or after the chipped dimension is `1`. + Index after_chipped_dim_product = 1; + for (int i = m_dim.actualDim() + 1; i < NumInputDims; ++i) { + after_chipped_dim_product *= input_dims[i]; + } + + Index before_chipped_dim_product = 1; + for (int i = 0; i < m_dim.actualDim(); ++i) { + before_chipped_dim_product *= input_dims[i]; + } + + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { + m_isEffectivelyInnerChipping = before_chipped_dim_product == 1; + m_isEffectivelyOuterChipping = after_chipped_dim_product == 1; + } else { + m_isEffectivelyInnerChipping = after_chipped_dim_product == 1; + m_isEffectivelyOuterChipping = before_chipped_dim_product == 1; + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -336,13 +356,11 @@ } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const { - return IsInnerChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == 0) || - (static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == NumInputDims - 1); + return IsInnerChipping || m_isEffectivelyInnerChipping; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const { - return IsOuterChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == NumInputDims - 1) || - (static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == 0); + return IsOuterChipping || m_isEffectivelyOuterChipping; } Dimensions m_dimensions; @@ -352,6 +370,11 @@ TensorEvaluator<ArgType, Device> m_impl; const internal::DimensionId<DimId> m_dim; const Device EIGEN_DEVICE_REF m_device; + + // If product of all dimensions after or before the chipped dimension is `1`, + // it is effectively the same as chipping innermost or outermost dimension. + bool m_isEffectivelyInnerChipping; + bool m_isEffectivelyOuterChipping; }; // Eval as lvalue