Adding lowlevel APIs for optimized RHS packet load in TensorFlow
 SpatialConvolution

Low-level APIs are added in order to optimized packet load in gemm_pack_rhs
in TensorFlow SpatialConvolution. The optimization is for scenario when a
packet is split across 2 adjacent columns. In this case we read it as two
'partial' packets and then merge these into 1. Currently this only works for
Packet16f (AVX512) and Packet8f (AVX2). We plan to add this for other
packet types (such as Packet8d) also.

This optimization shows significant speedup in SpatialConvolution with
certain parameters. Some examples are below.

Benchmark parameters are specified as:
Batch size, Input dim, Depth, Num of filters, Filter dim

Speedup numbers are specified for number of threads 1, 2, 4, 8, 16.

AVX512:

Parameters                  | Speedup (Num of threads: 1, 2, 4, 8, 16)
----------------------------|------------------------------------------
128,   24x24,  3, 64,   5x5 |2.18X, 2.13X, 1.73X, 1.64X, 1.66X
128,   24x24,  1, 64,   8x8 |2.00X, 1.98X, 1.93X, 1.91X, 1.91X
 32,   24x24,  3, 64,   5x5 |2.26X, 2.14X, 2.17X, 2.22X, 2.33X
128,   24x24,  3, 64,   3x3 |1.51X, 1.45X, 1.45X, 1.67X, 1.57X
 32,   14x14, 24, 64,   5x5 |1.21X, 1.19X, 1.16X, 1.70X, 1.17X
128, 128x128,  3, 96, 11x11 |2.17X, 2.18X, 2.19X, 2.20X, 2.18X

AVX2:

Parameters                  | Speedup (Num of threads: 1, 2, 4, 8, 16)
----------------------------|------------------------------------------
128,   24x24,  3, 64,   5x5 | 1.66X, 1.65X, 1.61X, 1.56X, 1.49X
 32,   24x24,  3, 64,   5x5 | 1.71X, 1.63X, 1.77X, 1.58X, 1.68X
128,   24x24,  1, 64,   5x5 | 1.44X, 1.40X, 1.38X, 1.37X, 1.33X
128,   24x24,  3, 64,   3x3 | 1.68X, 1.63X, 1.58X, 1.56X, 1.62X
128, 128x128,  3, 96, 11x11 | 1.36X, 1.36X, 1.37X, 1.37X, 1.37X

In the higher level benchmark cifar10, we observe a runtime improvement
of around 6% for AVX512 on Intel Skylake server (8 cores).

On lower level PackRhs micro-benchmarks specified in TensorFlow
tensorflow/core/kernels/eigen_spatial_convolutions_test.cc, we observe
the following runtime numbers:

AVX512:

Parameters                                                     | Runtime without patch (ns) | Runtime with patch (ns) | Speedup
---------------------------------------------------------------|----------------------------|-------------------------|---------
BM_RHS_NAME(PackRhs, 128, 24, 24, 3, 64, 5, 5, 1, 1, 256, 56)  |  41350                     | 15073                   | 2.74X
BM_RHS_NAME(PackRhs, 32, 64, 64, 32, 64, 5, 5, 1, 1, 256, 56)  |   7277                     |  7341                   | 0.99X
BM_RHS_NAME(PackRhs, 32, 64, 64, 32, 64, 5, 5, 2, 2, 256, 56)  |   8675                     |  8681                   | 1.00X
BM_RHS_NAME(PackRhs, 32, 64, 64, 30, 64, 5, 5, 1, 1, 256, 56)  |  24155                     | 16079                   | 1.50X
BM_RHS_NAME(PackRhs, 32, 64, 64, 30, 64, 5, 5, 2, 2, 256, 56)  |  25052                     | 17152                   | 1.46X
BM_RHS_NAME(PackRhs, 32, 256, 256, 4, 16, 8, 8, 1, 1, 256, 56) |  18269                     | 18345                   | 1.00X
BM_RHS_NAME(PackRhs, 32, 256, 256, 4, 16, 8, 8, 2, 4, 256, 56) |  19468                     | 19872                   | 0.98X
BM_RHS_NAME(PackRhs, 32, 64, 64, 4, 16, 3, 3, 1, 1, 36, 432)   | 156060                     | 42432                   | 3.68X
BM_RHS_NAME(PackRhs, 32, 64, 64, 4, 16, 3, 3, 2, 2, 36, 432)   | 132701                     | 36944                   | 3.59X

