Add async support for 'chip' and 'extract_volume_patches'
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h index 32980c7..000b1fb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h
@@ -182,6 +182,13 @@ return true; } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType /*data*/, EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); + } +#endif // EIGEN_USE_THREADS + EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h b/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h index 75063f5..d8faa4d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h
@@ -365,6 +365,13 @@ return true; } +#ifdef EIGEN_USE_THREADS + template <typename EvalSubExprsCallback> + EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType /*data*/, EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); + } +#endif // EIGEN_USE_THREADS + EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 8961c84..a566d7e 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp
@@ -80,6 +80,86 @@ } } +void test_multithread_chip() { + Tensor<float, 5> in(2, 3, 5, 7, 11); + Tensor<float, 4> out(3, 5, 7, 11); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random<int>(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11)); + + out.device(thread_pool_device) = in.chip(1, 0); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(out(i, j, k, l), in(1, i, j, k, l)); + } + } + } + } +} + +void test_async_multithread_chip() { + Tensor<float, 5> in(2, 3, 5, 7, 11); + Tensor<float, 4> out(3, 5, 7, 11); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random<int>(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11)); + + Eigen::Barrier b(1); + out.device(thread_pool_device, [&b]() { b.Notify(); }) = in.chip(1, 0); + b.Wait(); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(out(i, j, k, l), in(1, i, j, k, l)); + } + } + } + } +} + +void test_multithread_volume_patch() { + Tensor<float, 5> in(4, 2, 3, 5, 7); + Tensor<float, 6> out(4, 1, 1, 1, 2 * 3 * 5, 7); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random<int>(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11)); + + out.device(thread_pool_device) = in.extract_volume_patches(1, 1, 1); + + for (int i = 0; i < in.size(); ++i) { + VERIFY_IS_EQUAL(in.data()[i], out.data()[i]); + } +} + +void test_async_multithread_volume_patch() { + Tensor<float, 5> in(4, 2, 3, 5, 7); + Tensor<float, 6> out(4, 1, 1, 1, 2 * 3 * 5, 7); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random<int>(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11)); + + Eigen::Barrier b(1); + out.device(thread_pool_device, [&b]() { b.Notify(); }) = in.extract_volume_patches(1, 1, 1); + b.Wait(); + + for (int i = 0; i < in.size(); ++i) { + VERIFY_IS_EQUAL(in.data()[i], out.data()[i]); + } +} + void test_multithread_compound_assignment() { Tensor<float, 3> in1(2, 3, 7); Tensor<float, 3> in2(2, 3, 7); @@ -648,43 +728,49 @@ CALL_SUBTEST_2(test_multithread_contraction<ColMajor>()); CALL_SUBTEST_2(test_multithread_contraction<RowMajor>()); - CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>()); - CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>()); - CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>()); - CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>()); + CALL_SUBTEST_3(test_multithread_chip()); + CALL_SUBTEST_3(test_async_multithread_chip()); - CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>()); - CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>()); + CALL_SUBTEST_4(test_multithread_volume_patch()); + CALL_SUBTEST_4(test_async_multithread_volume_patch()); + + CALL_SUBTEST_5(test_multithread_contraction_agrees_with_singlethread<ColMajor>()); + CALL_SUBTEST_5(test_multithread_contraction_agrees_with_singlethread<RowMajor>()); + CALL_SUBTEST_5(test_multithread_contraction_with_output_kernel<ColMajor>()); + CALL_SUBTEST_5(test_multithread_contraction_with_output_kernel<RowMajor>()); + + CALL_SUBTEST_6(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>()); + CALL_SUBTEST_6(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>()); // Test EvalShardedByInnerDimContext parallelization strategy. - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<ColMajor>()); - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<RowMajor>()); - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>()); - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction<ColMajor>()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction<RowMajor>()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<ColMajor>()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<RowMajor>()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction<ColMajor>()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction<RowMajor>()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>()); // Exercise various cases that have been problematic in the past. - CALL_SUBTEST_7(test_contraction_corner_cases<ColMajor>()); - CALL_SUBTEST_7(test_contraction_corner_cases<RowMajor>()); + CALL_SUBTEST_9(test_contraction_corner_cases<ColMajor>()); + CALL_SUBTEST_9(test_contraction_corner_cases<RowMajor>()); - CALL_SUBTEST_8(test_full_contraction<ColMajor>()); - CALL_SUBTEST_8(test_full_contraction<RowMajor>()); + CALL_SUBTEST_10(test_full_contraction<ColMajor>()); + CALL_SUBTEST_10(test_full_contraction<RowMajor>()); - CALL_SUBTEST_9(test_multithreaded_reductions<ColMajor>()); - CALL_SUBTEST_9(test_multithreaded_reductions<RowMajor>()); + CALL_SUBTEST_11(test_multithreaded_reductions<ColMajor>()); + CALL_SUBTEST_11(test_multithreaded_reductions<RowMajor>()); - CALL_SUBTEST_10(test_memcpy()); - CALL_SUBTEST_10(test_multithread_random()); + CALL_SUBTEST_12(test_memcpy()); + CALL_SUBTEST_12(test_multithread_random()); TestAllocator test_allocator; - CALL_SUBTEST_11(test_multithread_shuffle<ColMajor>(NULL)); - CALL_SUBTEST_11(test_multithread_shuffle<RowMajor>(&test_allocator)); - CALL_SUBTEST_11(test_threadpool_allocate(&test_allocator)); + CALL_SUBTEST_13(test_multithread_shuffle<ColMajor>(NULL)); + CALL_SUBTEST_13(test_multithread_shuffle<RowMajor>(&test_allocator)); + CALL_SUBTEST_13(test_threadpool_allocate(&test_allocator)); // Force CMake to split this test. - // EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11 + // EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11;12;13 }