Use 3px8/2px8/1px8/1x8 gebp_kernel on arm64-neon
diff --git a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
index 6cd6edd..5022205 100644
--- a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h
@@ -49,7 +49,9 @@
{
typedef float RhsPacket;
typedef float32x4_t RhsPacketx4;
-
+ enum {
+ nr = 8
+ };
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
@@ -77,7 +79,6 @@
{
c = vfmaq_n_f32(c, a, b);
}
-
// NOTE: Template parameter inference failed when compiled with Android NDK:
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
@@ -94,9 +95,10 @@
template<int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
- #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
- // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
- // vfmaq_laneq_f32 is implemented through a costly dup
+ #if EIGEN_COMP_GNUC_STRICT
+ // 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
+ // vfmaq_laneq_f32 is implemented through a costly dup, which was fixed in gcc9
+ // 2. workaround the gcc register split problem on arm64-neon
if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
@@ -113,7 +115,9 @@
: gebp_traits<double,double,false,false,Architecture::Generic>
{
typedef double RhsPacket;
-
+ enum {
+ nr = 8
+ };
struct RhsPacketx4 {
float64x2_t B_0, B_1;
};
@@ -163,9 +167,10 @@
template <int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
- #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
- // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
- // vfmaq_laneq_f64 is implemented through a costly dup
+ #if EIGEN_COMP_GNUC_STRICT
+ // 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
+ // vfmaq_laneq_f64 is implemented through a costly dup, which was fixed in gcc9
+ // 2. workaround the gcc register split problem on arm64-neon
if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
@@ -179,6 +184,77 @@
}
};
+#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
+
+template<>
+struct gebp_traits <half,half,false,false,Architecture::NEON>
+ : gebp_traits<half,half,false,false,Architecture::Generic>
+{
+ typedef half RhsPacket;
+ typedef float16x4_t RhsPacketx4;
+ typedef float16x4_t PacketHalf;
+ enum {
+ nr = 8
+ };
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
+ {
+ dest = vld1_f16((const __fp16 *)b);
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
+ {
+ dest = *b;
+ }
+
+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
+ {}
+
+ EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
+ {
+ loadRhs(b,dest);
+ }
+
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ {
+ c = vfmaq_n_f16(c, a, b);
+ }
+ EIGEN_STRONG_INLINE void madd(const PacketHalf& a, const RhsPacket& b, PacketHalf& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ {
+ c = vfma_n_f16(c, a, b);
+ }
+
+ // NOTE: Template parameter inference failed when compiled with Android NDK:
+ // "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
+ { madd_helper<0>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
+ { madd_helper<1>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
+ { madd_helper<2>(a, b, c); }
+ EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
+ { madd_helper<3>(a, b, c); }
+ private:
+ template<int LaneID>
+ EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
+ {
+ #if EIGEN_COMP_GNUC_STRICT
+ // 1. vfmaq_lane_f16 is implemented through a costly dup
+ // 2. workaround the gcc register split problem on arm64-neon
+ if(LaneID==0) asm("fmla %0.8h, %1.8h, %2.h[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==1) asm("fmla %0.8h, %1.8h, %2.h[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==2) asm("fmla %0.8h, %1.8h, %2.h[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ else if(LaneID==3) asm("fmla %0.8h, %1.8h, %2.h[3]\n" : "+w" (c) : "w" (a), "w" (b) : );
+ #else
+ c = vfmaq_lane_f16(c, a, b, LaneID);
+ #endif
+ }
+};
+#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
#endif // EIGEN_ARCH_ARM64
} // namespace internal
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index b1a1277..0b07651 100644
--- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -1070,6 +1070,7 @@
typedef typename Traits::RhsPacketx4 RhsPacketx4;
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
+ typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 27>::type RhsPanel27;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
@@ -1215,13 +1216,135 @@
int prefetch_res_offset, Index peeled_kc, Index pk, Index cols, Index depth, Index packet_cols4)
{
GEBPTraits traits;
-
+ Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
// loops on each largest micro horizontal panel of lhs
// (LhsProgress x depth)
for(Index i=peelStart; i<peelEnd; i+=LhsProgress)
{
+ for(Index j2=0; j2<packet_cols8; j2+=8)
+ {
+ const LhsScalar* blA = &blockA[i*strideA+offsetA*(LhsProgress)];
+ prefetch(&blA[0]);
+
+ // gets res block as register
+ AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
+ traits.initAcc(C0);
+ traits.initAcc(C1);
+ traits.initAcc(C2);
+ traits.initAcc(C3);
+ traits.initAcc(C4);
+ traits.initAcc(C5);
+ traits.initAcc(C6);
+ traits.initAcc(C7);
+
+ LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
+ LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
+ LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
+ LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
+ LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
+ LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
+ LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
+ LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
+ r0.prefetch(prefetch_res_offset);
+ r1.prefetch(prefetch_res_offset);
+ r2.prefetch(prefetch_res_offset);
+ r3.prefetch(prefetch_res_offset);
+ r4.prefetch(prefetch_res_offset);
+ r5.prefetch(prefetch_res_offset);
+ r6.prefetch(prefetch_res_offset);
+ r7.prefetch(prefetch_res_offset);
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
+ prefetch(&blB[0]);
+
+ LhsPacket A0;
+ for(Index k=0; k<peeled_kc; k+=pk)
+ {
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
+#define EIGEN_GEBGP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX8"); \
+ traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], A0); \
+ traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
+ traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
+ traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
+ traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
+ traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
+ traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
+ traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
+ traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX8"); \
+ } while (false)
+
+ EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX8");
+
+ EIGEN_GEBGP_ONESTEP(0);
+ EIGEN_GEBGP_ONESTEP(1);
+ EIGEN_GEBGP_ONESTEP(2);
+ EIGEN_GEBGP_ONESTEP(3);
+ EIGEN_GEBGP_ONESTEP(4);
+ EIGEN_GEBGP_ONESTEP(5);
+ EIGEN_GEBGP_ONESTEP(6);
+ EIGEN_GEBGP_ONESTEP(7);
+
+ blB += pk*8*RhsProgress;
+ blA += pk*(1*LhsProgress);
+
+ EIGEN_ASM_COMMENT("end gebp micro kernel 1pX8");
+ }
+ // process remaining peeled loop
+ for(Index k=peeled_kc; k<depth; k++)
+ {
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
+ EIGEN_GEBGP_ONESTEP(0);
+ blB += 8*RhsProgress;
+ blA += 1*LhsProgress;
+ }
+
+#undef EIGEN_GEBGP_ONESTEP
+
+ ResPacket R0, R1;
+ ResPacket alphav = pset1<ResPacket>(alpha);
+
+ R0 = r0.template loadPacket<ResPacket>(0);
+ R1 = r1.template loadPacket<ResPacket>(0);
+ traits.acc(C0, alphav, R0);
+ traits.acc(C1, alphav, R1);
+ r0.storePacket(0, R0);
+ r1.storePacket(0, R1);
+
+ R0 = r2.template loadPacket<ResPacket>(0);
+ R1 = r3.template loadPacket<ResPacket>(0);
+ traits.acc(C2, alphav, R0);
+ traits.acc(C3, alphav, R1);
+ r2.storePacket(0, R0);
+ r3.storePacket(0, R1);
+
+ R0 = r4.template loadPacket<ResPacket>(0);
+ R1 = r5.template loadPacket<ResPacket>(0);
+ traits.acc(C4, alphav, R0);
+ traits.acc(C5, alphav, R1);
+ r4.storePacket(0, R0);
+ r5.storePacket(0, R1);
+
+ R0 = r6.template loadPacket<ResPacket>(0);
+ R1 = r7.template loadPacket<ResPacket>(0);
+ traits.acc(C6, alphav, R0);
+ traits.acc(C7, alphav, R1);
+ r6.storePacket(0, R0);
+ r7.storePacket(0, R1);
+ }
// loops on each largest micro vertical panel of rhs (depth * nr)
- for(Index j2=0; j2<packet_cols4; j2+=nr)
+ for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
// We select a LhsProgress x nr micro block of res
// which is entirely stored into 1 x nr registers.
@@ -1257,7 +1380,7 @@
r3.prefetch(prefetch_res_offset);
// performs "inner" products
- const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
prefetch(&blB[0]);
LhsPacket A0, A1;
@@ -1415,6 +1538,7 @@
if(strideB==-1) strideB = depth;
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
+ Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
const Index peeled_mc3 = mr>=3*Traits::LhsProgress ? (rows/(3*LhsProgress))*(3*LhsProgress) : 0;
const Index peeled_mc2 = mr>=2*Traits::LhsProgress ? peeled_mc3+((rows-peeled_mc3)/(2*LhsProgress))*(2*LhsProgress) : 0;
const Index peeled_mc1 = mr>=1*Traits::LhsProgress ? peeled_mc2+((rows-peeled_mc2)/(1*LhsProgress))*(1*LhsProgress) : 0;
@@ -1443,7 +1567,219 @@
for(Index i1=0; i1<peeled_mc3; i1+=actual_panel_rows)
{
const Index actual_panel_end = (std::min)(i1+actual_panel_rows, peeled_mc3);
- for(Index j2=0; j2<packet_cols4; j2+=nr)
+
+ // nr >= 8
+ for(Index j2=0; j2<packet_cols8; j2+=8)
+ {
+ for(Index i=i1; i<actual_panel_end; i+=3*LhsProgress)
+ {
+ const LhsScalar* blA = &blockA[i*strideA+offsetA*(3*LhsProgress)];
+ prefetch(&blA[0]);
+ // gets res block as register
+ AccPacket C0, C1, C2, C3, C4, C5, C6, C7,
+ C8, C9, C10, C11, C12, C13, C14, C15,
+ C16, C17, C18, C19, C20, C21, C22, C23;
+ traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3);
+ traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7);
+ traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11);
+ traits.initAcc(C12); traits.initAcc(C13); traits.initAcc(C14); traits.initAcc(C15);
+ traits.initAcc(C16); traits.initAcc(C17); traits.initAcc(C18); traits.initAcc(C19);
+ traits.initAcc(C20); traits.initAcc(C21); traits.initAcc(C22); traits.initAcc(C23);
+
+ LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
+ LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
+ LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
+ LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
+ LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
+ LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
+ LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
+ LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
+
+ r0.prefetch(0);
+ r1.prefetch(0);
+ r2.prefetch(0);
+ r3.prefetch(0);
+ r4.prefetch(0);
+ r5.prefetch(0);
+ r6.prefetch(0);
+ r7.prefetch(0);
+
+ // performs "inner" products
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
+ prefetch(&blB[0]);
+ LhsPacket A0, A1;
+ for(Index k=0; k<peeled_kc; k+=pk)
+ {
+ EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX8");
+ // 27 registers are taken (24 for acc, 3 for lhs).
+ RhsPanel27 rhs_panel;
+ RhsPacket T0;
+ LhsPacket A2;
+ #if EIGEN_COMP_GNUC_STRICT && EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && !(EIGEN_GNUC_AT_LEAST(9,0))
+ // see http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1633
+ // without this workaround A0, A1, and A2 are loaded in the same register,
+ // which is not good for pipelining
+ #define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND __asm__ ("" : "+w,m" (A0), "+w,m" (A1), "+w,m" (A2));
+ #else
+ #define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND
+ #endif
+
+#define EIGEN_GEBP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX8"); \
+ traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
+ traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
+ traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
+ EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND \
+ traits.loadRhs(blB + (0 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
+ traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
+ traits.madd(A2, rhs_panel, C16, T0, fix<0>); \
+ traits.updateRhs(blB + (1 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
+ traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
+ traits.madd(A2, rhs_panel, C17, T0, fix<1>); \
+ traits.updateRhs(blB + (2 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
+ traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
+ traits.madd(A2, rhs_panel, C18, T0, fix<2>); \
+ traits.updateRhs(blB + (3 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
+ traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
+ traits.madd(A2, rhs_panel, C19, T0, fix<3>); \
+ traits.loadRhs(blB + (4 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
+ traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
+ traits.madd(A2, rhs_panel, C20, T0, fix<0>); \
+ traits.updateRhs(blB + (5 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
+ traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
+ traits.madd(A2, rhs_panel, C21, T0, fix<1>); \
+ traits.updateRhs(blB + (6 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
+ traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
+ traits.madd(A2, rhs_panel, C22, T0, fix<2>); \
+ traits.updateRhs(blB + (7 + 8 * K) * Traits::RhsProgress, rhs_panel); \
+ traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
+ traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
+ traits.madd(A2, rhs_panel, C23, T0, fix<3>); \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX8"); \
+ } while (false)
+
+ EIGEN_GEBP_ONESTEP(0);
+ EIGEN_GEBP_ONESTEP(1);
+ EIGEN_GEBP_ONESTEP(2);
+ EIGEN_GEBP_ONESTEP(3);
+ EIGEN_GEBP_ONESTEP(4);
+ EIGEN_GEBP_ONESTEP(5);
+ EIGEN_GEBP_ONESTEP(6);
+ EIGEN_GEBP_ONESTEP(7);
+
+ blB += pk * 8 * RhsProgress;
+ blA += pk * 3 * Traits::LhsProgress;
+ EIGEN_ASM_COMMENT("end gebp micro kernel 3pX8");
+ }
+
+ // process remaining peeled loop
+ for (Index k = peeled_kc; k < depth; k++)
+ {
+
+ RhsPanel27 rhs_panel;
+ RhsPacket T0;
+ LhsPacket A2;
+ EIGEN_GEBP_ONESTEP(0);
+ blB += 8 * RhsProgress;
+ blA += 3 * Traits::LhsProgress;
+ }
+
+ #undef EIGEN_GEBP_ONESTEP
+
+ ResPacket R0, R1, R2;
+ ResPacket alphav = pset1<ResPacket>(alpha);
+
+ R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C0, alphav, R0);
+ traits.acc(C8, alphav, R1);
+ traits.acc(C16, alphav, R2);
+ r0.storePacket(0 * Traits::ResPacketSize, R0);
+ r0.storePacket(1 * Traits::ResPacketSize, R1);
+ r0.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C1, alphav, R0);
+ traits.acc(C9, alphav, R1);
+ traits.acc(C17, alphav, R2);
+ r1.storePacket(0 * Traits::ResPacketSize, R0);
+ r1.storePacket(1 * Traits::ResPacketSize, R1);
+ r1.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C2, alphav, R0);
+ traits.acc(C10, alphav, R1);
+ traits.acc(C18, alphav, R2);
+ r2.storePacket(0 * Traits::ResPacketSize, R0);
+ r2.storePacket(1 * Traits::ResPacketSize, R1);
+ r2.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C3, alphav, R0);
+ traits.acc(C11, alphav, R1);
+ traits.acc(C19, alphav, R2);
+ r3.storePacket(0 * Traits::ResPacketSize, R0);
+ r3.storePacket(1 * Traits::ResPacketSize, R1);
+ r3.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r4.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C4, alphav, R0);
+ traits.acc(C12, alphav, R1);
+ traits.acc(C20, alphav, R2);
+ r4.storePacket(0 * Traits::ResPacketSize, R0);
+ r4.storePacket(1 * Traits::ResPacketSize, R1);
+ r4.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r5.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C5, alphav, R0);
+ traits.acc(C13, alphav, R1);
+ traits.acc(C21, alphav, R2);
+ r5.storePacket(0 * Traits::ResPacketSize, R0);
+ r5.storePacket(1 * Traits::ResPacketSize, R1);
+ r5.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r6.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C6, alphav, R0);
+ traits.acc(C14, alphav, R1);
+ traits.acc(C22, alphav, R2);
+ r6.storePacket(0 * Traits::ResPacketSize, R0);
+ r6.storePacket(1 * Traits::ResPacketSize, R1);
+ r6.storePacket(2 * Traits::ResPacketSize, R2);
+
+ R0 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r7.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
+ traits.acc(C7, alphav, R0);
+ traits.acc(C15, alphav, R1);
+ traits.acc(C23, alphav, R2);
+ r7.storePacket(0 * Traits::ResPacketSize, R0);
+ r7.storePacket(1 * Traits::ResPacketSize, R1);
+ r7.storePacket(2 * Traits::ResPacketSize, R2);
+ }
+ }
+
+ for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
for(Index i=i1; i<actual_panel_end; i+=3*LhsProgress)
{
@@ -1473,14 +1809,14 @@
r3.prefetch(0);
// performs "inner" products
- const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
prefetch(&blB[0]);
LhsPacket A0, A1;
for(Index k=0; k<peeled_kc; k+=pk)
{
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4");
- // 15 registers are taken (12 for acc, 2 for lhs).
+ // 15 registers are taken (12 for acc, 3 for lhs).
RhsPanel15 rhs_panel;
RhsPacket T0;
LhsPacket A2;
@@ -1689,7 +2025,170 @@
for(Index i1=peeled_mc3; i1<peeled_mc2; i1+=actual_panel_rows)
{
Index actual_panel_end = (std::min)(i1+actual_panel_rows, peeled_mc2);
- for(Index j2=0; j2<packet_cols4; j2+=nr)
+ for(Index j2=0; j2<packet_cols8; j2+=8)
+ {
+ for(Index i=i1; i<actual_panel_end; i+=2*LhsProgress)
+ {
+ const LhsScalar* blA = &blockA[i*strideA+offsetA*(2*Traits::LhsProgress)];
+ prefetch(&blA[0]);
+
+ AccPacket C0, C1, C2, C3, C4, C5, C6, C7,
+ C8, C9, C10, C11, C12, C13, C14, C15;
+ traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3);
+ traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7);
+ traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11);
+ traits.initAcc(C12); traits.initAcc(C13); traits.initAcc(C14); traits.initAcc(C15);
+
+ LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
+ LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
+ LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
+ LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
+ LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
+ LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
+ LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
+ LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
+ r0.prefetch(prefetch_res_offset);
+ r1.prefetch(prefetch_res_offset);
+ r2.prefetch(prefetch_res_offset);
+ r3.prefetch(prefetch_res_offset);
+ r4.prefetch(prefetch_res_offset);
+ r5.prefetch(prefetch_res_offset);
+ r6.prefetch(prefetch_res_offset);
+ r7.prefetch(prefetch_res_offset);
+
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
+ prefetch(&blB[0]);
+ LhsPacket A0, A1;
+ for(Index k=0; k<peeled_kc; k+=pk)
+ {
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
+ // NOTE: the begin/end asm comments below work around bug 935!
+ // but they are not enough for gcc>=6 without FMA (bug 1637)
+ #if EIGEN_GNUC_AT_LEAST(6,0) && defined(EIGEN_VECTORIZE_SSE)
+ #define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND __asm__ ("" : [a0] "+x,m" (A0),[a1] "+x,m" (A1));
+ #else
+ #define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND
+ #endif
+#define EIGEN_GEBGP_ONESTEP(K) \
+ do { \
+ EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX8"); \
+ traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
+ traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
+ traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
+ traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
+ traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
+ traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
+ traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
+ traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
+ traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
+ traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
+ traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
+ traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
+ traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
+ traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
+ traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
+ traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
+ traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
+ traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
+ traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
+ EIGEN_GEBP_2Px8_SPILLING_WORKAROUND \
+ EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX8"); \
+ } while (false)
+
+ EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX8");
+
+ EIGEN_GEBGP_ONESTEP(0);
+ EIGEN_GEBGP_ONESTEP(1);
+ EIGEN_GEBGP_ONESTEP(2);
+ EIGEN_GEBGP_ONESTEP(3);
+ EIGEN_GEBGP_ONESTEP(4);
+ EIGEN_GEBGP_ONESTEP(5);
+ EIGEN_GEBGP_ONESTEP(6);
+ EIGEN_GEBGP_ONESTEP(7);
+
+ blB += pk*8*RhsProgress;
+ blA += pk*(2*Traits::LhsProgress);
+
+ EIGEN_ASM_COMMENT("end gebp micro kernel 2pX8");
+ }
+ // process remaining peeled loop
+ for(Index k=peeled_kc; k<depth; k++)
+ {
+ RhsPacketx4 rhs_panel;
+ RhsPacket T0;
+ EIGEN_GEBGP_ONESTEP(0);
+ blB += 8*RhsProgress;
+ blA += 2*Traits::LhsProgress;
+ }
+
+#undef EIGEN_GEBGP_ONESTEP
+
+ ResPacket R0, R1, R2, R3;
+ ResPacket alphav = pset1<ResPacket>(alpha);
+
+ R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ traits.acc(C0, alphav, R0);
+ traits.acc(C8, alphav, R1);
+ traits.acc(C1, alphav, R2);
+ traits.acc(C9, alphav, R3);
+ r0.storePacket(0 * Traits::ResPacketSize, R0);
+ r0.storePacket(1 * Traits::ResPacketSize, R1);
+ r1.storePacket(0 * Traits::ResPacketSize, R2);
+ r1.storePacket(1 * Traits::ResPacketSize, R3);
+
+ R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ traits.acc(C2, alphav, R0);
+ traits.acc(C10, alphav, R1);
+ traits.acc(C3, alphav, R2);
+ traits.acc(C11, alphav, R3);
+ r2.storePacket(0 * Traits::ResPacketSize, R0);
+ r2.storePacket(1 * Traits::ResPacketSize, R1);
+ r3.storePacket(0 * Traits::ResPacketSize, R2);
+ r3.storePacket(1 * Traits::ResPacketSize, R3);
+
+ R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R3 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ traits.acc(C4, alphav, R0);
+ traits.acc(C12, alphav, R1);
+ traits.acc(C5, alphav, R2);
+ traits.acc(C13, alphav, R3);
+ r4.storePacket(0 * Traits::ResPacketSize, R0);
+ r4.storePacket(1 * Traits::ResPacketSize, R1);
+ r5.storePacket(0 * Traits::ResPacketSize, R2);
+ r5.storePacket(1 * Traits::ResPacketSize, R3);
+
+ R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ R2 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
+ R3 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
+ traits.acc(C6, alphav, R0);
+ traits.acc(C14, alphav, R1);
+ traits.acc(C7, alphav, R2);
+ traits.acc(C15, alphav, R3);
+ r6.storePacket(0 * Traits::ResPacketSize, R0);
+ r6.storePacket(1 * Traits::ResPacketSize, R1);
+ r7.storePacket(0 * Traits::ResPacketSize, R2);
+ r7.storePacket(1 * Traits::ResPacketSize, R3);
+ }
+ }
+
+ for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
for(Index i=i1; i<actual_panel_end; i+=2*LhsProgress)
{
@@ -1717,7 +2216,7 @@
r3.prefetch(prefetch_res_offset);
// performs "inner" products
- const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
prefetch(&blB[0]);
LhsPacket A0, A1;
@@ -1907,14 +2406,66 @@
if(peeled_mc_quarter<rows)
{
// loop on each panel of the rhs
- for(Index j2=0; j2<packet_cols4; j2+=nr)
+ for(Index j2=0; j2<packet_cols8; j2+=8)
{
// loop on each row of the lhs (1*LhsProgress x depth)
for(Index i=peeled_mc_quarter; i<rows; i+=1)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
- const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
+ // gets a 1 x 1 res block as registers
+ ResScalar C0(0),C1(0),C2(0),C3(0),C4(0),C5(0),C6(0),C7(0);
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
+ for(Index k=0; k<depth; k++)
+ {
+ LhsScalar A0 = blA[k];
+ RhsScalar B_0;
+
+ B_0 = blB[0];
+ C0 = cj.pmadd(A0, B_0, C0);
+
+ B_0 = blB[1];
+ C1 = cj.pmadd(A0, B_0, C1);
+
+ B_0 = blB[2];
+ C2 = cj.pmadd(A0, B_0, C2);
+
+ B_0 = blB[3];
+ C3 = cj.pmadd(A0, B_0, C3);
+
+ B_0 = blB[4];
+ C4 = cj.pmadd(A0, B_0, C4);
+
+ B_0 = blB[5];
+ C5 = cj.pmadd(A0, B_0, C5);
+
+ B_0 = blB[6];
+ C6 = cj.pmadd(A0, B_0, C6);
+
+ B_0 = blB[7];
+ C7 = cj.pmadd(A0, B_0, C7);
+
+ blB += 8;
+ }
+ res(i, j2 + 0) += alpha * C0;
+ res(i, j2 + 1) += alpha * C1;
+ res(i, j2 + 2) += alpha * C2;
+ res(i, j2 + 3) += alpha * C3;
+ res(i, j2 + 4) += alpha * C4;
+ res(i, j2 + 5) += alpha * C5;
+ res(i, j2 + 6) += alpha * C6;
+ res(i, j2 + 7) += alpha * C7;
+ }
+ }
+
+ for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
+ {
+ // loop on each row of the lhs (1*LhsProgress x depth)
+ for(Index i=peeled_mc_quarter; i<rows; i+=1)
+ {
+ const LhsScalar* blA = &blockA[i*strideA+offsetA];
+ prefetch(&blA[0]);
+ const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
// If LhsProgress is 8 or 16, it assumes that there is a
// half or quarter packet, respectively, of the same size as
@@ -2397,51 +2948,121 @@
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
const Index peeled_k = (depth/PacketSize)*PacketSize;
-// if(nr>=8)
-// {
-// for(Index j2=0; j2<packet_cols8; j2+=8)
-// {
-// // skip what we have before
-// if(PanelMode) count += 8 * offset;
-// const Scalar* b0 = &rhs[(j2+0)*rhsStride];
-// const Scalar* b1 = &rhs[(j2+1)*rhsStride];
-// const Scalar* b2 = &rhs[(j2+2)*rhsStride];
-// const Scalar* b3 = &rhs[(j2+3)*rhsStride];
-// const Scalar* b4 = &rhs[(j2+4)*rhsStride];
-// const Scalar* b5 = &rhs[(j2+5)*rhsStride];
-// const Scalar* b6 = &rhs[(j2+6)*rhsStride];
-// const Scalar* b7 = &rhs[(j2+7)*rhsStride];
-// Index k=0;
-// if(PacketSize==8) // TODO enable vectorized transposition for PacketSize==4
-// {
-// for(; k<peeled_k; k+=PacketSize) {
-// PacketBlock<Packet> kernel;
-// for (int p = 0; p < PacketSize; ++p) {
-// kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]);
-// }
-// ptranspose(kernel);
-// for (int p = 0; p < PacketSize; ++p) {
-// pstoreu(blockB+count, cj.pconj(kernel.packet[p]));
-// count+=PacketSize;
-// }
-// }
-// }
-// for(; k<depth; k++)
-// {
-// blockB[count+0] = cj(b0[k]);
-// blockB[count+1] = cj(b1[k]);
-// blockB[count+2] = cj(b2[k]);
-// blockB[count+3] = cj(b3[k]);
-// blockB[count+4] = cj(b4[k]);
-// blockB[count+5] = cj(b5[k]);
-// blockB[count+6] = cj(b6[k]);
-// blockB[count+7] = cj(b7[k]);
-// count += 8;
-// }
-// // skip what we have after
-// if(PanelMode) count += 8 * (stride-offset-depth);
-// }
-// }
+
+ if(nr>=8)
+ {
+ for(Index j2=0; j2<packet_cols8; j2+=8)
+ {
+ // skip what we have before
+ if(PanelMode) count += 8 * offset;
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+ const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
+ const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
+ const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
+ const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
+ Index k = 0;
+ if (PacketSize % 2 == 0 && PacketSize <= 8) // 2 4 8
+ {
+ for (; k < peeled_k; k += PacketSize)
+ {
+ if (PacketSize == 2)
+ {
+ PacketBlock<Packet, PacketSize==2 ?2:PacketSize> kernel0, kernel1, kernel2, kernel3;
+ kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
+ kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
+ kernel1.packet[0%PacketSize] = dm2.template loadPacket<Packet>(k);
+ kernel1.packet[1%PacketSize] = dm3.template loadPacket<Packet>(k);
+ kernel2.packet[0%PacketSize] = dm4.template loadPacket<Packet>(k);
+ kernel2.packet[1%PacketSize] = dm5.template loadPacket<Packet>(k);
+ kernel3.packet[0%PacketSize] = dm6.template loadPacket<Packet>(k);
+ kernel3.packet[1%PacketSize] = dm7.template loadPacket<Packet>(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ ptranspose(kernel2);
+ ptranspose(kernel3);
+
+ pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
+ pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
+ pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel2.packet[0 % PacketSize]));
+ pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel3.packet[0 % PacketSize]));
+
+ pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
+ pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
+ pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel2.packet[1 % PacketSize]));
+ pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel3.packet[1 % PacketSize]));
+ count+=8*PacketSize;
+ }
+ else if (PacketSize == 4)
+ {
+ PacketBlock<Packet, PacketSize == 4?4:PacketSize> kernel0, kernel1;
+
+ kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
+ kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
+ kernel0.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
+ kernel0.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
+ kernel1.packet[0%PacketSize] = dm4.template loadPacket<Packet>(k);
+ kernel1.packet[1%PacketSize] = dm5.template loadPacket<Packet>(k);
+ kernel1.packet[2%PacketSize] = dm6.template loadPacket<Packet>(k);
+ kernel1.packet[3%PacketSize] = dm7.template loadPacket<Packet>(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+
+ pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize]));
+ pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel1.packet[0%PacketSize]));
+ pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[1%PacketSize]));
+ pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel1.packet[1%PacketSize]));
+ pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[2%PacketSize]));
+ pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel1.packet[2%PacketSize]));
+ pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[3%PacketSize]));
+ pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel1.packet[3%PacketSize]));
+ count+=8*PacketSize;
+ }
+ else if (PacketSize == 8)
+ {
+ PacketBlock<Packet, PacketSize==8?8:PacketSize> kernel0;
+
+ kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
+ kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
+ kernel0.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
+ kernel0.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
+ kernel0.packet[4%PacketSize] = dm4.template loadPacket<Packet>(k);
+ kernel0.packet[5%PacketSize] = dm5.template loadPacket<Packet>(k);
+ kernel0.packet[6%PacketSize] = dm6.template loadPacket<Packet>(k);
+ kernel0.packet[7%PacketSize] = dm7.template loadPacket<Packet>(k);
+ ptranspose(kernel0);
+
+ pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize]));
+ pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel0.packet[1%PacketSize]));
+ pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[2%PacketSize]));
+ pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel0.packet[3%PacketSize]));
+ pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[4%PacketSize]));
+ pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel0.packet[5%PacketSize]));
+ pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[6%PacketSize]));
+ pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel0.packet[7%PacketSize]));
+ count+=8*PacketSize;
+ }
+ }
+ }
+
+ for(; k<depth; k++)
+ {
+ blockB[count+0] = cj(dm0(k));
+ blockB[count+1] = cj(dm1(k));
+ blockB[count+2] = cj(dm2(k));
+ blockB[count+3] = cj(dm3(k));
+ blockB[count+4] = cj(dm4(k));
+ blockB[count+5] = cj(dm5(k));
+ blockB[count+6] = cj(dm6(k));
+ blockB[count+7] = cj(dm7(k));
+ count += 8;
+ }
+ // skip what we have after
+ if(PanelMode) count += 8 * (stride-offset-depth);
+ }
+ }
if(nr>=4)
{
@@ -2522,39 +3143,41 @@
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
- // if(nr>=8)
- // {
- // for(Index j2=0; j2<packet_cols8; j2+=8)
- // {
- // // skip what we have before
- // if(PanelMode) count += 8 * offset;
- // for(Index k=0; k<depth; k++)
- // {
- // if (PacketSize==8) {
- // Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
- // pstoreu(blockB+count, cj.pconj(A));
- // } else if (PacketSize==4) {
- // Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
- // Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
- // pstoreu(blockB+count, cj.pconj(A));
- // pstoreu(blockB+count+PacketSize, cj.pconj(B));
- // } else {
- // const Scalar* b0 = &rhs[k*rhsStride + j2];
- // blockB[count+0] = cj(b0[0]);
- // blockB[count+1] = cj(b0[1]);
- // blockB[count+2] = cj(b0[2]);
- // blockB[count+3] = cj(b0[3]);
- // blockB[count+4] = cj(b0[4]);
- // blockB[count+5] = cj(b0[5]);
- // blockB[count+6] = cj(b0[6]);
- // blockB[count+7] = cj(b0[7]);
- // }
- // count += 8;
- // }
- // // skip what we have after
- // if(PanelMode) count += 8 * (stride-offset-depth);
- // }
- // }
+ if(nr>=8)
+ {
+ for(Index j2=0; j2<packet_cols8; j2+=8)
+ {
+ // skip what we have before
+ if(PanelMode) count += 8 * offset;
+ for(Index k=0; k<depth; k++)
+ {
+ if (PacketSize==8) {
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
+ pstoreu(blockB+count, cj.pconj(A));
+ count += PacketSize;
+ } else if (PacketSize==4) {
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
+ Packet B = rhs.template loadPacket<Packet>(k, j2 + 4);
+ pstoreu(blockB+count, cj.pconj(A));
+ pstoreu(blockB+count+PacketSize, cj.pconj(B));
+ count += 2*PacketSize;
+ } else {
+ const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
+ blockB[count+0] = cj(dm0(0));
+ blockB[count+1] = cj(dm0(1));
+ blockB[count+2] = cj(dm0(2));
+ blockB[count+3] = cj(dm0(3));
+ blockB[count+4] = cj(dm0(4));
+ blockB[count+5] = cj(dm0(5));
+ blockB[count+6] = cj(dm0(6));
+ blockB[count+7] = cj(dm0(7));
+ count += 8;
+ }
+ }
+ // skip what we have after
+ if(PanelMode) count += 8 * (stride-offset-depth);
+ }
+ }
if(nr>=4)
{
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)