AVX2:

Parameters                                                     | Runtime without patch (ns) | Runtime with patch (ns) | Speedup
---------------------------------------------------------------|----------------------------|-------------------------|---------
BM_RHS_NAME(PackRhs, 128, 24, 24, 3, 64, 5, 5, 1, 1, 256, 56)  | 26233                      | 12393                   | 2.12X
BM_RHS_NAME(PackRhs, 32, 64, 64, 32, 64, 5, 5, 1, 1, 256, 56)  |  6091                      |  6062                   | 1.00X
BM_RHS_NAME(PackRhs, 32, 64, 64, 32, 64, 5, 5, 2, 2, 256, 56)  |  7427                      |  7408                   | 1.00X
BM_RHS_NAME(PackRhs, 32, 64, 64, 30, 64, 5, 5, 1, 1, 256, 56)  | 23453                      | 20826                   | 1.13X
BM_RHS_NAME(PackRhs, 32, 64, 64, 30, 64, 5, 5, 2, 2, 256, 56)  | 23167                      | 22091                   | 1.09X
BM_RHS_NAME(PackRhs, 32, 256, 256, 4, 16, 8, 8, 1, 1, 256, 56) | 23422                      | 23682                   | 0.99X
BM_RHS_NAME(PackRhs, 32, 256, 256, 4, 16, 8, 8, 2, 4, 256, 56) | 23165                      | 23663                   | 0.98X
BM_RHS_NAME(PackRhs, 32, 64, 64, 4, 16, 3, 3, 1, 1, 36, 432)   | 72689                      | 44969                   | 1.62X
BM_RHS_NAME(PackRhs, 32, 64, 64, 4, 16, 3, 3, 2, 2, 36, 432)   | 61732                      | 39779                   | 1.55X

All benchmarks on Intel Skylake server with 8 cores.
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 04a321b..3bba022 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -287,6 +287,14 @@
 template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
 ploadu(const typename unpacket_traits<Packet>::type* from) { return *from; }
 
+/** \internal \returns a packet version of \a *from, (un-aligned masked load)
+ * There is no generic implementation. We only have implementations for specialized
+ * cases. Generic case should not be called.
+ */
+template<typename Packet> EIGEN_DEVICE_FUNC inline
+typename enable_if<unpacket_traits<Packet>::masked_load_available, Packet>::type
+ploadu(const typename unpacket_traits<Packet>::type* from, typename unpacket_traits<Packet>::mask_t umask);
+
 /** \internal \returns a packet with constant coefficients \a a, e.g.: (a,a,a,a) */
 template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
 pset1(const typename unpacket_traits<Packet>::type& a) { return a; }
diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h
index 5b8ff59..3d229cd 100644
--- a/Eigen/src/Core/arch/AVX/Complex.h
+++ b/Eigen/src/Core/arch/AVX/Complex.h
@@ -47,7 +47,7 @@
 };
 #endif
 
-template<> struct unpacket_traits<Packet4cf> { typedef std::complex<float> type; enum {size=4, alignment=Aligned32, vectorizable=true}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet4cf> { typedef std::complex<float> type; enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false}; typedef Packet2cf half; };
 
 template<> EIGEN_STRONG_INLINE Packet4cf padd<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); }
 template<> EIGEN_STRONG_INLINE Packet4cf psub<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); }
@@ -263,7 +263,7 @@
 };
 #endif
 
