TensorMorphing: re-enable BlockAccess for slicing of bool tensors

libeigen/eigen!2452

Co-authored-by: Rasmus Munk Larsen <rmlarsen@gmail.com>
diff --git a/unsupported/Eigen/src/Tensor/TensorMorphing.h b/unsupported/Eigen/src/Tensor/TensorMorphing.h
index 7a5c4b0..1ccf414 100644
--- a/unsupported/Eigen/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/src/Tensor/TensorMorphing.h
@@ -367,9 +367,7 @@
     // slice offsets and sizes.
     IsAligned = false,
     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
-    BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess &&
-                  // FIXME: Temporary workaround for bug in slicing of bool tensors.
-                  !internal::is_same<std::remove_const_t<Scalar>, bool>::value,
+    BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
     PreferBlockAccess = true,
     CoordAccess = false,
     RawAccess = false
diff --git a/unsupported/test/tensor_block_eval.cpp b/unsupported/test/tensor_block_eval.cpp
index cbd72a4..aecd7b2 100644
--- a/unsupported/test/tensor_block_eval.cpp
+++ b/unsupported/test/tensor_block_eval.cpp
@@ -459,6 +459,40 @@
                                            [&slice_size]() { return RandomBlock<Layout>(slice_size, 1, 10); });
 }
 
+// Exercise the block evaluator for bool slices as a sub-expression of a
+// block-aware parent op at sizes spanning Packet16b (16-lane bool packet)
+// boundaries. Before the BlockAccess fix for bool slicing, the parent op's
+// BlockAccess would be forced to false and this composition would never
+// dispatch through the block path.
+template <int NumDims, int Layout>
+static void test_eval_tensor_slice_bool_composite() {
+  const Index boundary_sizes[] = {15, 16, 17, 31, 32, 33, 47, 48, 49};
+  for (Index sz : boundary_sizes) {
+    DSizes<Index, NumDims> dims;
+    for (int i = 0; i < NumDims; ++i) dims[i] = sz;
+
+    Tensor<bool, NumDims, Layout> lhs(dims);
+    Tensor<bool, NumDims, Layout> rhs(dims);
+    lhs.setRandom();
+    rhs.setRandom();
+
+    // Slice skewed off the boundary so the block layout straddles packet
+    // boundaries of the underlying tensor.
+    DSizes<Index, NumDims> slice_start;
+    DSizes<Index, NumDims> slice_size;
+    for (int i = 0; i < NumDims; ++i) {
+      slice_start[i] = sz >= 2 ? 1 : 0;
+      slice_size[i] = sz - slice_start[i];
+    }
+
+    auto expr = lhs.slice(slice_start, slice_size) && rhs.slice(slice_start, slice_size);
+
+    VerifyBlockEvaluator<bool, NumDims, Layout>(expr, [&slice_size]() { return FixedSizeBlock(slice_size); });
+    VerifyBlockEvaluator<bool, NumDims, Layout>(expr,
+                                                [&slice_size, sz]() { return RandomBlock<Layout>(slice_size, 1, sz); });
+  }
+}
+
 template <typename T, int NumDims, int Layout>
 static void test_eval_tensor_shuffle() {
   DSizes<Index, NumDims> dims = RandomDims<NumDims>(5, 15);
@@ -788,6 +822,12 @@
   CALL_SUBTESTS_DIMS_LAYOUTS_TYPES(4, test_eval_tensor_generator);
   CALL_SUBTESTS_DIMS_LAYOUTS_TYPES(4, test_eval_tensor_reverse);
   CALL_SUBTESTS_DIMS_LAYOUTS_TYPES(5, test_eval_tensor_slice);
+  CALL_SUBTEST_PART(5)((test_eval_tensor_slice_bool_composite<1, RowMajor>()));
+  CALL_SUBTEST_PART(5)((test_eval_tensor_slice_bool_composite<2, RowMajor>()));
+  CALL_SUBTEST_PART(5)((test_eval_tensor_slice_bool_composite<3, RowMajor>()));
+  CALL_SUBTEST_PART(5)((test_eval_tensor_slice_bool_composite<1, ColMajor>()));
+  CALL_SUBTEST_PART(5)((test_eval_tensor_slice_bool_composite<2, ColMajor>()));
+  CALL_SUBTEST_PART(5)((test_eval_tensor_slice_bool_composite<3, ColMajor>()));
   CALL_SUBTESTS_DIMS_LAYOUTS_TYPES(5, test_eval_tensor_shuffle);
 
   CALL_SUBTESTS_LAYOUTS_TYPES(6, test_eval_tensor_reshape_with_bcast);
diff --git a/unsupported/test/tensor_executor.cpp b/unsupported/test/tensor_executor.cpp
index 5a74a04..f2fd2da 100644
--- a/unsupported/test/tensor_executor.cpp
+++ b/unsupported/test/tensor_executor.cpp
@@ -372,6 +372,45 @@
   }
 }
 
