blob: 7a65befa0c29a8042fcb7f25351a0048c4b861d5 [file] [edit]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// BLAS Level 3 expression types for gpu::DeviceMatrix (beyond GEMM):
// TrsmExpr -> cublasXtrsm (triangular solve)
// SymmExpr -> cublasXsymm (symmetric multiply, real)
// -> cublasXhemm (Hermitian multiply, complex)
// SyrkExpr -> cublasXsyrk (symmetric rank-k update, real)
// -> cublasXherk (Hermitian rank-k update, complex)
#ifndef EIGEN_GPU_DEVICE_BLAS_EXPR_H
#define EIGEN_GPU_DEVICE_BLAS_EXPR_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include <functional>
namespace Eigen {
namespace gpu {
template <typename Scalar_>
class DeviceMatrix;
template <typename Scalar_, int UpLo_>
class TrsmExpr;
// ---- TriangularView --------------------------------------------------------
// d_A.triangularView<Lower>() -> view with .solve(d_B)
template <typename Scalar_, int UpLo_>
class TriangularView {
public:
using Scalar = Scalar_;
static constexpr int UpLo = UpLo_;
explicit TriangularView(const DeviceMatrix<Scalar>& m) : mat_(m) {}
const DeviceMatrix<Scalar>& matrix() const { return mat_; }
/** Build a TRSM solve expression. */
TrsmExpr<Scalar, UpLo_> solve(const DeviceMatrix<Scalar>& rhs) const { return {mat_, rhs}; }
private:
std::reference_wrapper<const DeviceMatrix<Scalar>> mat_;
};
// ---- TrsmExpr: triangularView<UpLo>().solve(B) -> cublasXtrsm --------------
template <typename Scalar_, int UpLo_>
class TrsmExpr {
public:
using Scalar = Scalar_;
static constexpr int UpLo = UpLo_;
TrsmExpr(const DeviceMatrix<Scalar>& A, const DeviceMatrix<Scalar>& B) : A_(A), B_(B) {}
const DeviceMatrix<Scalar>& matrix() const { return A_; }
const DeviceMatrix<Scalar>& rhs() const { return B_; }
private:
std::reference_wrapper<const DeviceMatrix<Scalar>> A_;
std::reference_wrapper<const DeviceMatrix<Scalar>> B_;
};
// ---- SelfAdjointView -------------------------------------------------------
// d_A.selfadjointView<Lower>() -> view that can multiply: view * d_B
template <typename Scalar_, int UpLo_>
class SelfAdjointView {
public:
using Scalar = Scalar_;
using RealScalar = typename NumTraits<Scalar>::Real;
static constexpr int UpLo = UpLo_;
explicit SelfAdjointView(DeviceMatrix<Scalar>& m) : mat_(m) {}
const DeviceMatrix<Scalar>& matrix() const { return mat_; }
DeviceMatrix<Scalar>& matrix() { return mat_; }
/** Rank-k update: C.selfadjointView<Lower>().rankUpdate(A, alpha)
* computes C = alpha * A * A^H + C (lower triangle only).
* Maps to cublasXsyrk (real) or cublasXherk (complex). */
void rankUpdate(const DeviceMatrix<Scalar>& A, RealScalar alpha = RealScalar(1));
private:
std::reference_wrapper<DeviceMatrix<Scalar>> mat_;
};
// Const variant for multiplication only (no rankUpdate).
template <typename Scalar_, int UpLo_>
class ConstSelfAdjointView {
public:
using Scalar = Scalar_;
static constexpr int UpLo = UpLo_;
explicit ConstSelfAdjointView(const DeviceMatrix<Scalar>& m) : mat_(m) {}
const DeviceMatrix<Scalar>& matrix() const { return mat_; }
private:
std::reference_wrapper<const DeviceMatrix<Scalar>> mat_;
};
// ---- SymmExpr: selfadjointView<UpLo>() * B -> cublasXsymm/Xhemm -----------
template <typename Scalar_, int UpLo_>
class SymmExpr {
public:
using Scalar = Scalar_;
static constexpr int UpLo = UpLo_;
SymmExpr(const DeviceMatrix<Scalar>& A, const DeviceMatrix<Scalar>& B) : A_(A), B_(B) {}
const DeviceMatrix<Scalar>& matrix() const { return A_; }
const DeviceMatrix<Scalar>& rhs() const { return B_; }
private:
std::reference_wrapper<const DeviceMatrix<Scalar>> A_;
std::reference_wrapper<const DeviceMatrix<Scalar>> B_;
};
// operator*: SelfAdjointView * Matrix -> SymmExpr (mutable and const variants)
template <typename S, int UpLo>
SymmExpr<S, UpLo> operator*(const SelfAdjointView<S, UpLo>& a, const DeviceMatrix<S>& b) {
return {a.matrix(), b};
}
template <typename S, int UpLo>
SymmExpr<S, UpLo> operator*(const ConstSelfAdjointView<S, UpLo>& a, const DeviceMatrix<S>& b) {
return {a.matrix(), b};
}
// ---- SyrkExpr: rankUpdate(A) -> cublasXsyrk/Xherk --------------------------
// C.rankUpdate(A) computes C += A * A^H (or A^H * A depending on convention).
template <typename Scalar_, int UpLo_>
class SyrkExpr {
public:
using Scalar = Scalar_;
static constexpr int UpLo = UpLo_;
SyrkExpr(const DeviceMatrix<Scalar>& A) : A_(A) {}
const DeviceMatrix<Scalar>& matrix() const { return A_; }
private:
std::reference_wrapper<const DeviceMatrix<Scalar>> A_;
};
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_DEVICE_BLAS_EXPR_H