-template<> struct unpacket_traits<Packet2cd> { typedef std::complex<double> type; enum {size=2, alignment=Aligned32, vectorizable=true}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet2cd> { typedef std::complex<double> type; enum {size=2, alignment=Aligned32, vectorizable=true, masked_load_available=false}; typedef Packet1cd half; };
 
 template<> EIGEN_STRONG_INLINE Packet2cd padd<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); }
 template<> EIGEN_STRONG_INLINE Packet2cd psub<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); }
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index f88e360..3f94f85 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -117,14 +117,15 @@
   typedef float     type;
   typedef Packet4f  half;
   typedef Packet8i  integer_packet;
-  enum {size=8, alignment=Aligned32, vectorizable=true};
+  typedef uint8_t   mask_t;
+  enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true};
 };
 template<> struct unpacket_traits<Packet4d> {
   typedef double type;
   typedef Packet2d half;
-  enum {size=4, alignment=Aligned32, vectorizable=true};
+  enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false};
 };
-template<> struct unpacket_traits<Packet8i> { typedef int    type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false}; };
+template<> struct unpacket_traits<Packet8i> { typedef int    type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false}; };
 
 template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float&  from) { return _mm256_set1_ps(from); }
 template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
@@ -385,6 +386,14 @@
 template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_pd(from); }
 template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
 
+template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
+  __m256i mask = _mm256_set1_epi8(static_cast<char>(umask));
+  const __m256i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
+  mask = _mm256_or_si256(mask, bit_mask);
+  mask = _mm256_cmpeq_epi32(mask, _mm256_set1_epi32(0xffffffff));
+  EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask);
+}
+
 // Loads 4 floats from memory a returns the packet {a0, a0  a1, a1, a2, a2, a3, a3}
 template<> EIGEN_STRONG_INLINE Packet8f ploaddup<Packet8f>(const float* from)
 {
diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h
index 9a89dd0..5ab2ffe 100644
--- a/Eigen/src/Core/arch/AVX512/Complex.h
+++ b/Eigen/src/Core/arch/AVX512/Complex.h
@@ -51,7 +51,8 @@
   enum {
     size = 8,
     alignment=unpacket_traits<Packet16f>::alignment,
-    vectorizable=true
+    vectorizable=true,
+    masked_load_available=false
   };
   typedef Packet4cf half;
 };
@@ -247,7 +248,8 @@
   enum {
     size = 4,
     alignment = unpacket_traits<Packet8d>::alignment,
-    vectorizable=true
+    vectorizable=true,
+    masked_load_available=false
   };
   typedef Packet2cd half;
 };
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 60b723b..094309e 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -102,19 +102,20 @@
   typedef float type;
   typedef Packet8f half;
   typedef Packet16i integer_packet;
-  enum { size = 16, alignment=Aligned64, vectorizable=true };
+  typedef uint16_t mask_t;
+  enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true };
 };
 template <>
 struct unpacket_traits<Packet8d> {
   typedef double type;
   typedef Packet4d half;
-  enum { size = 8, alignment=Aligned64, vectorizable=true };
+  enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false };
 };
 template <>
 struct unpacket_traits<Packet16i> {
   typedef int type;
   typedef Packet8i half;
-  enum { size = 16, alignment=Aligned64, vectorizable=false };
+  enum { size = 16, alignment=Aligned64, vectorizable=false, masked_load_available=false };
 };
 
 template <>
@@ -485,6 +486,12 @@
       reinterpret_cast<const __m512i*>(from));
 }
 
+template <>
+EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
+  __mmask16 mask = static_cast<__mmask16>(umask);
+  EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
+}
+
 // Loads 8 floats from memory a returns the packet
 // {a0, a0  a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
 template <>
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
index 0c9f8e0..ebc3b2a 100644
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
@@ -60,7 +60,7 @@
   };
 };
 
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2cf half; };
 
 template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>&  from)
 {
@@ -291,7 +291,7 @@
   };
 };
 
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet1cd half; };
 
 template<> EIGEN_STRONG_INLINE Packet1cd pload <Packet1cd>(const std::complex<double>* from) { return Packet1cd(pload<Packet2d>((const double*)from)); }
 template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { return Packet1cd(ploadu<Packet2d>((const double*)from)); }
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index 9535724..b5484e6 100755
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -192,13 +192,13 @@
   typedef float     type;
   typedef Packet4f  half;
   typedef Packet4i  integer_packet;
