TensorConcatenation: fix packet() fast-path when concat axis is not innermost libeigen/eigen!2498 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/unsupported/Eigen/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/src/Tensor/TensorConcatenation.h index 15811a6..ae8b14c 100644 --- a/unsupported/Eigen/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/src/Tensor/TensorConcatenation.h
@@ -243,11 +243,14 @@ // When the packet sits entirely on one side of the concat boundary, delegate // to that operand's packet<>() rather than assembling PacketSize coeff() - // calls. The packet can straddle the boundary when either (a) the concat - // axis is the innermost dim and subs[axis] crosses left_dims[axis] within - // the packet, or (b) the innermost dim has fewer elements than PacketSize - // and the packet spills into a higher dim that happens to be the concat - // axis. Check the first and last linear index explicitly to cover both. + // calls. The packet stays on one side iff only the innermost dim varies + // across the packet -- i.e. all other subs match between the first and last + // index. When that holds, subs[m_axis] is either constant (m_axis is not + // innermost) or monotonic non-decreasing (m_axis is innermost), so checking + // just the endpoints decides the side. Otherwise subs[m_axis] can wrap back + // through the boundary mid-packet (as when the inner dim has fewer than + // PacketSize elements and the packet spills past the concat axis), so fall + // back to scalars. template <int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const { const int packetSize = PacketType<CoeffReturnType, Device>::size; @@ -281,8 +284,19 @@ const Dimensions& left_dims = m_leftImpl.dimensions(); const Index left_axis_size = left_dims[m_axis]; - const bool on_left = subs[m_axis] < left_axis_size && subs_end[m_axis] < left_axis_size; - const bool on_right = subs[m_axis] >= left_axis_size && subs_end[m_axis] >= left_axis_size; + const int innermost = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? 0 : NumDims - 1; + bool packet_in_single_inner_row = true; + EIGEN_UNROLL_LOOP + for (int i = 0; i < NumDims; ++i) { + if (i != innermost && subs[i] != subs_end[i]) { + packet_in_single_inner_row = false; + } + } + + const bool on_left = + packet_in_single_inner_row && subs[m_axis] < left_axis_size && subs_end[m_axis] < left_axis_size; + const bool on_right = + packet_in_single_inner_row && subs[m_axis] >= left_axis_size && subs_end[m_axis] >= left_axis_size; if (on_left) { Index left_index; @@ -320,8 +334,8 @@ return m_rightImpl.template packet<LoadMode>(right_index); } - // Straddling case (m_axis == innermost and the packet crosses the boundary): - // fall back to assembling scalars. + // The packet straddles the boundary or spans multiple inner rows: fall + // back to assembling scalars. EIGEN_ALIGN_MAX CoeffReturnType values[packetSize]; EIGEN_UNROLL_LOOP for (int i = 0; i < packetSize; ++i) {
diff --git a/unsupported/test/tensor_concatenation.cpp b/unsupported/test/tensor_concatenation.cpp index e7f53b6..c520a02 100644 --- a/unsupported/test/tensor_concatenation.cpp +++ b/unsupported/test/tensor_concatenation.cpp
@@ -95,8 +95,41 @@ } } -// TODO(phli): Add test once we have a real vectorized implementation. -// static void test_vectorized_concatenation() {} +// Exercise the packet() fast path when the concat axis is not the innermost +// dim and the inner dim is small enough that a packet load spans multiple +// rows -- including rows that fall on the right side of the boundary. The +// guard in packet() must reject this case and fall back to scalars. +template <int DataLayout> +static void test_concatenation_packet_axis_not_innermost() { + // Output shape (8, 6, 1) with concat along axis 1: each packet load whose + // first/last linear indices land on the left side will sweep through right + // rows in between unless the fast path is correctly guarded. + Tensor<float, 3, DataLayout> left(8, 3, 1); + Tensor<float, 3, DataLayout> right(8, 3, 1); + left.setRandom(); + right.setRandom(); + + Tensor<float, 3, DataLayout> concatenation = left.concatenate(right, 1); + VERIFY_IS_EQUAL(concatenation.dimension(0), 8); + VERIFY_IS_EQUAL(concatenation.dimension(1), 6); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + VERIFY_IS_EQUAL(concatenation(i, j + 3, 0), right(i, j, 0)); + } + } + + // Force evaluation through the packet path with a coefficient-wise op so + // the executor will request packets aligned to the output strides. + Tensor<float, 3, DataLayout> doubled = concatenation * concatenation.constant(2.0f); + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_APPROX(doubled(i, j, 0), 2.0f * left(i, j, 0)); + VERIFY_IS_APPROX(doubled(i, j + 3, 0), 2.0f * right(i, j, 0)); + } + } +} static void test_concatenation_as_lvalue() { Tensor<int, 2> t1(2, 3); @@ -123,6 +156,7 @@ CALL_SUBTEST(test_static_dimension_failure<RowMajor>()); CALL_SUBTEST(test_simple_concatenation<ColMajor>()); CALL_SUBTEST(test_simple_concatenation<RowMajor>()); - // CALL_SUBTEST(test_vectorized_concatenation()); + CALL_SUBTEST(test_concatenation_packet_axis_not_innermost<ColMajor>()); + CALL_SUBTEST(test_concatenation_packet_axis_not_innermost<RowMajor>()); CALL_SUBTEST(test_concatenation_as_lvalue()); }