+// Regression test for BlockAccess=true on bool slice rvalue expressions.
+// Sweeps sizes spanning Packet16b (16-lane bool packet) boundaries so the
+// tiled executor is forced to cross packet boundaries inside a single slice.
+template <typename T, int NumDims, typename Device, bool Vectorizable, TiledEvaluation Tiling, int Layout>
+void test_execute_slice_rvalue_bool_boundaries(Device d) {
+  static_assert(std::is_same<T, bool>::value, "Only bool is supported.");
+  static_assert(NumDims >= 2, "NumDims must be greater or equal than 2");
+  static constexpr int Options = 0 | Layout;
+
+  const Index boundary_sizes[] = {15, 16, 17, 31, 32, 33, 47, 48, 49, 63, 64, 65};
+  for (Index sz : boundary_sizes) {
+    array<Index, NumDims> src_dims;
+    for (int i = 0; i < NumDims; ++i) src_dims[i] = sz;
+
+    Tensor<bool, NumDims, Options, Index> src(src_dims);
+    src.setRandom();
+
+    DSizes<Index, NumDims> slice_start;
+    DSizes<Index, NumDims> slice_size;
+    for (int i = 0; i < NumDims; ++i) {
+      slice_start[i] = sz >= 2 ? 1 : 0;
+      slice_size[i] = sz - slice_start[i];
+    }
+
+    Tensor<bool, NumDims, Options, Index> golden = src.slice(slice_start, slice_size);
+
+    Tensor<bool, NumDims, Options, Index> dst(golden.dimensions());
+    auto expr = src.slice(slice_start, slice_size);
+
+    using Assign = TensorAssignOp<decltype(dst), const decltype(expr)>;
+    using Executor = internal::TensorExecutor<const Assign, Device, Vectorizable, Tiling>;
+    Executor::run(Assign(dst, expr), d);
+
+    for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) {
+      VERIFY_IS_EQUAL(dst.coeff(i), golden.coeff(i));
+    }
+  }
+}
+
 template <typename T, int NumDims, typename Device, bool Vectorizable, TiledEvaluation Tiling, int Layout>
 void test_execute_slice_lvalue(Device d) {
   static_assert(NumDims >= 2, "NumDims must be greater or equal than 2");
@@ -669,6 +708,13 @@
   CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, float, 3);
   CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, float, 4);
   CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, float, 5);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, bool, 2);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, bool, 3);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, bool, 4);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue, bool, 5);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue_bool_boundaries, bool, 2);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue_bool_boundaries, bool, 3);
+  CALL_SUBTEST_COMBINATIONS(10, test_execute_slice_rvalue_bool_boundaries, bool, 4);
 
   CALL_SUBTEST_COMBINATIONS(11, test_execute_slice_lvalue, float, 2);
   CALL_SUBTEST_COMBINATIONS(11, test_execute_slice_lvalue, float, 3);