-  enum {size=4, alignment=Aligned16, vectorizable=true};
+  enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false};
 };
 template<> struct unpacket_traits<Packet4i>
 {
   typedef int       type;
   typedef Packet4i  half;
-  enum {size=4, alignment=Aligned16, vectorizable=false};
+  enum {size=4, alignment=Aligned16, vectorizable=false, masked_load_available=false};
 };
 
 inline std::ostream & operator <<(std::ostream & s, const Packet16uc & v)
@@ -921,7 +921,7 @@
   };
 };
 
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2d half; };
 
 inline std::ostream & operator <<(std::ostream & s, const Packet2l & v)
 {
diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h
index cd4615a..7fac0a5 100644
--- a/Eigen/src/Core/arch/GPU/PacketMath.h
+++ b/Eigen/src/Core/arch/GPU/PacketMath.h
@@ -92,8 +92,8 @@
 };
 
 
-template<> struct unpacket_traits<float4>  { typedef float  type; enum {size=4, alignment=Aligned16, vectorizable=true}; typedef float4 half; };
-template<> struct unpacket_traits<double2> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef double2 half; };
+template<> struct unpacket_traits<float4>  { typedef float  type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef float4 half; };
+template<> struct unpacket_traits<double2> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef double2 half; };
 
 template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pset1<float4>(const float&  from) {
   return make_float4(from, from, from, from);
diff --git a/Eigen/src/Core/arch/GPU/PacketMathHalf.h b/Eigen/src/Core/arch/GPU/PacketMathHalf.h
index cd518c7..2bee56f 100644
--- a/Eigen/src/Core/arch/GPU/PacketMathHalf.h
+++ b/Eigen/src/Core/arch/GPU/PacketMathHalf.h
@@ -42,7 +42,7 @@
   };
 };
 
-template<> struct unpacket_traits<half2> { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef half2 half; };
+template<> struct unpacket_traits<half2> { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef half2 half; };
 
 template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1<half2>(const Eigen::half& from) {
   return __half2half2(from);
@@ -567,7 +567,7 @@
 };
 
 
-template<> struct unpacket_traits<Packet16h> { typedef Eigen::half type; enum {size=16, alignment=Aligned32, vectorizable=true}; typedef Packet16h half; };
+template<> struct unpacket_traits<Packet16h> { typedef Eigen::half type; typedef uint16_t mask_t; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=true}; typedef Packet16h half; };
 
 template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
   Packet16h result;
@@ -591,6 +591,14 @@
   return result;
 }
 
+template<> EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from,
+                                                           uint16_t umask) {
+  __mmask16 mask = static_cast<__mmask16>(umask);
+  Packet16h result;
+  result.x = _mm256_maskz_loadu_epi16(mask, from);
+  return result;
+}
+
 template<> EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
   // (void*) -> workaround clang warning:
   // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
@@ -1056,7 +1064,7 @@
 };
 
 
-template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true}; typedef Packet8h half; };
+template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet8h half; };
 
 template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
   Packet8h result;
@@ -1419,7 +1427,7 @@
 };
 
 
-template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16, vectorizable=true}; typedef Packet4h half; };
+template<> struct unpacket_traits<Packet4h> { typedef Eigen::half type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet4h half; };
 
 template<> EIGEN_STRONG_INLINE Packet4h pset1<Packet4h>(const Eigen::half& from) {
   Packet4h result;
diff --git a/Eigen/src/Core/arch/MSA/Complex.h b/Eigen/src/Core/arch/MSA/Complex.h
index fa64d35..0ced061 100644
--- a/Eigen/src/Core/arch/MSA/Complex.h
+++ b/Eigen/src/Core/arch/MSA/Complex.h
@@ -127,7 +127,7 @@
 template <>
 struct unpacket_traits<Packet2cf> {
   typedef std::complex<float> type;
-  enum { size = 2, alignment = Aligned16, vectorizable=true };
+  enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false };
   typedef Packet2cf half;
 };
 
