gemm_pack_rhs: enable vectorized transpose for half-width packets libeigen/eigen!2453 Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h index 556894c..50e26ec 100644 --- a/Eigen/src/Core/arch/AVX512/GemmKernel.h +++ b/Eigen/src/Core/arch/AVX512/GemmKernel.h
@@ -1006,10 +1006,9 @@ const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6); const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7); Index k = 0; - if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4 - { + EIGEN_IF_CONSTEXPR((PacketSize % 8) == 0 || PacketSize == 4) { for (; k < peeled_k; k += PacketSize) { - PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel; + PacketBlock<Packet, 8> kernel; kernel.packet[0] = dm0.template loadPacket<Packet>(k); kernel.packet[1] = dm1.template loadPacket<Packet>(k); @@ -1020,16 +1019,43 @@ kernel.packet[6] = dm6.template loadPacket<Packet>(k); kernel.packet[7] = dm7.template loadPacket<Packet>(k); - ptranspose(kernel); + EIGEN_IF_CONSTEXPR(PacketSize == 4) { + // For PacketSize==4 we cannot ptranspose 8 packets directly; compose two + // 4-packet transposes (cols 0-3 and 4-7) and interleave the halves so + // the 8 stores produce 4 rows of 8 packed elements. + PacketBlock<Packet, 4> tmp_lo; + tmp_lo.packet[0] = kernel.packet[0]; + tmp_lo.packet[1] = kernel.packet[1]; + tmp_lo.packet[2] = kernel.packet[2]; + tmp_lo.packet[3] = kernel.packet[3]; + ptranspose(tmp_lo); + PacketBlock<Packet, 4> tmp_hi; + tmp_hi.packet[0] = kernel.packet[4]; + tmp_hi.packet[1] = kernel.packet[5]; + tmp_hi.packet[2] = kernel.packet[6]; + tmp_hi.packet[3] = kernel.packet[7]; + ptranspose(tmp_hi); + kernel.packet[0] = tmp_lo.packet[0]; + kernel.packet[1] = tmp_hi.packet[0]; + kernel.packet[2] = tmp_lo.packet[1]; + kernel.packet[3] = tmp_hi.packet[1]; + kernel.packet[4] = tmp_lo.packet[2]; + kernel.packet[5] = tmp_hi.packet[2]; + kernel.packet[6] = tmp_lo.packet[3]; + kernel.packet[7] = tmp_hi.packet[3]; + } + else { + ptranspose(kernel); + } pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize])); - pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize])); - pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize])); - pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize])); - pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize])); - pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize])); - pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize])); + pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1])); + pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2])); + pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3])); + pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4])); + pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5])); + pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6])); + pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7])); count += 8 * PacketSize; } } @@ -1059,19 +1085,35 @@ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); Index k = 0; - if ((PacketSize % 4) == 0) // TODO: enable vectorized transposition for PacketSize==2. - { + EIGEN_IF_CONSTEXPR((PacketSize % 4) == 0 || PacketSize == 2) { for (; k < peeled_k; k += PacketSize) { - PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel; + PacketBlock<Packet, 4> kernel; kernel.packet[0] = dm0.template loadPacket<Packet>(k); - kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k); - kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k); - kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k); - ptranspose(kernel); + kernel.packet[1] = dm1.template loadPacket<Packet>(k); + kernel.packet[2] = dm2.template loadPacket<Packet>(k); + kernel.packet[3] = dm3.template loadPacket<Packet>(k); + EIGEN_IF_CONSTEXPR(PacketSize == 2) { + // See the matching note in GeneralBlockPanelKernel.h. + PacketBlock<Packet, 2> tmp01; + tmp01.packet[0] = kernel.packet[0]; + tmp01.packet[1] = kernel.packet[1]; + ptranspose(tmp01); + PacketBlock<Packet, 2> tmp23; + tmp23.packet[0] = kernel.packet[2]; + tmp23.packet[1] = kernel.packet[3]; + ptranspose(tmp23); + kernel.packet[0] = tmp01.packet[0]; + kernel.packet[1] = tmp23.packet[0]; + kernel.packet[2] = tmp01.packet[1]; + kernel.packet[3] = tmp23.packet[1]; + } + else { + ptranspose(kernel); + } pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize])); - pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize])); - pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize])); + pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1])); + pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2])); + pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3])); count += 4 * PacketSize; } }
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index c4f2b83..2989a1b 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -2222,19 +2222,37 @@ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); Index k = 0; - if ((PacketSize % 4) == 0) // TODO: enable vectorized transposition for PacketSize==2. - { + EIGEN_IF_CONSTEXPR((PacketSize % 4) == 0 || PacketSize == 2) { for (; k < peeled_k; k += PacketSize) { - PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel; + PacketBlock<Packet, 4> kernel; kernel.packet[0] = dm0.template loadPacket<Packet>(k); - kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k); - kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k); - kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k); - ptranspose(kernel); + kernel.packet[1] = dm1.template loadPacket<Packet>(k); + kernel.packet[2] = dm2.template loadPacket<Packet>(k); + kernel.packet[3] = dm3.template loadPacket<Packet>(k); + EIGEN_IF_CONSTEXPR(PacketSize == 2) { + // For PacketSize==2 we cannot ptranspose 4 packets directly; compose two + // 2-packet transposes and re-interleave so the 4 stores produce the + // packed-rhs layout (each store writing one half-row of the panel). + PacketBlock<Packet, 2> tmp01; + tmp01.packet[0] = kernel.packet[0]; + tmp01.packet[1] = kernel.packet[1]; + ptranspose(tmp01); + PacketBlock<Packet, 2> tmp23; + tmp23.packet[0] = kernel.packet[2]; + tmp23.packet[1] = kernel.packet[3]; + ptranspose(tmp23); + kernel.packet[0] = tmp01.packet[0]; + kernel.packet[1] = tmp23.packet[0]; + kernel.packet[2] = tmp01.packet[1]; + kernel.packet[3] = tmp23.packet[1]; + } + else { + ptranspose(kernel); + } pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize])); - pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize])); - pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize])); + pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1])); + pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2])); + pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3])); count += 4 * PacketSize; } }