blob: 92c10a913b4156c54e3da2aa1815eb0610a941b5 [file] [edit]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Eigen Authors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// GPU partial-pivoting LU decomposition using cuSOLVER.
//
// Wraps cusolverDnXgetrf (factorization) and cusolverDnXgetrs (solve).
// The factored LU matrix and pivot array are kept in device memory for the
// lifetime of the object, so repeated solves only transfer the RHS/solution.
//
// Requires CUDA 11.0+ (cusolverDnX generic API).
//
// Usage:
// gpu::LU<double> lu(A); // upload A, getrf, LU+ipiv on device
// if (lu.info() != Success) { ... }
// MatrixXd x = lu.solve(b); // getrs NoTrans, only b transferred
// MatrixXd xt = lu.solve(b, gpu::GpuOp::Trans); // A^T x = b
#ifndef EIGEN_GPU_LU_H
#define EIGEN_GPU_LU_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./GpuSolverContext.h"
namespace Eigen {
namespace gpu {
/** \ingroup GPU_Module
* \class LU
* \brief GPU LU decomposition with partial pivoting via cuSOLVER
*
* \tparam Scalar_ Element type: float, double, complex<float>, complex<double>
*
* Decomposes a square matrix A = P L U on the GPU and retains the factored
* matrix and pivot array in device memory. Solves A*X=B, A^T*X=B, or
* A^H*X=B by passing the appropriate gpu::GpuOp.
*
* Each LU object owns a dedicated CUDA stream and cuSOLVER handle.
*/
template <typename Scalar_>
class LU {
public:
using Scalar = Scalar_;
using RealScalar = typename NumTraits<Scalar>::Real;
using PlainMatrix = Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>;
// ---- Construction / destruction ------------------------------------------
LU() = default;
template <typename InputType>
explicit LU(const EigenBase<InputType>& A) {
compute(A);
}
~LU() = default;
LU(const LU&) = delete;
LU& operator=(const LU&) = delete;
LU(LU&& o) noexcept
: solver_ctx_(std::move(o.solver_ctx_)),
d_lu_(std::move(o.d_lu_)),
lu_alloc_size_(o.lu_alloc_size_),
d_ipiv_(std::move(o.d_ipiv_)),
n_(o.n_),
lda_(o.lda_) {
o.lu_alloc_size_ = 0;
o.n_ = 0;
o.lda_ = 0;
}
LU& operator=(LU&& o) noexcept {
if (this != &o) {
solver_ctx_ = std::move(o.solver_ctx_);
d_lu_ = std::move(o.d_lu_);
lu_alloc_size_ = o.lu_alloc_size_;
d_ipiv_ = std::move(o.d_ipiv_);
n_ = o.n_;
lda_ = o.lda_;
o.lu_alloc_size_ = 0;
o.n_ = 0;
o.lda_ = 0;
}
return *this;
}
// ---- Factorization -------------------------------------------------------
/** Compute the LU factorization of A (host matrix, must be square). */
template <typename InputType>
LU& compute(const EigenBase<InputType>& A) {
eigen_assert(A.rows() == A.cols() && "LU requires a square matrix");
if (!begin_compute(A.rows())) return *this;
const PlainMatrix mat(A.derived());
lda_ = static_cast<int64_t>(mat.rows());
allocate_lu_storage();
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_lu_.get(), mat.data(), matrixBytes(), cudaMemcpyHostToDevice, solver_ctx_.stream_));
factorize();
return *this;
}
/** Compute the LU factorization from a device-resident matrix (D2D copy). */
LU& compute(const DeviceMatrix<Scalar>& d_A) {
eigen_assert(d_A.rows() == d_A.cols() && "LU requires a square matrix");
if (!begin_compute(d_A.rows())) return *this;
lda_ = static_cast<int64_t>(d_A.rows());
d_A.waitReady(solver_ctx_.stream_);
allocate_lu_storage();
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_lu_.get(), d_A.data(), matrixBytes(), cudaMemcpyDeviceToDevice, solver_ctx_.stream_));
factorize();
return *this;
}
/** Compute the LU factorization from a device matrix (move, no copy). */
LU& compute(DeviceMatrix<Scalar>&& d_A) {
eigen_assert(d_A.rows() == d_A.cols() && "LU requires a square matrix");
if (!begin_compute(d_A.rows())) return *this;
lda_ = static_cast<int64_t>(d_A.rows());
d_A.waitReady(solver_ctx_.stream_);
d_lu_ = internal::DeviceBuffer::adopt(static_cast<void*>(d_A.release()));
lu_alloc_size_ = matrixBytes();
factorize();
return *this;
}
// ---- Solve ---------------------------------------------------------------
/** Solve op(A) * X = B using the cached LU factorization (host → host).
*
* \param B Right-hand side (n x nrhs host matrix).
* \param op gpu::GpuOp::NoTrans (default), Trans, or ConjTrans.
*/
template <typename Rhs>
PlainMatrix solve(const MatrixBase<Rhs>& B, GpuOp op = GpuOp::NoTrans) const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success && "LU::solve called on a failed or uninitialized factorization");
eigen_assert(B.rows() == n_);
const PlainMatrix rhs(B);
const int64_t nrhs = static_cast<int64_t>(rhs.cols());
const int64_t ldb = static_cast<int64_t>(rhs.rows());
internal::DeviceBuffer d_x(matrixBytes(nrhs, ldb));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_x.get(), rhs.data(), matrixBytes(nrhs, ldb), cudaMemcpyHostToDevice, solver_ctx_.stream_));
DeviceMatrix<Scalar> d_X = solve_impl(nrhs, ldb, op, std::move(d_x));
PlainMatrix X(n_, B.cols());
int solve_info = 0;
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(X.data(), d_X.data(), matrixBytes(nrhs, ldb), cudaMemcpyDeviceToHost, solver_ctx_.stream_));
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(&solve_info, solver_ctx_.scratch_info(), sizeof(int),
cudaMemcpyDeviceToHost, solver_ctx_.stream_));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(solver_ctx_.stream_));
eigen_assert(solve_info == 0 && "cusolverDnXgetrs reported an error");
return X;
}
/** Solve op(A) * X = B with device-resident RHS. Returns immediately after
* enqueuing the solve; the sync_info()/info_ check at entry forces a host
* wait on the factorization status so a singular A causes a clean assert
* rather than a silently-bad result. */
DeviceMatrix<Scalar> solve(const DeviceMatrix<Scalar>& d_B, GpuOp op = GpuOp::NoTrans) const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success && "LU::solve called on a failed or uninitialized factorization");
eigen_assert(d_B.rows() == n_);
d_B.waitReady(solver_ctx_.stream_);
const int64_t nrhs = static_cast<int64_t>(d_B.cols());
const int64_t ldb = static_cast<int64_t>(d_B.rows());
internal::DeviceBuffer d_x(matrixBytes(nrhs, ldb));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_x.get(), d_B.data(), matrixBytes(nrhs, ldb), cudaMemcpyDeviceToDevice, solver_ctx_.stream_));
return solve_impl(nrhs, ldb, op, std::move(d_x));
}
// ---- Accessors -----------------------------------------------------------
ComputationInfo info() const { return solver_ctx_.info(); }
Index rows() const { return n_; }
Index cols() const { return n_; }
cudaStream_t stream() const { return solver_ctx_.stream_; }
private:
mutable internal::GpuSolverContext solver_ctx_;
internal::DeviceBuffer d_lu_;
size_t lu_alloc_size_ = 0;
internal::DeviceBuffer d_ipiv_;
int64_t n_ = 0;
int64_t lda_ = 0;
bool begin_compute(Index rows) {
n_ = rows;
solver_ctx_.info_ = InvalidInput;
if (n_ == 0) {
solver_ctx_.info_ = Success;
solver_ctx_.info_synced_ = true;
return false;
}
return true;
}
size_t matrixBytes() const { return matrixBytes(n_, lda_); }
static size_t matrixBytes(int64_t cols, int64_t ld) {
return static_cast<size_t>(ld) * static_cast<size_t>(cols) * sizeof(Scalar);
}
void allocate_lu_storage() {
size_t needed = matrixBytes();
if (needed > lu_alloc_size_) {
d_lu_ = internal::DeviceBuffer(needed);
lu_alloc_size_ = needed;
}
}
// Solve in place on `d_x` (which already holds B), then re-wrap as a typed
// DeviceMatrix carrying shape and a ready event. The release/adopt hop hands
// ownership of the raw cudaMalloc pointer from the untyped DeviceBuffer to
// the typed DeviceMatrix without copying.
DeviceMatrix<Scalar> solve_impl(int64_t nrhs, int64_t ldb, GpuOp op, internal::DeviceBuffer&& d_x) const {
constexpr cudaDataType_t dtype = internal::cusolver_data_type<Scalar>::value;
const cublasOperation_t trans = internal::to_cublas_op(op);
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrs(solver_ctx_.cusolver_, solver_ctx_.params_.p, trans, n_, nrhs, dtype,
d_lu_.get(), lda_, static_cast<const int64_t*>(d_ipiv_.get()), dtype,
d_x.get(), ldb, solver_ctx_.scratch_info()));
DeviceMatrix<Scalar> result =
DeviceMatrix<Scalar>::adopt(static_cast<Scalar*>(d_x.release()), n_, static_cast<Index>(nrhs));
result.recordReady(solver_ctx_.stream_);
return result;
}
void factorize() {
constexpr cudaDataType_t dtype = internal::cusolver_data_type<Scalar>::value;
const size_t ipiv_bytes = static_cast<size_t>(n_) * sizeof(int64_t);
solver_ctx_.mark_pending();
d_ipiv_ = internal::DeviceBuffer(ipiv_bytes);
size_t dev_ws_bytes = 0, host_ws_bytes = 0;
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf_bufferSize(solver_ctx_.cusolver_, solver_ctx_.params_.p, n_, n_, dtype,
d_lu_.get(), lda_, dtype, &dev_ws_bytes, &host_ws_bytes));
solver_ctx_.ensure_scratch(dev_ws_bytes);
solver_ctx_.h_workspace_.resize(host_ws_bytes);
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf(
solver_ctx_.cusolver_, solver_ctx_.params_.p, n_, n_, dtype, d_lu_.get(), lda_,
static_cast<int64_t*>(d_ipiv_.get()), dtype, solver_ctx_.scratch_workspace(), dev_ws_bytes,
host_ws_bytes > 0 ? solver_ctx_.h_workspace_.data() : nullptr, host_ws_bytes, solver_ctx_.scratch_info()));
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(&solver_ctx_.info_word(), solver_ctx_.scratch_info(), sizeof(int),
cudaMemcpyDeviceToHost, solver_ctx_.stream_));
}
};
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_LU_H