@@ -500,7 +500,7 @@
 template <>
 struct unpacket_traits<Packet1cd> {
   typedef std::complex<double> type;
-  enum { size = 1, alignment = Aligned16, vectorizable=true };
+  enum { size = 1, alignment = Aligned16, vectorizable=true, masked_load_available=false };
   typedef Packet1cd half;
 };
 
diff --git a/Eigen/src/Core/arch/MSA/PacketMath.h b/Eigen/src/Core/arch/MSA/PacketMath.h
index a97156a..f426d5b 100644
--- a/Eigen/src/Core/arch/MSA/PacketMath.h
+++ b/Eigen/src/Core/arch/MSA/PacketMath.h
@@ -117,14 +117,14 @@
 template <>
 struct unpacket_traits<Packet4f> {
   typedef float type;
-  enum { size = 4, alignment = Aligned16, vectorizable=true };
+  enum { size = 4, alignment = Aligned16, vectorizable=true, masked_load_available=false };
   typedef Packet4f half;
 };
 
 template <>
 struct unpacket_traits<Packet4i> {
   typedef int32_t type;
-  enum { size = 4, alignment = Aligned16, vectorizable=true };
+  enum { size = 4, alignment = Aligned16, vectorizable=true, masked_load_available=false };
   typedef Packet4i half;
 };
 
@@ -925,7 +925,7 @@
 template <>
 struct unpacket_traits<Packet2d> {
   typedef double type;
-  enum { size = 2, alignment = Aligned16, vectorizable=true };
+  enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false };
   typedef Packet2d half;
 };
 
diff --git a/Eigen/src/Core/arch/NEON/Complex.h b/Eigen/src/Core/arch/NEON/Complex.h
index f6c5c21..c17d0a0 100644
--- a/Eigen/src/Core/arch/NEON/Complex.h
+++ b/Eigen/src/Core/arch/NEON/Complex.h
@@ -62,7 +62,7 @@
   };
 };
 
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2cf half; };
 
 template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>&  from)
 {
@@ -340,7 +340,7 @@
   };
 };
 
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet1cd half; };
 
 template<> EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>((const double*)from)); }
 template<> EIGEN_STRONG_INLINE Packet1cd ploadu<Packet1cd>(const std::complex<double>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu<Packet2d>((const double*)from)); }
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index e8b3518..b8051cf 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -145,13 +145,13 @@
   typedef float     type;
   typedef Packet4f  half;
   typedef Packet4i  integer_packet;
-  enum {size=4, alignment=Aligned16, vectorizable=true};
+  enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false};
 };
 template<> struct unpacket_traits<Packet4i>
 {
   typedef int32_t   type;
   typedef Packet4i  half;
-  enum {size=4, alignment=Aligned16, vectorizable=true};
+  enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false};
 };
 
 template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float&  from) { return vdupq_n_f32(from); }
@@ -657,7 +657,7 @@
   };
 };
 
-template<> struct unpacket_traits<Packet2d> { typedef double  type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet2d> { typedef double  type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2d half; };
 
 template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double&  from) { return vdupq_n_f64(from); }
 
diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h
index f39988e..7d89c32 100644
--- a/Eigen/src/Core/arch/SSE/Complex.h
+++ b/Eigen/src/Core/arch/SSE/Complex.h
@@ -50,7 +50,7 @@
 };
 #endif
 
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2cf half; };
 
 template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_add_ps(a.v,b.v)); }
 template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_sub_ps(a.v,b.v)); }
