gemm_pack_rhs: enable vectorized transpose for half-width packets

libeigen/eigen!2453

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h
index 556894c..50e26ec 100644
--- a/Eigen/src/Core/arch/AVX512/GemmKernel.h
+++ b/Eigen/src/Core/arch/AVX512/GemmKernel.h
@@ -1006,10 +1006,9 @@
       const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
       const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
       Index k = 0;
-      if ((PacketSize % 8) == 0)  // TODO enable vectorized transposition for PacketSize==4
-      {
+      EIGEN_IF_CONSTEXPR((PacketSize % 8) == 0 || PacketSize == 4) {
         for (; k < peeled_k; k += PacketSize) {
-          PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
+          PacketBlock<Packet, 8> kernel;
 
           kernel.packet[0] = dm0.template loadPacket<Packet>(k);
           kernel.packet[1] = dm1.template loadPacket<Packet>(k);
@@ -1020,16 +1019,43 @@
           kernel.packet[6] = dm6.template loadPacket<Packet>(k);
           kernel.packet[7] = dm7.template loadPacket<Packet>(k);
 
-          ptranspose(kernel);
+          EIGEN_IF_CONSTEXPR(PacketSize == 4) {
+            // For PacketSize==4 we cannot ptranspose 8 packets directly; compose two
+            // 4-packet transposes (cols 0-3 and 4-7) and interleave the halves so
+            // the 8 stores produce 4 rows of 8 packed elements.
+            PacketBlock<Packet, 4> tmp_lo;
+            tmp_lo.packet[0] = kernel.packet[0];
+            tmp_lo.packet[1] = kernel.packet[1];
+            tmp_lo.packet[2] = kernel.packet[2];
+            tmp_lo.packet[3] = kernel.packet[3];
+            ptranspose(tmp_lo);
+            PacketBlock<Packet, 4> tmp_hi;
+            tmp_hi.packet[0] = kernel.packet[4];
+            tmp_hi.packet[1] = kernel.packet[5];
+            tmp_hi.packet[2] = kernel.packet[6];
+            tmp_hi.packet[3] = kernel.packet[7];
+            ptranspose(tmp_hi);
+            kernel.packet[0] = tmp_lo.packet[0];
+            kernel.packet[1] = tmp_hi.packet[0];
+            kernel.packet[2] = tmp_lo.packet[1];
+            kernel.packet[3] = tmp_hi.packet[1];
+            kernel.packet[4] = tmp_lo.packet[2];
+            kernel.packet[5] = tmp_hi.packet[2];
+            kernel.packet[6] = tmp_lo.packet[3];
+            kernel.packet[7] = tmp_hi.packet[3];
+          }
+          else {
+            ptranspose(kernel);
+          }
 
           pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
-          pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
-          pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
-          pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
-          pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
-          pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
-          pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
-          pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
+          pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1]));
+          pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2]));
+          pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3]));
+          pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4]));
+          pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5]));
+          pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6]));
+          pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7]));
           count += 8 * PacketSize;
         }
       }
@@ -1059,19 +1085,35 @@
       const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
 
       Index k = 0;
-      if ((PacketSize % 4) == 0)  // TODO: enable vectorized transposition for PacketSize==2.
-      {
+      EIGEN_IF_CONSTEXPR((PacketSize % 4) == 0 || PacketSize == 2) {
         for (; k < peeled_k; k += PacketSize) {
-          PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
+          PacketBlock<Packet, 4> kernel;
           kernel.packet[0] = dm0.template loadPacket<Packet>(k);
-          kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
-          kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
-          kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
-          ptranspose(kernel);
+          kernel.packet[1] = dm1.template loadPacket<Packet>(k);
+          kernel.packet[2] = dm2.template loadPacket<Packet>(k);
+          kernel.packet[3] = dm3.template loadPacket<Packet>(k);
+          EIGEN_IF_CONSTEXPR(PacketSize == 2) {
+            // See the matching note in GeneralBlockPanelKernel.h.
+            PacketBlock<Packet, 2> tmp01;
+            tmp01.packet[0] = kernel.packet[0];
+            tmp01.packet[1] = kernel.packet[1];
+            ptranspose(tmp01);
+            PacketBlock<Packet, 2> tmp23;
+            tmp23.packet[0] = kernel.packet[2];
+            tmp23.packet[1] = kernel.packet[3];
+            ptranspose(tmp23);
+            kernel.packet[0] = tmp01.packet[0];
+            kernel.packet[1] = tmp23.packet[0];
+            kernel.packet[2] = tmp01.packet[1];
+            kernel.packet[3] = tmp23.packet[1];
+          }
+          else {
+            ptranspose(kernel);
+          }
           pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
-          pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
-          pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
-          pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
+          pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1]));
+          pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2]));
+          pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3]));
           count += 4 * PacketSize;
         }
       }
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index c4f2b83..2989a1b 100644
--- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -2222,19 +2222,37 @@
       const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
 
       Index k = 0;
-      if ((PacketSize % 4) == 0)  // TODO: enable vectorized transposition for PacketSize==2.
-      {
+      EIGEN_IF_CONSTEXPR((PacketSize % 4) == 0 || PacketSize == 2) {
         for (; k < peeled_k; k += PacketSize) {
-          PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
+          PacketBlock<Packet, 4> kernel;
           kernel.packet[0] = dm0.template loadPacket<Packet>(k);
-          kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
-          kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
-          kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
-          ptranspose(kernel);
+          kernel.packet[1] = dm1.template loadPacket<Packet>(k);
+          kernel.packet[2] = dm2.template loadPacket<Packet>(k);
+          kernel.packet[3] = dm3.template loadPacket<Packet>(k);
+          EIGEN_IF_CONSTEXPR(PacketSize == 2) {
+            // For PacketSize==2 we cannot ptranspose 4 packets directly; compose two
+            // 2-packet transposes and re-interleave so the 4 stores produce the
+            // packed-rhs layout (each store writing one half-row of the panel).
+            PacketBlock<Packet, 2> tmp01;
+            tmp01.packet[0] = kernel.packet[0];
+            tmp01.packet[1] = kernel.packet[1];
+            ptranspose(tmp01);
+            PacketBlock<Packet, 2> tmp23;
+            tmp23.packet[0] = kernel.packet[2];
+            tmp23.packet[1] = kernel.packet[3];
+            ptranspose(tmp23);
+            kernel.packet[0] = tmp01.packet[0];
+            kernel.packet[1] = tmp23.packet[0];
+            kernel.packet[2] = tmp01.packet[1];
+            kernel.packet[3] = tmp23.packet[1];
+          }
+          else {
+            ptranspose(kernel);
+          }
           pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
-          pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
-          pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
-          pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
+          pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1]));
+          pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2]));
+          pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3]));
           count += 4 * PacketSize;
         }
       }