blob: 026e1dbc1301143061defa257526cb617d024ba4 [file] [log] [blame]
// This file is part of Eigen, a lightweight C template library
// for linear algebra.
//
// Copyright (C) 2024 Kseniya Zaytseva <kseniya.zaytseva@syntacore.com>
// Copyright (C) 2025 Chip Kerchner <ckerchner@tenstorrent.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_RVV10_GENERAL_BLOCK_KERNEL_H
#define EIGEN_RVV10_GENERAL_BLOCK_KERNEL_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
/********************************* real ************************************/
template <>
struct gebp_traits<float, float, false, false, Architecture::RVV10, GEBPPacketFull>
: gebp_traits<float, float, false, false, Architecture::Generic, GEBPPacketFull> {
typedef float RhsPacket;
typedef QuadPacket<float> RhsPacketx4;
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { dest = pset1<RhsPacket>(*b); }
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const { loadRhs(b, dest); }
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = ploadquad<RhsPacket>(b); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = __riscv_vfmadd_vf_f32m1(a, b, c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 2
c = __riscv_vfmadd_vf_f32m2(a, b, c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 4
c = __riscv_vfmadd_vf_f32m4(a, b, c, unpacket_traits<AccPacket>::size);
#endif
}
#if EIGEN_RISCV64_DEFAULT_LMUL >= 2
EIGEN_STRONG_INLINE void madd(const Packet1Xf& a, const RhsPacket& b, Packet1Xf& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
c = __riscv_vfmadd_vf_f32m1(a, b, c, unpacket_traits<Packet1Xf>::size);
}
#endif
#if EIGEN_RISCV64_DEFAULT_LMUL == 4
EIGEN_STRONG_INLINE void madd(const Packet2Xf& a, const RhsPacket& b, Packet2Xf& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
c = __riscv_vfmadd_vf_f32m2(a, b, c, unpacket_traits<Packet2Xf>::size);
}
#endif
template <typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/,
const LaneIdType& lane) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = __riscv_vfmadd_vf_f32m1(a, b.get(lane), c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 2
c = __riscv_vfmadd_vf_f32m2(a, b.get(lane), c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 4
c = __riscv_vfmadd_vf_f32m4(a, b.get(lane), c, unpacket_traits<AccPacket>::size);
#endif
}
};
template <>
struct gebp_traits<double, double, false, false, Architecture::RVV10, GEBPPacketFull>
: gebp_traits<double, double, false, false, Architecture::Generic, GEBPPacketFull> {
typedef double RhsPacket;
typedef QuadPacket<double> RhsPacketx4;
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { dest = pset1<RhsPacket>(*b); }
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const { loadRhs(b, dest); }
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = ploadquad<RhsPacket>(b); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = __riscv_vfmadd_vf_f64m1(a, b, c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 2
c = __riscv_vfmadd_vf_f64m2(a, b, c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 4
c = __riscv_vfmadd_vf_f64m4(a, b, c, unpacket_traits<AccPacket>::size);
#endif
}
#if EIGEN_RISCV64_DEFAULT_LMUL >= 2
EIGEN_STRONG_INLINE void madd(const Packet1Xd& a, const RhsPacket& b, Packet1Xd& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
c = __riscv_vfmadd_vf_f64m1(a, b, c, unpacket_traits<Packet1Xd>::size);
}
#endif
#if EIGEN_RISCV64_DEFAULT_LMUL == 4
EIGEN_STRONG_INLINE void madd(const Packet2Xd& a, const RhsPacket& b, Packet2Xd& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
c = __riscv_vfmadd_vf_f64m2(a, b, c, unpacket_traits<Packet2Xd>::size);
}
#endif
template <typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/,
const LaneIdType& lane) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = __riscv_vfmadd_vf_f64m1(a, b.get(lane), c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 2
c = __riscv_vfmadd_vf_f64m2(a, b.get(lane), c, unpacket_traits<AccPacket>::size);
#elif EIGEN_RISCV64_DEFAULT_LMUL == 4
c = __riscv_vfmadd_vf_f64m4(a, b.get(lane), c, unpacket_traits<AccPacket>::size);
#endif
}
};
#if defined(EIGEN_VECTORIZE_RVV10FP16)
template <>
struct gebp_traits<half, half, false, false, Architecture::RVV10>
: gebp_traits<half, half, false, false, Architecture::Generic> {
typedef half RhsPacket;
typedef PacketXh LhsPacket;
typedef PacketXh AccPacket;
typedef QuadPacket<half> RhsPacketx4;
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { dest = pset1<RhsPacket>(*b); }
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const { loadRhs(b, dest); }
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = pload<RhsPacket>(b); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = __riscv_vfmadd_vf_f16m1(a, numext::bit_cast<_Float16>(b), c, unpacket_traits<AccPacket>::size);
#else
c = __riscv_vfmadd_vf_f16m2(a, numext::bit_cast<_Float16>(b), c, unpacket_traits<AccPacket>::size);
#endif
}
#if EIGEN_RISCV64_DEFAULT_LMUL >= 2
EIGEN_STRONG_INLINE void madd(const Packet1Xh& a, const RhsPacket& b, Packet1Xh& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
c = __riscv_vfmadd_vf_f16m1(a, numext::bit_cast<_Float16>(b), c, unpacket_traits<Packet1Xh>::size);
}
#endif
template <typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/,
const LaneIdType& lane) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = __riscv_vfmadd_vf_f16m1(a, numext::bit_cast<_Float16>(b.get(lane)), c, unpacket_traits<AccPacket>::size);
#else
c = __riscv_vfmadd_vf_f16m2(a, numext::bit_cast<_Float16>(b.get(lane)), c, unpacket_traits<AccPacket>::size);
#endif
}
};
#endif
#if defined(EIGEN_VECTORIZE_RVV10BF16)
template <>
struct gebp_traits<bfloat16, bfloat16, false, false, Architecture::RVV10>
: gebp_traits<bfloat16, bfloat16, false, false, Architecture::Generic> {
typedef bfloat16 RhsPacket;
typedef PacketXbf LhsPacket;
typedef PacketXbf AccPacket;
typedef QuadPacket<bfloat16> RhsPacketx4;
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { dest = pset1<RhsPacket>(*b); }
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const {
pbroadcast4(b, dest.B_0, dest.B1, dest.B2, dest.B3);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const { loadRhs(b, dest); }
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const {}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { dest = pload<RhsPacket>(b); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = F32ToBf16(
__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(c), numext::bit_cast<__bf16>(b), a, unpacket_traits<AccPacket>::size));
#else
c = F32ToBf16(
__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(c), numext::bit_cast<__bf16>(b), a, unpacket_traits<AccPacket>::size));
#endif
}
#if EIGEN_RISCV64_DEFAULT_LMUL >= 2
EIGEN_STRONG_INLINE void madd(const Packet1Xbf& a, const RhsPacket& b, Packet1Xbf& c, RhsPacket& /*tmp*/,
const FixedInt<0>&) const {
c = F32ToBf16(
__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(c), numext::bit_cast<__bf16>(b), a, unpacket_traits<Packet1Xbf>::size));
}
#endif
template <typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/,
const LaneIdType& lane) const {
#if EIGEN_RISCV64_DEFAULT_LMUL == 1
c = F32ToBf16(__riscv_vfwmaccbf16_vf_f32m2(Bf16ToF32(c), numext::bit_cast<__bf16>(b.get(lane)), a,
unpacket_traits<AccPacket>::size));
#else
c = F32ToBf16(__riscv_vfwmaccbf16_vf_f32m4(Bf16ToF32(c), numext::bit_cast<__bf16>(b.get(lane)), a,
unpacket_traits<AccPacket>::size));
#endif
}
};
#endif
} // namespace internal
} // namespace Eigen
#endif // EIGEN_RVV10_GENERAL_BLOCK_KERNEL_H