@@ -283,7 +283,7 @@
 };
 #endif
 
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet1cd half; };
 
 template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_add_pd(a.v,b.v)); }
 template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_sub_pd(a.v,b.v)); }
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 9c3750a..04b6360 100755
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -166,17 +166,17 @@
   typedef float     type;
   typedef Packet4f  half;
   typedef Packet4i  integer_packet;
-  enum {size=4, alignment=Aligned16, vectorizable=true};
+  enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false};
 };
 template<> struct unpacket_traits<Packet2d> {
   typedef double    type;
   typedef Packet2d  half;
-  enum {size=2, alignment=Aligned16, vectorizable=true};
+  enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false};
 };
 template<> struct unpacket_traits<Packet4i> {
   typedef int       type;
   typedef Packet4i  half;
-  enum {size=4, alignment=Aligned16, vectorizable=false};
+  enum {size=4, alignment=Aligned16, vectorizable=false, masked_load_available=false};
 };
 
 #ifndef EIGEN_VECTORIZE_AVX
diff --git a/Eigen/src/Core/arch/SYCL/InteropHeaders.h b/Eigen/src/Core/arch/SYCL/InteropHeaders.h
index 294cb10..1afa63b 100644
--- a/Eigen/src/Core/arch/SYCL/InteropHeaders.h
+++ b/Eigen/src/Core/arch/SYCL/InteropHeaders.h
@@ -88,7 +88,7 @@
 #define SYCL_UNPACKET_TRAITS(packet_type, unpacket_type, lengths)\
 template<> struct unpacket_traits<packet_type>  {\
   typedef unpacket_type  type;\
-  enum {size=lengths, alignment=Aligned16, vectorizable=true};\
+  enum {size=lengths, alignment=Aligned16, vectorizable=true, masked_load_available=false};\
   typedef packet_type half;\
 };
 SYCL_UNPACKET_TRAITS(cl::sycl::cl_float4, float, 4)
diff --git a/Eigen/src/Core/arch/ZVector/Complex.h b/Eigen/src/Core/arch/ZVector/Complex.h
index 167c3ee..9fcbcb8 100644
--- a/Eigen/src/Core/arch/ZVector/Complex.h
+++ b/Eigen/src/Core/arch/ZVector/Complex.h
@@ -91,8 +91,8 @@
   };
 };
 
-template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float>  type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2cf half; };
-template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true}; typedef Packet1cd half; };
+template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float>  type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2cf half; };
+template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet1cd half; };
 
 /* Forward declaration */
 EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel);
diff --git a/Eigen/src/Core/arch/ZVector/PacketMath.h b/Eigen/src/Core/arch/ZVector/PacketMath.h
index c8e90f1..74e0a13 100755
--- a/Eigen/src/Core/arch/ZVector/PacketMath.h
+++ b/Eigen/src/Core/arch/ZVector/PacketMath.h
@@ -239,9 +239,9 @@
   };
 };
 
-template<> struct unpacket_traits<Packet4i> { typedef int    type; enum {size=4, alignment=Aligned16, vectorizable=true}; typedef Packet4i half; };
-template<> struct unpacket_traits<Packet4f> { typedef float  type; enum {size=4, alignment=Aligned16, vectorizable=true}; typedef Packet4f half; };
-template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true}; typedef Packet2d half; };
+template<> struct unpacket_traits<Packet4i> { typedef int    type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet4i half; };
+template<> struct unpacket_traits<Packet4f> { typedef float  type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet4f half; };
+template<> struct unpacket_traits<Packet2d> { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false}; typedef Packet2d half; };
 
 /* Forward declaration */
 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4f,4>& kernel);
diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h
index 91c2e42..ce01994 100644
--- a/Eigen/src/Core/util/XprHelper.h
+++ b/Eigen/src/Core/util/XprHelper.h
@@ -185,7 +185,8 @@
   {
     size = 1,
     alignment = 1,
-    vectorizable = false
+    vectorizable = false,
+    masked_load_available=false
   };
 };
 
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index 4906f6e..200670b 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -119,6 +119,9 @@
   inline Packet load(const T* from) const { return internal::pload<Packet>(from); }
 
   template<typename T>
