| // This file is part of Eigen, a lightweight C++ template library |
| // for linear algebra. |
| // |
| // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> |
| // |
| // This Source Code Form is subject to the terms of the Mozilla |
| // Public License v. 2.0. If a copy of the MPL was not distributed |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H |
| #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H |
| |
| namespace Eigen { |
| |
| /** \class TensorContraction |
| * \ingroup CXX11_Tensor_Module |
| * |
| * \brief Tensor contraction class. |
| * |
| * |
| */ |
| namespace internal { |
| template<typename Dimensions, typename LhsXprType, typename RhsXprType> |
| struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > |
| { |
| // Type promotion to handle the case where the types of the lhs and the rhs are different. |
| typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, |
| typename RhsXprType::Scalar>::ret Scalar; |
| typedef typename internal::packet_traits<Scalar>::type Packet; |
| typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, |
| typename traits<RhsXprType>::StorageKind>::ret StorageKind; |
| typedef typename promote_index_type<typename traits<LhsXprType>::Index, |
| typename traits<RhsXprType>::Index>::type Index; |
| typedef typename LhsXprType::Nested LhsNested; |
| typedef typename RhsXprType::Nested RhsNested; |
| typedef typename remove_reference<LhsNested>::type _LhsNested; |
| typedef typename remove_reference<RhsNested>::type _RhsNested; |
| |
| enum { |
| Flags = 0, |
| }; |
| }; |
| |
| template<typename Dimensions, typename LhsXprType, typename RhsXprType> |
| struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense> |
| { |
| typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type; |
| }; |
| |
| template<typename Dimensions, typename LhsXprType, typename RhsXprType> |
| struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type> |
| { |
| typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type; |
| }; |
| |
| } // end namespace internal |
| |
| |
| |
| template<typename Indices, typename LhsXprType, typename RhsXprType> |
| class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType> > |
| { |
| public: |
| typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; |
| typedef typename Eigen::internal::traits<TensorContractionOp>::Packet Packet; |
| typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; |
| typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, |
| typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; |
| typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType, |
| typename RhsXprType::PacketReturnType>::ret PacketReturnType; |
| typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested; |
| typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind; |
| typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; |
| |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) |
| : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} |
| |
| EIGEN_DEVICE_FUNC |
| const Indices& indices() const { return m_indices; } |
| |
| /** \returns the nested expressions */ |
| EIGEN_DEVICE_FUNC |
| const typename internal::remove_all<typename LhsXprType::Nested>::type& |
| lhsExpression() const { return m_lhs_xpr; } |
| |
| EIGEN_DEVICE_FUNC |
| const typename internal::remove_all<typename RhsXprType::Nested>::type& |
| rhsExpression() const { return m_rhs_xpr; } |
| |
| protected: |
| typename LhsXprType::Nested m_lhs_xpr; |
| typename RhsXprType::Nested m_rhs_xpr; |
| const Indices m_indices; |
| }; |
| |
| |
| template <size_t n> struct max_n_1 { |
| static const size_t size = n; |
| }; |
| template <> struct max_n_1<0> { |
| static const size_t size = 1; |
| }; |
| |
| |
| template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> |
| struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> |
| { |
| typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; |
| |
| static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size; |
| typedef typename XprType::Index Index; |
| typedef DSizes<Index, NumDims> Dimensions; |
| |
| enum { |
| IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned, |
| PacketAccess = /*TensorEvaluator<LeftArgType>::PacketAccess & TensorEvaluator<RightArgType>::PacketAccess */ |
| false, |
| }; |
| |
| TensorEvaluator(const XprType& op, const Device& device) |
| : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device) |
| { |
| Index index = 0; |
| Index stride = 1; |
| m_shiftright = 1; |
| |
| int skipped = 0; |
| const typename TensorEvaluator<LeftArgType, Device>::Dimensions& left_dims = m_leftImpl.dimensions(); |
| for (int i = 0; i < TensorEvaluator<LeftArgType, Device>::Dimensions::count; ++i) { |
| bool skip = false; |
| for (int j = 0; j < internal::array_size<Indices>::value; ++j) { |
| if (op.indices()[j].first == i) { |
| skip = true; |
| m_leftOffsets[2*skipped] = stride; |
| m_leftOffsets[2*skipped+1] = stride * left_dims[i]; |
| m_stitchsize[skipped] = left_dims[i]; |
| break; |
| } |
| } |
| if (!skip) { |
| m_dimensions[index++] = left_dims[i]; |
| m_shiftright *= left_dims[i]; |
| } else { |
| ++skipped; |
| } |
| stride *= left_dims[i]; |
| } |
| |
| stride = 1; |
| skipped = 0; |
| const typename TensorEvaluator<RightArgType, Device>::Dimensions& right_dims = m_rightImpl.dimensions(); |
| for (int i = 0; i < TensorEvaluator<RightArgType, Device>::Dimensions::count; ++i) { |
| bool skip = false; |
| for (int j = 0; j < internal::array_size<Indices>::value; ++j) { |
| if (op.indices()[j].second == i) { |
| skip = true; |
| m_rightOffsets[2*skipped] = stride; |
| m_rightOffsets[2*skipped+1] = stride * right_dims[i]; |
| break; |
| } |
| } |
| if (!skip) { |
| m_dimensions[index++] = right_dims[i]; |
| } else { |
| ++skipped; |
| } |
| stride *= right_dims[i]; |
| } |
| |
| // Scalar case |
| if (TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count == 2 * internal::array_size<Indices>::value) { |
| m_dimensions[0] = 1; |
| } |
| } |
| |
| typedef typename XprType::Scalar Scalar; |
| typedef typename XprType::CoeffReturnType CoeffReturnType; |
| typedef typename XprType::PacketReturnType PacketReturnType; |
| |
| const Dimensions& dimensions() const { return m_dimensions; } |
| |
| void evalTo(typename XprType::Scalar* buffer) const { |
| for (int i = 0; i < dimensions().TotalSize(); ++i) { |
| buffer[i] += coeff(i); |
| } |
| } |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { |
| m_leftImpl.evalSubExprsIfNeeded(NULL); |
| m_rightImpl.evalSubExprsIfNeeded(NULL); |
| return true; |
| } |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { |
| m_leftImpl.cleanup(); |
| m_rightImpl.cleanup(); |
| } |
| |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const |
| { |
| const Index startLeft = index % m_shiftright; |
| const Index startRight = index / m_shiftright; |
| CoeffReturnType result = CoeffReturnType(0); |
| partialStitch(startLeft, startRight, 0, result); |
| return result; |
| } |
| |
| /* TODO: vectorization |
| template<int LoadMode> |
| EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const |
| { |
| assert(false); |
| }*/ |
| |
| private: |
| EIGEN_DEVICE_FUNC void partialStitch(Index startLeft, Index startRight, int StitchIndex, CoeffReturnType& accum) const { |
| Index firstLeft = (startLeft / m_leftOffsets[2*StitchIndex]) * m_leftOffsets[2*StitchIndex+1] + (startLeft % m_leftOffsets[2*StitchIndex]); |
| Index firstRight = (startRight / m_rightOffsets[2*StitchIndex]) * m_rightOffsets[2*StitchIndex+1] + (startRight % m_rightOffsets[2*StitchIndex]); |
| |
| for (int j = 0; j < m_stitchsize[StitchIndex]; ++j) { |
| const Index left = firstLeft+j*m_leftOffsets[2*StitchIndex]; |
| const Index right = firstRight+j*m_rightOffsets[2*StitchIndex]; |
| if (StitchIndex < internal::array_size<Indices>::value-1) { |
| partialStitch(left, right, StitchIndex+1, accum); |
| } else { |
| accum += m_leftImpl.coeff(left) * m_rightImpl.coeff(right); |
| } |
| } |
| } |
| |
| Scalar* data() const { return NULL; } |
| |
| private: |
| array<Index, 2*internal::array_size<Indices>::value> m_leftOffsets; |
| array<Index, 2*internal::array_size<Indices>::value> m_rightOffsets; |
| array<Index, internal::array_size<Indices>::value> m_stitchsize; |
| Index m_shiftright; |
| Dimensions m_dimensions; |
| TensorEvaluator<LeftArgType, Device> m_leftImpl; |
| TensorEvaluator<RightArgType, Device> m_rightImpl; |
| }; |
| |
| |
| } // end namespace Eigen |
| |
| #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H |