TensorConcatenation: fix packet() fast-path when concat axis is not innermost

libeigen/eigen!2498

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/unsupported/Eigen/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/src/Tensor/TensorConcatenation.h
index 15811a6..ae8b14c 100644
--- a/unsupported/Eigen/src/Tensor/TensorConcatenation.h
+++ b/unsupported/Eigen/src/Tensor/TensorConcatenation.h
@@ -243,11 +243,14 @@
 
   // When the packet sits entirely on one side of the concat boundary, delegate
   // to that operand's packet<>() rather than assembling PacketSize coeff()
-  // calls. The packet can straddle the boundary when either (a) the concat
-  // axis is the innermost dim and subs[axis] crosses left_dims[axis] within
-  // the packet, or (b) the innermost dim has fewer elements than PacketSize
-  // and the packet spills into a higher dim that happens to be the concat
-  // axis. Check the first and last linear index explicitly to cover both.
+  // calls. The packet stays on one side iff only the innermost dim varies
+  // across the packet -- i.e. all other subs match between the first and last
+  // index. When that holds, subs[m_axis] is either constant (m_axis is not
+  // innermost) or monotonic non-decreasing (m_axis is innermost), so checking
+  // just the endpoints decides the side. Otherwise subs[m_axis] can wrap back
+  // through the boundary mid-packet (as when the inner dim has fewer than
+  // PacketSize elements and the packet spills past the concat axis), so fall
+  // back to scalars.
   template <int LoadMode>
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
     const int packetSize = PacketType<CoeffReturnType, Device>::size;
@@ -281,8 +284,19 @@
     const Dimensions& left_dims = m_leftImpl.dimensions();
     const Index left_axis_size = left_dims[m_axis];
 
-    const bool on_left = subs[m_axis] < left_axis_size && subs_end[m_axis] < left_axis_size;
-    const bool on_right = subs[m_axis] >= left_axis_size && subs_end[m_axis] >= left_axis_size;
+    const int innermost = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? 0 : NumDims - 1;
+    bool packet_in_single_inner_row = true;
+    EIGEN_UNROLL_LOOP
+    for (int i = 0; i < NumDims; ++i) {
+      if (i != innermost && subs[i] != subs_end[i]) {
+        packet_in_single_inner_row = false;
+      }
+    }
+
+    const bool on_left =
+        packet_in_single_inner_row && subs[m_axis] < left_axis_size && subs_end[m_axis] < left_axis_size;
+    const bool on_right =
+        packet_in_single_inner_row && subs[m_axis] >= left_axis_size && subs_end[m_axis] >= left_axis_size;
 
     if (on_left) {
       Index left_index;
@@ -320,8 +334,8 @@
       return m_rightImpl.template packet<LoadMode>(right_index);
     }
 
-    // Straddling case (m_axis == innermost and the packet crosses the boundary):
-    // fall back to assembling scalars.
+    // The packet straddles the boundary or spans multiple inner rows: fall
+    // back to assembling scalars.
     EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
     EIGEN_UNROLL_LOOP
     for (int i = 0; i < packetSize; ++i) {
diff --git a/unsupported/test/tensor_concatenation.cpp b/unsupported/test/tensor_concatenation.cpp
index e7f53b6..c520a02 100644
--- a/unsupported/test/tensor_concatenation.cpp
+++ b/unsupported/test/tensor_concatenation.cpp
@@ -95,8 +95,41 @@
   }
 }
 
-// TODO(phli): Add test once we have a real vectorized implementation.
-// static void test_vectorized_concatenation() {}
+// Exercise the packet() fast path when the concat axis is not the innermost
+// dim and the inner dim is small enough that a packet load spans multiple
+// rows -- including rows that fall on the right side of the boundary. The
+// guard in packet() must reject this case and fall back to scalars.
+template <int DataLayout>
+static void test_concatenation_packet_axis_not_innermost() {
+  // Output shape (8, 6, 1) with concat along axis 1: each packet load whose
+  // first/last linear indices land on the left side will sweep through right
+  // rows in between unless the fast path is correctly guarded.
+  Tensor<float, 3, DataLayout> left(8, 3, 1);
+  Tensor<float, 3, DataLayout> right(8, 3, 1);
+  left.setRandom();
+  right.setRandom();
+
+  Tensor<float, 3, DataLayout> concatenation = left.concatenate(right, 1);
+  VERIFY_IS_EQUAL(concatenation.dimension(0), 8);
+  VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
+  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
+  for (int i = 0; i < 8; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
+      VERIFY_IS_EQUAL(concatenation(i, j + 3, 0), right(i, j, 0));
+    }
+  }
+
+  // Force evaluation through the packet path with a coefficient-wise op so
+  // the executor will request packets aligned to the output strides.
+  Tensor<float, 3, DataLayout> doubled = concatenation * concatenation.constant(2.0f);
+  for (int i = 0; i < 8; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      VERIFY_IS_APPROX(doubled(i, j, 0), 2.0f * left(i, j, 0));
+      VERIFY_IS_APPROX(doubled(i, j + 3, 0), 2.0f * right(i, j, 0));
+    }
+  }
+}
 
 static void test_concatenation_as_lvalue() {
   Tensor<int, 2> t1(2, 3);
@@ -123,6 +156,7 @@
   CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
   CALL_SUBTEST(test_simple_concatenation<ColMajor>());
   CALL_SUBTEST(test_simple_concatenation<RowMajor>());
-  // CALL_SUBTEST(test_vectorized_concatenation());
+  CALL_SUBTEST(test_concatenation_packet_axis_not_innermost<ColMajor>());
+  CALL_SUBTEST(test_concatenation_packet_axis_not_innermost<RowMajor>());
   CALL_SUBTEST(test_concatenation_as_lvalue());
 }