+  inline Packet load(const T* from, unsigned long long umask) const { return internal::ploadu<Packet>(from, umask); }
+
+  template<typename T>
   inline void store(T* to, const Packet& x) const { internal::pstore(to,x); }
 };
 
@@ -129,6 +132,9 @@
   inline T load(const T* from) const { return *from; }
 
   template<typename T>
+  inline T load(const T* from, unsigned long long) const { return *from; }
+
+  template<typename T>
   inline void store(T* to, const T& x) const { *to = x; }
 };
 
@@ -169,6 +175,7 @@
   const int size = PacketSize*max_size;
   EIGEN_ALIGN_MAX Scalar data1[size];
   EIGEN_ALIGN_MAX Scalar data2[size];
+  EIGEN_ALIGN_MAX Scalar data3[size];
   EIGEN_ALIGN_MAX Packet packets[PacketSize*2];
   EIGEN_ALIGN_MAX Scalar ref[size];
   RealScalar refvalue = RealScalar(0);
@@ -194,6 +201,22 @@
     VERIFY(areApprox(data1, data2+offset, PacketSize) && "internal::pstoreu");
   }
 
+  if (internal::unpacket_traits<Packet>::masked_load_available)
+  {
+    unsigned long long max_umask = (0x1ull << PacketSize);
+    for (int offset=0; offset<PacketSize; ++offset)
+    {
+      for (unsigned long long umask=0; umask<max_umask; ++umask)
+      {
+        packet_helper<internal::unpacket_traits<Packet>::masked_load_available, Packet> h;
+        h.store(data2, h.load(data1+offset, umask));
+        for (int k=0; k<PacketSize; ++k)
+          data3[k] = ((umask & ( 0x1ull << k )) >> k) ? data1[k+offset] : Scalar(0);
+        VERIFY(areApprox(data3, data2, PacketSize) && "internal::ploadu masked");
+      }
+    }
+  }
+
   for (int offset=0; offset<PacketSize; ++offset)
   {
     #define MIN(A,B) (A<B?A:B)
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 4ca6b3d..05c684f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -95,6 +95,18 @@
     return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
   }
 
+  // Return a packet starting at `index` where `umask` specifies which elements
+  // have to be loaded. Type/size of mask depends on PacketReturnType, e.g. for
+  // Packet16f, `umask` is of type uint16_t and if a bit is 1, corresponding
+  // float element will be loaded, otherwise 0 will be loaded.
+  // Function has been templatized to enable Sfinae.
+  template <typename PacketReturnTypeT = PacketReturnType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+  typename internal::enable_if<internal::unpacket_traits<PacketReturnTypeT>::masked_load_available, PacketReturnTypeT>::type
+  partialPacket(Index index, typename internal::unpacket_traits<PacketReturnTypeT>::mask_t umask) const
+  {
+    return internal::ploadu<PacketReturnTypeT>(m_data + index, umask);
+  }
+
   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
   void writePacket(Index index, const PacketReturnType& x)
   {
@@ -244,6 +256,18 @@
     return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index);
   }
 
+  // Return a packet starting at `index` where `umask` specifies which elements
+  // have to be loaded. Type/size of mask depends on PacketReturnType, e.g. for
+  // Packet16f, `umask` is of type uint16_t and if a bit is 1, corresponding
+  // float element will be loaded, otherwise 0 will be loaded.
+  // Function has been templatized to enable Sfinae.
+  template <typename PacketReturnTypeT = PacketReturnType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+  typename internal::enable_if<internal::unpacket_traits<PacketReturnTypeT>::masked_load_available, PacketReturnTypeT>::type
+  partialPacket(Index index, typename internal::unpacket_traits<PacketReturnTypeT>::mask_t umask) const
+  {
+    return internal::ploadu<PacketReturnTypeT>(m_data + index, umask);
+  }
+
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
     eigen_assert(m_data);
     const Index index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_dims.IndexOfColMajor(coords)