Added a test for shuffling
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h index c455300..15a22aa 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
@@ -67,7 +67,7 @@ : m_xpr(expr), m_shuffle(shuffle) {} EIGEN_DEVICE_FUNC - const Shuffle& shuffle() const { return m_shuffle; } + const Shuffle& shufflePermutation() const { return m_shuffle; } EIGEN_DEVICE_FUNC const typename internal::remove_all<typename XprType::Nested>::type& @@ -119,7 +119,7 @@ : m_impl(op.expression(), device) { const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); - const Shuffle& shuffle = op.shuffle(); + const Shuffle& shuffle = op.shufflePermutation(); for (int i = 0; i < NumDims; ++i) { m_dimensions[i] = input_dims[shuffle[i]]; }
diff --git a/unsupported/test/cxx11_tensor_shuffling.cpp b/unsupported/test/cxx11_tensor_shuffling.cpp index 2f7fd9e..d11444a 100644 --- a/unsupported/test/cxx11_tensor_shuffling.cpp +++ b/unsupported/test/cxx11_tensor_shuffling.cpp
@@ -176,12 +176,53 @@ } } + +template <int DataLayout> +static void test_shuffle_unshuffle() +{ + Tensor<float, 4, DataLayout> tensor(2,3,5,7); + tensor.setRandom(); + + // Choose a random permutation. + array<ptrdiff_t, 4> shuffles; + for (int i = 0; i < 4; ++i) { + shuffles[i] = i; + } + array<ptrdiff_t, 4> shuffles_inverse; + for (int i = 0; i < 4; ++i) { + const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3); + shuffles_inverse[shuffles[index]] = i; + std::swap(shuffles[i], shuffles[index]); + } + + Tensor<float, 4, DataLayout> shuffle; + shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse); + + VERIFY_IS_EQUAL(shuffle.dimension(0), 2); + VERIFY_IS_EQUAL(shuffle.dimension(1), 3); + VERIFY_IS_EQUAL(shuffle.dimension(2), 5); + VERIFY_IS_EQUAL(shuffle.dimension(3), 7); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l)); + } + } + } + } +} + + void test_cxx11_tensor_shuffling() { - CALL_SUBTEST(test_simple_shuffling<ColMajor>()); - CALL_SUBTEST(test_simple_shuffling<RowMajor>()); - CALL_SUBTEST(test_expr_shuffling<ColMajor>()); - CALL_SUBTEST(test_expr_shuffling<RowMajor>()); - CALL_SUBTEST(test_shuffling_as_value<ColMajor>()); - CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); + CALL_SUBTEST(test_simple_shuffling<ColMajor>()); + CALL_SUBTEST(test_simple_shuffling<RowMajor>()); + CALL_SUBTEST(test_expr_shuffling<ColMajor>()); + CALL_SUBTEST(test_expr_shuffling<RowMajor>()); + CALL_SUBTEST(test_shuffling_as_value<ColMajor>()); + CALL_SUBTEST(test_shuffling_as_value<RowMajor>()); + CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>()); + CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>()); }