| // This file is part of Eigen, a lightweight C++ template library |
| // for linear algebra. |
| // |
| // Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com> |
| // |
| // This Source Code Form is subject to the terms of the Mozilla |
| // Public License v. 2.0. If a copy of the MPL was not distributed |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| #ifndef EIGEN_REALVIEW_H |
| #define EIGEN_REALVIEW_H |
| |
| // IWYU pragma: private |
| #include "./InternalHeaderCheck.h" |
| |
| namespace Eigen { |
| |
| namespace internal { |
| |
| // Write access and vectorization requires array-oriented access to the real and imaginary components. |
| // From https://en.cppreference.com/w/cpp/numeric/complex.html: |
| // For any pointer to an element of an array of std::complex<T> named p and any valid array index i, |
| // reinterpret_cast<T*>(p)[2 * i] is the real part of the complex number p[i], and |
| // reinterpret_cast<T*>(p)[2 * i + 1] is the imaginary part of the complex number p[i]. |
| |
| template <typename T> |
| struct complex_array_access : std::false_type {}; |
| template <typename T> |
| struct complex_array_access<std::complex<T>> : std::true_type {}; |
| |
| template <typename Xpr> |
| struct traits<RealView<Xpr>> : public traits<Xpr> { |
| template <typename T> |
| static constexpr int double_size(T size, bool times_two) { |
| int size_as_int = int(size); |
| if (size_as_int == Dynamic) return Dynamic; |
| return times_two ? (2 * size_as_int) : size_as_int; |
| } |
| |
| using Base = traits<Xpr>; |
| using ComplexScalar = typename Base::Scalar; |
| using Scalar = typename NumTraits<ComplexScalar>::Real; |
| |
| static constexpr bool ArrayAccess = complex_array_access<ComplexScalar>::value; |
| static constexpr int ActualDirectAccessBit = ArrayAccess ? DirectAccessBit : 0; |
| static constexpr int ActualLvaluebit = !std::is_const<Xpr>::value && ArrayAccess ? LvalueBit : 0; |
| static constexpr int ActualPacketAccessBit = packet_traits<Scalar>::Vectorizable ? PacketAccessBit : 0; |
| static constexpr int FlagMask = |
| ActualDirectAccessBit | ActualLvaluebit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit; |
| static constexpr int BaseFlags = int(evaluator<Xpr>::Flags) | int(Base::Flags); |
| static constexpr int Flags = BaseFlags & FlagMask; |
| static constexpr bool IsRowMajor = Flags & RowMajorBit; |
| static constexpr int RowsAtCompileTime = double_size(Base::RowsAtCompileTime, !IsRowMajor); |
| static constexpr int ColsAtCompileTime = double_size(Base::ColsAtCompileTime, IsRowMajor); |
| static constexpr int SizeAtCompileTime = size_at_compile_time(RowsAtCompileTime, ColsAtCompileTime); |
| static constexpr int MaxRowsAtCompileTime = double_size(Base::MaxRowsAtCompileTime, !IsRowMajor); |
| static constexpr int MaxColsAtCompileTime = double_size(Base::MaxColsAtCompileTime, IsRowMajor); |
| static constexpr int MaxSizeAtCompileTime = size_at_compile_time(MaxRowsAtCompileTime, MaxColsAtCompileTime); |
| static constexpr int OuterStrideAtCompileTime = double_size(outer_stride_at_compile_time<Xpr>::ret, true); |
| static constexpr int InnerStrideAtCompileTime = inner_stride_at_compile_time<Xpr>::ret; |
| }; |
| |
| template <typename Xpr> |
| struct evaluator<RealView<Xpr>> : private evaluator<Xpr> { |
| using BaseEvaluator = evaluator<Xpr>; |
| using XprType = RealView<Xpr>; |
| using ExpressionTraits = traits<XprType>; |
| using ComplexScalar = typename ExpressionTraits::ComplexScalar; |
| using Scalar = typename ExpressionTraits::Scalar; |
| |
| static constexpr int Flags = ExpressionTraits::Flags; |
| static constexpr int CoeffReadCost = BaseEvaluator::CoeffReadCost; |
| static constexpr int Alignment = BaseEvaluator::Alignment; |
| static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor; |
| static constexpr bool DirectAccess = Flags & DirectAccessBit; |
| |
| using ComplexCoeffReturnType = std::conditional_t<DirectAccess, const ComplexScalar&, ComplexScalar>; |
| using CoeffReturnType = std::conditional_t<DirectAccess, const Scalar&, Scalar>; |
| |
| EIGEN_DEVICE_FUNC explicit evaluator(XprType realView) : BaseEvaluator(realView.m_xpr) {} |
| |
| template <bool Enable = DirectAccess, std::enable_if_t<!Enable, bool> = true> |
| constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index row, Index col) const { |
| Index r = IsRowMajor ? row : row / 2; |
| Index c = IsRowMajor ? col / 2 : col; |
| bool p = (IsRowMajor ? col : row) & 1; |
| ComplexScalar ccoeff = BaseEvaluator::coeff(r, c); |
| return p ? numext::imag(ccoeff) : numext::real(ccoeff); |
| } |
| template <bool Enable = DirectAccess, std::enable_if_t<Enable, bool> = true> |
| constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { |
| Index r = IsRowMajor ? row : row / 2; |
| Index c = IsRowMajor ? col / 2 : col; |
| Index p = (IsRowMajor ? col : row) & 1; |
| ComplexCoeffReturnType ccoeff = BaseEvaluator::coeff(r, c); |
| return reinterpret_cast<const Scalar(&)[2]>(ccoeff)[p]; |
| } |
| template <bool Enable = DirectAccess, std::enable_if_t<!Enable, bool> = true> |
| constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const { |
| ComplexScalar ccoeff = BaseEvaluator::coeff(index / 2); |
| bool p = index & 1; |
| return p ? numext::imag(ccoeff) : numext::real(ccoeff); |
| } |
| template <bool Enable = DirectAccess, std::enable_if_t<Enable, bool> = true> |
| constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { |
| ComplexCoeffReturnType ccoeff = BaseEvaluator::coeff(index / 2); |
| Index p = index & 1; |
| return reinterpret_cast<const Scalar(&)[2]>(ccoeff)[p]; |
| } |
| constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { |
| Index r = IsRowMajor ? row : row / 2; |
| Index c = IsRowMajor ? col / 2 : col; |
| Index p = (IsRowMajor ? col : row) & 1; |
| ComplexScalar& ccoeffRef = BaseEvaluator::coeffRef(r, c); |
| return reinterpret_cast<Scalar(&)[2]>(ccoeffRef)[p]; |
| } |
| constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { |
| ComplexScalar& ccoeffRef = BaseEvaluator::coeffRef(index / 2); |
| Index p = index & 1; |
| return reinterpret_cast<Scalar(&)[2]>(ccoeffRef)[p]; |
| } |
| |
| // If the first index is odd (imaginary), discard the first scalar |
| // in 'result' and assign the missing scalar. |
| // This operation is safe as the real component of the first scalar must exist. |
| |
| template <int LoadMode, typename PacketType> |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { |
| constexpr int RealPacketSize = unpacket_traits<PacketType>::size; |
| using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type; |
| EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value), |
| MISSING COMPATIBLE COMPLEX PACKET TYPE) |
| Index r = IsRowMajor ? row : row / 2; |
| Index c = IsRowMajor ? col / 2 : col; |
| bool p = (IsRowMajor ? col : row) & 1; |
| ComplexPacket cresult = BaseEvaluator::template packet<LoadMode, ComplexPacket>(r, c); |
| PacketType result = preinterpret<PacketType>(cresult); |
| if (p) { |
| Scalar aux[RealPacketSize + 1]; |
| pstoreu(aux, result); |
| Index lastr = IsRowMajor ? row : row + RealPacketSize - 1; |
| Index lastc = IsRowMajor ? col + RealPacketSize - 1 : col; |
| aux[RealPacketSize] = coeff(lastr, lastc); |
| result = ploadu<PacketType>(aux + 1); |
| } |
| return result; |
| } |
| |
| template <int LoadMode, typename PacketType> |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const { |
| constexpr int RealPacketSize = unpacket_traits<PacketType>::size; |
| using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type; |
| EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value), |
| MISSING COMPATIBLE COMPLEX PACKET TYPE) |
| ComplexPacket cresult = BaseEvaluator::template packet<LoadMode, ComplexPacket>(index / 2); |
| PacketType result = preinterpret<PacketType>(cresult); |
| bool p = index & 1; |
| if (p) { |
| Scalar aux[RealPacketSize + 1]; |
| pstoreu(aux, result); |
| aux[RealPacketSize] = coeff(index + RealPacketSize - 1); |
| result = ploadu<PacketType>(aux + 1); |
| } |
| return result; |
| } |
| |
| // The requested real packet segment forms the half-open interval [begin, end), where 'end' = 'begin' + 'count'. |
| // In order to access the underlying complex array, even indices must be aligned with the real components |
| // of the complex scalars. 'begin' and 'count' must be modified as follows: |
| // a) 'begin' must be rounded down to the nearest even number; and |
| // b) 'end' must be rounded up to the nearest even number. |
| |
| template <int LoadMode, typename PacketType> |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index row, Index col, Index begin, Index count) const { |
| constexpr int RealPacketSize = unpacket_traits<PacketType>::size; |
| using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type; |
| EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value), |
| MISSING COMPATIBLE COMPLEX PACKET TYPE) |
| Index actualBegin = numext::round_down(begin, 2); |
| Index actualEnd = numext::round_down(begin + count + 1, 2); |
| Index actualCount = actualEnd - actualBegin; |
| Index r = IsRowMajor ? row : row / 2; |
| Index c = IsRowMajor ? col / 2 : col; |
| ComplexPacket cresult = |
| BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(r, c, actualBegin / 2, actualCount / 2); |
| PacketType result = preinterpret<PacketType>(cresult); |
| bool p = (IsRowMajor ? col : row) & 1; |
| if (p) { |
| Scalar aux[RealPacketSize + 1] = {}; |
| pstoreu(aux, result); |
| Index lastr = IsRowMajor ? row : row + actualEnd - 1; |
| Index lastc = IsRowMajor ? col + actualEnd - 1 : col; |
| aux[actualEnd] = coeff(lastr, lastc); |
| result = ploadu<PacketType>(aux + 1); |
| } |
| return result; |
| } |
| |
| template <int LoadMode, typename PacketType> |
| EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index index, Index begin, Index count) const { |
| constexpr int RealPacketSize = unpacket_traits<PacketType>::size; |
| using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type; |
| EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value), |
| MISSING COMPATIBLE COMPLEX PACKET TYPE) |
| Index actualBegin = numext::round_down(begin, 2); |
| Index actualEnd = numext::round_down(begin + count + 1, 2); |
| Index actualCount = actualEnd - actualBegin; |
| ComplexPacket cresult = |
| BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(index / 2, actualBegin / 2, actualCount / 2); |
| PacketType result = preinterpret<PacketType>(cresult); |
| bool p = index & 1; |
| if (p) { |
| Scalar aux[RealPacketSize + 1] = {}; |
| pstoreu(aux, result); |
| aux[actualEnd] = coeff(index + actualEnd - 1); |
| result = ploadu<PacketType>(aux + 1); |
| } |
| return result; |
| } |
| }; |
| |
| } // namespace internal |
| |
| template <typename Xpr> |
| class RealView : public internal::dense_xpr_base<RealView<Xpr>>::type { |
| using ExpressionTraits = internal::traits<RealView>; |
| EIGEN_STATIC_ASSERT(NumTraits<typename Xpr::Scalar>::IsComplex, SCALAR MUST BE COMPLEX) |
| public: |
| using Scalar = typename ExpressionTraits::Scalar; |
| using Nested = RealView; |
| |
| EIGEN_DEVICE_FUNC explicit RealView(Xpr& xpr) : m_xpr(xpr) {} |
| EIGEN_DEVICE_FUNC constexpr Index rows() const noexcept { return Xpr::IsRowMajor ? m_xpr.rows() : 2 * m_xpr.rows(); } |
| EIGEN_DEVICE_FUNC constexpr Index cols() const noexcept { return Xpr::IsRowMajor ? 2 * m_xpr.cols() : m_xpr.cols(); } |
| EIGEN_DEVICE_FUNC constexpr Index size() const noexcept { return 2 * m_xpr.size(); } |
| EIGEN_DEVICE_FUNC constexpr Index innerStride() const noexcept { return m_xpr.innerStride(); } |
| EIGEN_DEVICE_FUNC constexpr Index outerStride() const noexcept { return 2 * m_xpr.outerStride(); } |
| EIGEN_DEVICE_FUNC void resize(Index rows, Index cols) { |
| m_xpr.resize(Xpr::IsRowMajor ? rows : rows / 2, Xpr::IsRowMajor ? cols / 2 : cols); |
| } |
| EIGEN_DEVICE_FUNC void resize(Index size) { m_xpr.resize(size / 2); } |
| EIGEN_DEVICE_FUNC Scalar* data() { return reinterpret_cast<Scalar*>(m_xpr.data()); } |
| EIGEN_DEVICE_FUNC const Scalar* data() const { return reinterpret_cast<const Scalar*>(m_xpr.data()); } |
| |
| EIGEN_DEVICE_FUNC RealView(const RealView&) = default; |
| |
| EIGEN_DEVICE_FUNC RealView& operator=(const RealView& other); |
| |
| template <typename OtherDerived> |
| EIGEN_DEVICE_FUNC RealView& operator=(const RealView<OtherDerived>& other); |
| |
| template <typename OtherDerived> |
| EIGEN_DEVICE_FUNC RealView& operator=(const DenseBase<OtherDerived>& other); |
| |
| protected: |
| friend struct internal::evaluator<RealView>; |
| Xpr& m_xpr; |
| }; |
| |
| template <typename Xpr> |
| EIGEN_DEVICE_FUNC RealView<Xpr>& RealView<Xpr>::operator=(const RealView& other) { |
| internal::call_assignment(*this, other); |
| return *this; |
| } |
| |
| template <typename Xpr> |
| template <typename OtherDerived> |
| EIGEN_DEVICE_FUNC RealView<Xpr>& RealView<Xpr>::operator=(const RealView<OtherDerived>& other) { |
| internal::call_assignment(*this, other); |
| return *this; |
| } |
| |
| template <typename Xpr> |
| template <typename OtherDerived> |
| EIGEN_DEVICE_FUNC RealView<Xpr>& RealView<Xpr>::operator=(const DenseBase<OtherDerived>& other) { |
| internal::call_assignment(*this, other.derived()); |
| return *this; |
| } |
| |
| template <typename Derived> |
| EIGEN_DEVICE_FUNC typename DenseBase<Derived>::RealViewReturnType DenseBase<Derived>::realView() { |
| return RealViewReturnType(derived()); |
| } |
| |
| template <typename Derived> |
| EIGEN_DEVICE_FUNC typename DenseBase<Derived>::ConstRealViewReturnType DenseBase<Derived>::realView() const { |
| return ConstRealViewReturnType(derived()); |
| } |
| |
| } // namespace Eigen |
| |
| #endif // EIGEN_REALVIEW_H |