blob: afc941becbfbef653228572f30245f348e9408d2 [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Eigen Authors
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// GPU partial-pivoting LU decomposition using cuSOLVER.
//
// Wraps cusolverDnXgetrf (factorization) and cusolverDnXgetrs (solve).
// The factored LU matrix and pivot array are kept in device memory for the
// lifetime of the object, so repeated solves only transfer the RHS/solution.
//
// Requires CUDA 11.0+ (cusolverDnX generic API).
//
// Usage:
// gpu::LU<double> lu(A); // upload A, getrf, LU+ipiv on device
// if (lu.info() != Success) { ... }
// MatrixXd x = lu.solve(b); // getrs NoTrans, only b transferred
// MatrixXd xt = lu.solve(b, gpu::GpuOp::Trans); // A^T x = b
#ifndef EIGEN_GPU_LU_H
#define EIGEN_GPU_LU_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./CuBlasSupport.h"
#include "./CuSolverSupport.h"
#include <vector>
namespace Eigen {
namespace gpu {
/** \ingroup GPU_Module
* \class LU
* \brief GPU LU decomposition with partial pivoting via cuSOLVER
*
* \tparam Scalar_ Element type: float, double, complex<float>, complex<double>
*
* Decomposes a square matrix A = P L U on the GPU and retains the factored
* matrix and pivot array in device memory. Solves A*X=B, A^T*X=B, or
* A^H*X=B by passing the appropriate gpu::GpuOp.
*
* Each LU object owns a dedicated CUDA stream and cuSOLVER handle.
*/
template <typename Scalar_>
class LU {
public:
using Scalar = Scalar_;
using RealScalar = typename NumTraits<Scalar>::Real;
using PlainMatrix = Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>;
// ---- Construction / destruction ------------------------------------------
LU() { init_context(); }
template <typename InputType>
explicit LU(const EigenBase<InputType>& A) : LU() {
compute(A);
}
~LU() {
// Ignore errors here: dtors are noexcept, and EIGEN_CUSOLVER_CHECK /
// EIGEN_CUDA_RUNTIME_CHECK are eigen_assert-based — firing one from a
// noexcept dtor terminates the program. The trailing cudaFree() of the
// device buffers (via internal::DeviceBuffer::~DeviceBuffer) is sync.
if (handle_) (void)cusolverDnDestroy(handle_);
if (stream_) (void)cudaStreamDestroy(stream_);
}
LU(const LU&) = delete;
LU& operator=(const LU&) = delete;
LU(LU&& o) noexcept
: stream_(o.stream_),
handle_(o.handle_),
params_(std::move(o.params_)),
d_lu_(std::move(o.d_lu_)),
lu_alloc_size_(o.lu_alloc_size_),
d_ipiv_(std::move(o.d_ipiv_)),
d_scratch_(std::move(o.d_scratch_)),
scratch_size_(o.scratch_size_),
h_workspace_(std::move(o.h_workspace_)),
n_(o.n_),
lda_(o.lda_),
info_(o.info_),
pinned_info_(std::move(o.pinned_info_)),
info_synced_(o.info_synced_) {
o.stream_ = nullptr;
o.handle_ = nullptr;
o.lu_alloc_size_ = 0;
o.scratch_size_ = 0;
o.n_ = 0;
o.lda_ = 0;
o.info_ = InvalidInput;
o.info_synced_ = true;
}
LU& operator=(LU&& o) noexcept {
if (this != &o) {
// Drain pending work on the old stream before tearing it down so the
// upcoming move of d_lu_/d_ipiv_/d_scratch_ doesn't free buffers that an
// in-flight kernel is still touching. The asymmetric move-ctor doesn't
// need this — it constructs onto uninitialized storage.
if (stream_) (void)cudaStreamSynchronize(stream_);
if (handle_) (void)cusolverDnDestroy(handle_);
if (stream_) (void)cudaStreamDestroy(stream_);
stream_ = o.stream_;
handle_ = o.handle_;
params_ = std::move(o.params_);
d_lu_ = std::move(o.d_lu_);
lu_alloc_size_ = o.lu_alloc_size_;
d_ipiv_ = std::move(o.d_ipiv_);
d_scratch_ = std::move(o.d_scratch_);
scratch_size_ = o.scratch_size_;
h_workspace_ = std::move(o.h_workspace_);
n_ = o.n_;
lda_ = o.lda_;
info_ = o.info_;
pinned_info_ = std::move(o.pinned_info_);
info_synced_ = o.info_synced_;
o.stream_ = nullptr;
o.handle_ = nullptr;
o.lu_alloc_size_ = 0;
o.scratch_size_ = 0;
o.n_ = 0;
o.lda_ = 0;
o.info_ = InvalidInput;
o.info_synced_ = true;
}
return *this;
}
// ---- Factorization -------------------------------------------------------
/** Compute the LU factorization of A (host matrix, must be square). */
template <typename InputType>
LU& compute(const EigenBase<InputType>& A) {
eigen_assert(A.rows() == A.cols() && "LU requires a square matrix");
if (!begin_compute(A.rows())) return *this;
const PlainMatrix mat(A.derived());
lda_ = static_cast<int64_t>(mat.outerStride());
allocate_lu_storage();
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_lu_.get(), mat.data(), matrixBytes(), cudaMemcpyHostToDevice, stream_));
factorize();
return *this;
}
/** Compute the LU factorization from a device-resident matrix (D2D copy). */
LU& compute(const DeviceMatrix<Scalar>& d_A) {
eigen_assert(d_A.rows() == d_A.cols() && "LU requires a square matrix");
if (!begin_compute(d_A.rows())) return *this;
lda_ = static_cast<int64_t>(d_A.outerStride());
d_A.waitReady(stream_);
allocate_lu_storage();
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_lu_.get(), d_A.data(), matrixBytes(), cudaMemcpyDeviceToDevice, stream_));
factorize();
return *this;
}
/** Compute the LU factorization from a device matrix (move, no copy). */
LU& compute(DeviceMatrix<Scalar>&& d_A) {
eigen_assert(d_A.rows() == d_A.cols() && "LU requires a square matrix");
if (!begin_compute(d_A.rows())) return *this;
lda_ = static_cast<int64_t>(d_A.outerStride());
d_A.waitReady(stream_);
d_lu_ = internal::DeviceBuffer::adopt(static_cast<void*>(d_A.release()));
lu_alloc_size_ = matrixBytes();
factorize();
return *this;
}
// ---- Solve ---------------------------------------------------------------
/** Solve op(A) * X = B using the cached LU factorization (host → host).
*
* \param B Right-hand side (n x nrhs host matrix).
* \param op gpu::GpuOp::NoTrans (default), Trans, or ConjTrans.
*/
template <typename Rhs>
PlainMatrix solve(const MatrixBase<Rhs>& B, GpuOp op = GpuOp::NoTrans) const {
const_cast<LU*>(this)->sync_info();
eigen_assert(info_ == Success && "LU::solve called on a failed or uninitialized factorization");
eigen_assert(B.rows() == n_);
const PlainMatrix rhs(B);
const int64_t nrhs = static_cast<int64_t>(rhs.cols());
const int64_t ldb = static_cast<int64_t>(rhs.outerStride());
internal::DeviceBuffer d_x(matrixBytes(nrhs, ldb));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_x.get(), rhs.data(), matrixBytes(nrhs, ldb), cudaMemcpyHostToDevice, stream_));
DeviceMatrix<Scalar> d_X = solve_impl(nrhs, ldb, op, std::move(d_x));
PlainMatrix X(n_, B.cols());
int solve_info = 0;
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(X.data(), d_X.data(), matrixBytes(nrhs, ldb), cudaMemcpyDeviceToHost, stream_));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(&solve_info, scratch_info(), sizeof(int), cudaMemcpyDeviceToHost, stream_));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(stream_));
eigen_assert(solve_info == 0 && "cusolverDnXgetrs reported an error");
return X;
}
/** Solve op(A) * X = B with device-resident RHS. Fully async. */
DeviceMatrix<Scalar> solve(const DeviceMatrix<Scalar>& d_B, GpuOp op = GpuOp::NoTrans) const {
eigen_assert(d_B.rows() == n_);
d_B.waitReady(stream_);
const int64_t nrhs = static_cast<int64_t>(d_B.cols());
const int64_t ldb = static_cast<int64_t>(d_B.outerStride());
internal::DeviceBuffer d_x(matrixBytes(nrhs, ldb));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_x.get(), d_B.data(), matrixBytes(nrhs, ldb), cudaMemcpyDeviceToDevice, stream_));
return solve_impl(nrhs, ldb, op, std::move(d_x));
}
// ---- Accessors -----------------------------------------------------------
/** Lazily synchronizes the stream on first call after compute(). */
ComputationInfo info() const {
const_cast<LU*>(this)->sync_info();
return info_;
}
Index rows() const { return n_; }
Index cols() const { return n_; }
cudaStream_t stream() const { return stream_; }
private:
cudaStream_t stream_ = nullptr;
cusolverDnHandle_t handle_ = nullptr;
internal::CusolverParams params_; // cuSOLVER params (created once, reused)
internal::DeviceBuffer d_lu_; // LU factors on device (grows, never shrinks)
size_t lu_alloc_size_ = 0; // current d_lu_ allocation size
internal::DeviceBuffer d_ipiv_; // pivot indices (int64_t) on device
internal::DeviceBuffer d_scratch_; // combined workspace + info word (grows, never shrinks)
size_t scratch_size_ = 0; // current scratch allocation size
std::vector<char> h_workspace_; // host workspace (kept alive until next compute)
int64_t n_ = 0;
int64_t lda_ = 0;
ComputationInfo info_ = InvalidInput;
internal::PinnedHostBuffer pinned_info_{sizeof(int)}; // pinned host memory for async D2H
bool info_synced_ = true; // has the stream been synced for info?
int& info_word() { return *static_cast<int*>(pinned_info_.get()); }
int info_word() const { return *static_cast<const int*>(pinned_info_.get()); }
bool begin_compute(Index rows) {
n_ = rows;
info_ = InvalidInput;
if (n_ == 0) {
info_ = Success;
info_synced_ = true;
return false;
}
return true;
}
size_t matrixBytes() const { return matrixBytes(n_, lda_); }
static size_t matrixBytes(int64_t cols, int64_t outer_stride) {
return static_cast<size_t>(outer_stride) * static_cast<size_t>(cols) * sizeof(Scalar);
}
void allocate_lu_storage() {
size_t needed = matrixBytes();
if (needed > lu_alloc_size_) {
d_lu_ = internal::DeviceBuffer(needed);
lu_alloc_size_ = needed;
}
}
// Scratch layout: [ workspace (aligned) | info_word (sizeof(int)) ].
// Workspace size is rounded up to 16 bytes so the info word lands aligned.
static constexpr size_t kInfoBytes = sizeof(int);
static constexpr size_t kScratchAlign = 16;
static size_t scratchBytesFor(size_t workspace_bytes) {
workspace_bytes = (workspace_bytes + kScratchAlign - 1) & ~(kScratchAlign - 1);
return workspace_bytes + kInfoBytes;
}
// Ensure d_scratch_ is at least `bytes`. Grows but never shrinks.
// Syncs the stream before reallocating to avoid freeing memory that
// async kernels may still be using.
void ensure_scratch(size_t bytes) {
if (bytes > scratch_size_) {
if (d_scratch_) EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(stream_));
d_scratch_ = internal::DeviceBuffer(bytes);
scratch_size_ = bytes;
}
}
void* scratch_workspace() const { return d_scratch_.get(); }
int* scratch_info() const {
eigen_assert(d_scratch_ && scratch_size_ >= kInfoBytes);
return reinterpret_cast<int*>(static_cast<char*>(d_scratch_.get()) + scratch_size_ - kInfoBytes);
}
// Solve in place on `d_x` (which already holds B), then re-wrap as a
// typed DeviceMatrix carrying shape and a ready event. The release/adopt
// hop just hands ownership of the raw cudaMalloc pointer from the untyped
// DeviceBuffer to the typed DeviceMatrix without copying.
DeviceMatrix<Scalar> solve_impl(int64_t nrhs, int64_t ldb, GpuOp op, internal::DeviceBuffer&& d_x) const {
constexpr cudaDataType_t dtype = internal::cusolver_data_type<Scalar>::value;
const cublasOperation_t trans = internal::to_cublas_op(op);
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrs(handle_, params_.p, trans, n_, nrhs, dtype, d_lu_.get(), lda_,
static_cast<const int64_t*>(d_ipiv_.get()), dtype, d_x.get(), ldb,
scratch_info()));
DeviceMatrix<Scalar> result =
DeviceMatrix<Scalar>::adopt(static_cast<Scalar*>(d_x.release()), n_, static_cast<Index>(nrhs));
result.recordReady(stream_);
return result;
}
void init_context() {
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamCreate(&stream_));
EIGEN_CUSOLVER_CHECK(cusolverDnCreate(&handle_));
EIGEN_CUSOLVER_CHECK(cusolverDnSetStream(handle_, stream_));
ensure_scratch(kInfoBytes);
}
void sync_info() {
if (!info_synced_) {
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(stream_));
info_ = (info_word() == 0) ? Success : NumericalIssue;
info_synced_ = true;
}
}
// Run cusolverDnXgetrf on d_lu_ (already on device). Allocates d_ipiv_.
// Enqueues factorization + async info download. Does NOT sync.
// Workspaces are stored as members to ensure they outlive the async kernels.
void factorize() {
constexpr cudaDataType_t dtype = internal::cusolver_data_type<Scalar>::value;
const size_t ipiv_bytes = static_cast<size_t>(n_) * sizeof(int64_t);
info_synced_ = false;
info_ = InvalidInput;
d_ipiv_ = internal::DeviceBuffer(ipiv_bytes);
size_t dev_ws_bytes = 0, host_ws_bytes = 0;
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf_bufferSize(handle_, params_.p, n_, n_, dtype, d_lu_.get(), lda_, dtype,
&dev_ws_bytes, &host_ws_bytes));
ensure_scratch(scratchBytesFor(dev_ws_bytes));
h_workspace_.resize(host_ws_bytes);
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf(handle_, params_.p, n_, n_, dtype, d_lu_.get(), lda_,
static_cast<int64_t*>(d_ipiv_.get()), dtype, scratch_workspace(),
dev_ws_bytes, host_ws_bytes > 0 ? h_workspace_.data() : nullptr,
host_ws_bytes, scratch_info()));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(&info_word(), scratch_info(), sizeof(int), cudaMemcpyDeviceToHost, stream_));
}
};
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_LU_H