blob: 9b16cabe80e58d1d7828e542486e8f4fc9985d5c [file] [edit]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// GPU SVD decomposition using cuSOLVER (divide-and-conquer).
//
// Wraps cusolverDnXgesvd. Stores U, S, VT on device. Solve uses
// cuBLAS GEMM: X = VT^H * diag(D) * U^H * B.
//
// cuSOLVER returns VT (not V). We store and expose VT directly.
//
// Usage:
// SVD<double> svd(A, ComputeThinU | ComputeThinV);
// VectorXd S = svd.singularValues();
// MatrixXd U = svd.matrixU(); // m×k or m×m
// MatrixXd VT = svd.matrixVT(); // k×n or n×n (this is V^T)
// MatrixXd X = svd.solve(B); // pseudoinverse
// MatrixXd X = svd.solve(B, k); // truncated (top k triplets)
// MatrixXd X = svd.solve(B, 0.1); // Tikhonov regularized
#ifndef EIGEN_GPU_SVD_H
#define EIGEN_GPU_SVD_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./GpuSolverContext.h"
namespace Eigen {
namespace gpu {
template <typename Scalar_>
class SVD {
public:
using Scalar = Scalar_;
using RealScalar = typename NumTraits<Scalar>::Real;
using PlainMatrix = Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>;
using PlainVector = Eigen::Matrix<Scalar, Dynamic, 1>;
using RealVector = Eigen::Matrix<RealScalar, Dynamic, 1>;
SVD() = default;
template <typename InputType>
explicit SVD(const EigenBase<InputType>& A, unsigned int options = ComputeThinU | ComputeThinV) {
compute(A, options);
}
explicit SVD(const DeviceMatrix<Scalar>& d_A, unsigned int options = ComputeThinU | ComputeThinV) {
compute(d_A, options);
}
~SVD() = default;
SVD(const SVD&) = delete;
SVD& operator=(const SVD&) = delete;
SVD(SVD&& o) noexcept
: solver_ctx_(std::move(o.solver_ctx_)),
d_A_(std::move(o.d_A_)),
d_U_(std::move(o.d_U_)),
d_S_(std::move(o.d_S_)),
d_VT_(std::move(o.d_VT_)),
options_(o.options_),
m_(o.m_),
n_(o.n_),
lda_(o.lda_),
transposed_(o.transposed_) {
o.options_ = 0;
o.m_ = 0;
o.n_ = 0;
o.lda_ = 0;
o.transposed_ = false;
}
SVD& operator=(SVD&& o) noexcept {
if (this != &o) {
solver_ctx_ = std::move(o.solver_ctx_);
d_A_ = std::move(o.d_A_);
d_U_ = std::move(o.d_U_);
d_S_ = std::move(o.d_S_);
d_VT_ = std::move(o.d_VT_);
options_ = o.options_;
m_ = o.m_;
n_ = o.n_;
lda_ = o.lda_;
transposed_ = o.transposed_;
o.options_ = 0;
o.m_ = 0;
o.n_ = 0;
o.lda_ = 0;
o.transposed_ = false;
}
return *this;
}
// ---- Factorization -------------------------------------------------------
template <typename InputType>
SVD& compute(const EigenBase<InputType>& A, unsigned int options = ComputeThinU | ComputeThinV) {
// Upload to device, then delegate. The wide-matrix transpose runs on the
// GPU (via cublasXgeam) inside the device-input path; no host transpose.
return compute(DeviceMatrix<Scalar>::fromHost(A.derived(), solver_ctx_.stream_), options);
}
SVD& compute(const DeviceMatrix<Scalar>& d_A, unsigned int options = ComputeThinU | ComputeThinV) {
options_ = options;
m_ = d_A.rows();
n_ = d_A.cols();
lda_ = 0;
transposed_ = false;
if (m_ == 0 || n_ == 0) {
d_A_ = internal::DeviceBuffer();
d_U_ = internal::DeviceBuffer();
d_S_ = internal::DeviceBuffer();
d_VT_ = internal::DeviceBuffer();
solver_ctx_.info_ = Success;
solver_ctx_.info_synced_ = true;
return *this;
}
transposed_ = (m_ < n_);
d_A.waitReady(solver_ctx_.stream_);
if (transposed_) {
// Transpose on device via cuBLAS geam: d_A_ = A^H.
std::swap(m_, n_);
lda_ = m_;
const size_t mat_bytes = static_cast<size_t>(lda_) * static_cast<size_t>(n_) * sizeof(Scalar);
d_A_ = internal::DeviceBuffer(mat_bytes);
// geam: C(m×n) = alpha * op(A) + beta * op(B). beta=0, B=nullptr.
Scalar alpha_one(1), beta_zero(0);
EIGEN_CUBLAS_CHECK(internal::cublasXgeam(
solver_ctx_.cublas_, CUBLAS_OP_C, CUBLAS_OP_N, internal::to_blas_int(m_), internal::to_blas_int(n_),
&alpha_one, d_A.data(), internal::to_blas_int(d_A.rows()), &beta_zero, static_cast<const Scalar*>(nullptr),
internal::to_blas_int(m_), static_cast<Scalar*>(d_A_.get()), internal::to_blas_int(m_)));
} else {
lda_ = static_cast<int64_t>(d_A.rows());
const size_t mat_bytes = static_cast<size_t>(lda_) * static_cast<size_t>(n_) * sizeof(Scalar);
d_A_ = internal::DeviceBuffer(mat_bytes);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_A_.get(), d_A.data(), mat_bytes, cudaMemcpyDeviceToDevice, solver_ctx_.stream_));
}
factorize();
return *this;
}
// ---- Accessors -----------------------------------------------------------
ComputationInfo info() const { return solver_ctx_.info(); }
Index rows() const { return transposed_ ? n_ : m_; }
Index cols() const { return transposed_ ? m_ : n_; }
/** Singular values (always available). Downloads from device on each call. */
RealVector singularValues() const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success);
const Index k = (std::min)(m_, n_);
RealVector S(k);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpy(S.data(), d_S_.get(), static_cast<size_t>(k) * sizeof(RealScalar), cudaMemcpyDeviceToHost));
return S;
}
/** Left singular vectors U. */
PlainMatrix matrixU() const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success);
eigen_assert((options_ & (ComputeThinU | ComputeFullU)) && "matrixU() requires ComputeThinU or ComputeFullU");
const Index m_orig = transposed_ ? n_ : m_;
const Index n_orig = transposed_ ? m_ : n_;
const Index k = (std::min)(m_orig, n_orig);
if (!transposed_) {
const Index ucols = (options_ & ComputeFullU) ? m_ : k;
PlainMatrix U(m_, ucols);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpy(U.data(), d_U_.get(),
static_cast<size_t>(m_) * static_cast<size_t>(ucols) * sizeof(Scalar),
cudaMemcpyDeviceToHost));
return U;
} else {
const Index vtrows = (options_ & ComputeFullU) ? m_orig : k;
PlainMatrix VT_stored(vtrows, n_);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpy(VT_stored.data(), d_VT_.get(),
static_cast<size_t>(vtrows) * static_cast<size_t>(n_) * sizeof(Scalar),
cudaMemcpyDeviceToHost));
return VT_stored.adjoint();
}
}
/** Right singular vectors V (matches host JacobiSVD/BDCSVD). */
PlainMatrix matrixV() const { return matrixVT().adjoint(); }
/** Right singular vectors transposed V^T. */
PlainMatrix matrixVT() const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success);
eigen_assert((options_ & (ComputeThinV | ComputeFullV)) && "matrixVT() requires ComputeThinV or ComputeFullV");
const Index m_orig = transposed_ ? n_ : m_;
const Index n_orig = transposed_ ? m_ : n_;
const Index k = (std::min)(m_orig, n_orig);
if (!transposed_) {
const Index vtrows = (options_ & ComputeFullV) ? n_ : k;
PlainMatrix VT(vtrows, n_);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpy(VT.data(), d_VT_.get(),
static_cast<size_t>(vtrows) * static_cast<size_t>(n_) * sizeof(Scalar),
cudaMemcpyDeviceToHost));
return VT;
} else {
const Index ucols = (options_ & ComputeFullV) ? n_orig : k;
PlainMatrix U_stored(m_, ucols);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpy(U_stored.data(), d_U_.get(),
static_cast<size_t>(m_) * static_cast<size_t>(ucols) * sizeof(Scalar),
cudaMemcpyDeviceToHost));
return U_stored.adjoint();
}
}
// ---- Device-side accessors (zero-copy views; chain into cuBLAS without D2D) ---------
//
// These return non-owning DeviceMatrix views over the SVD's internal device storage.
// The view borrows the pointer: destruction does not free; the SVD object must outlive
// any view derived from it. For the common case (m >= n) all three accessors are pure
// metadata: zero kernel launches, zero allocations.
//
// For wide matrices (m < n, internally factored as A^H), original U and V^T are the
// adjoints of the stored buffers, so d_matrixU() / d_matrixVT() build them via a
// cublasXgeam into an owning temporary. d_singularValues() remains zero-copy.
/** Singular values as a k × 1 view on this solver's stream. */
DeviceMatrix<RealScalar> d_singularValues() const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success);
const Index k = (std::min)(m_, n_);
auto v = DeviceMatrix<RealScalar>::view(static_cast<RealScalar*>(d_S_.get()), k, 1);
v.recordReady(solver_ctx_.stream_);
return v;
}
/** Left singular vectors U as a DeviceMatrix on this solver's stream.
* For m >= n: zero-copy view. For m < n: owning (one cublasXgeam adjoint pass). */
DeviceMatrix<Scalar> d_matrixU() const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success);
eigen_assert((options_ & (ComputeThinU | ComputeFullU)) && "d_matrixU() requires ComputeThinU or ComputeFullU");
const Index m_orig = transposed_ ? n_ : m_;
const Index n_orig = transposed_ ? m_ : n_;
const Index k = (std::min)(m_orig, n_orig);
if (!transposed_) {
const Index ucols = (options_ & ComputeFullU) ? m_ : k;
auto v = DeviceMatrix<Scalar>::view(static_cast<Scalar*>(d_U_.get()), m_, ucols);
v.recordReady(solver_ctx_.stream_);
return v;
}
// transposed: U_orig = VT_stored^H -> conjugate-transpose via cublasXgeam.
const Index vtrows_stored = (options_ & ComputeFullU) ? n_ : k;
DeviceMatrix<Scalar> result(n_, vtrows_stored);
if (n_ > 0 && vtrows_stored > 0) {
Scalar alpha_one(1), beta_zero(0);
EIGEN_CUBLAS_CHECK(internal::cublasXgeam(
solver_ctx_.cublas_, CUBLAS_OP_C, CUBLAS_OP_N, internal::to_blas_int(n_),
internal::to_blas_int(vtrows_stored), &alpha_one, static_cast<const Scalar*>(d_VT_.get()),
internal::to_blas_int(vtrows_stored), &beta_zero, static_cast<const Scalar*>(nullptr),
internal::to_blas_int(n_), result.data(), internal::to_blas_int(n_)));
result.recordReady(solver_ctx_.stream_);
}
return result;
}
/** Right singular vectors transposed V^T as a DeviceMatrix on this solver's stream.
* For m >= n: zero-copy view. For m < n: owning (one cublasXgeam adjoint pass). */
DeviceMatrix<Scalar> d_matrixVT() const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success);
eigen_assert((options_ & (ComputeThinV | ComputeFullV)) && "d_matrixVT() requires ComputeThinV or ComputeFullV");
const Index m_orig = transposed_ ? n_ : m_;
const Index n_orig = transposed_ ? m_ : n_;
const Index k = (std::min)(m_orig, n_orig);
if (!transposed_) {
const Index vtrows = (options_ & ComputeFullV) ? n_ : k;
auto v = DeviceMatrix<Scalar>::view(static_cast<Scalar*>(d_VT_.get()), vtrows, n_);
v.recordReady(solver_ctx_.stream_);
return v;
}
// transposed: VT_orig = U_stored^H.
const Index ucols = (options_ & ComputeFullV) ? n_orig : k;
DeviceMatrix<Scalar> result(ucols, m_);
if (ucols > 0 && m_ > 0) {
Scalar alpha_one(1), beta_zero(0);
EIGEN_CUBLAS_CHECK(
internal::cublasXgeam(solver_ctx_.cublas_, CUBLAS_OP_C, CUBLAS_OP_N, internal::to_blas_int(ucols),
internal::to_blas_int(m_), &alpha_one, static_cast<const Scalar*>(d_U_.get()),
internal::to_blas_int(m_), &beta_zero, static_cast<const Scalar*>(nullptr),
internal::to_blas_int(ucols), result.data(), internal::to_blas_int(ucols)));
result.recordReady(solver_ctx_.stream_);
}
return result;
}
/** Number of singular values above threshold. */
Index rank(RealScalar threshold = RealScalar(-1)) const {
RealVector S = singularValues();
if (S.size() == 0) return 0;
if (threshold < 0) {
threshold = (std::max)(m_, n_) * S(0) * NumTraits<RealScalar>::epsilon();
}
return (S.array() > threshold).count();
}
// ---- Solve ---------------------------------------------------------------
/** Pseudoinverse solve: X = V * diag(1/S) * U^H * B. */
template <typename Rhs>
PlainMatrix solve(const MatrixBase<Rhs>& B) const {
return solve_impl(B, (std::min)(m_, n_), RealScalar(0));
}
/** Truncated solve: use only top trunc singular triplets. */
template <typename Rhs>
PlainMatrix solve(const MatrixBase<Rhs>& B, Index trunc) const {
eigen_assert(trunc > 0 && trunc <= (std::min)(m_, n_));
return solve_impl(B, trunc, RealScalar(0));
}
/** Tikhonov-regularized solve: D_ii = S_i / (S_i^2 + lambda^2). */
template <typename Rhs>
PlainMatrix solve(const MatrixBase<Rhs>& B, RealScalar lambda) const {
eigen_assert(lambda > 0);
return solve_impl(B, (std::min)(m_, n_), lambda);
}
cudaStream_t stream() const { return solver_ctx_.stream_; }
private:
mutable internal::GpuSolverContext solver_ctx_;
internal::DeviceBuffer d_A_;
internal::DeviceBuffer d_U_;
internal::DeviceBuffer d_S_;
internal::DeviceBuffer d_VT_;
unsigned int options_ = 0;
int64_t m_ = 0;
int64_t n_ = 0;
int64_t lda_ = 0;
bool transposed_ = false;
// Swap U↔V flags for the transposed case.
static unsigned int swap_uv_options(unsigned int opts) {
unsigned int result = 0;
if (opts & ComputeThinU) result |= ComputeThinV;
if (opts & ComputeFullU) result |= ComputeFullV;
if (opts & ComputeThinV) result |= ComputeThinU;
if (opts & ComputeFullV) result |= ComputeFullU;
return result;
}
static signed char jobu(unsigned int opts) {
if (opts & ComputeFullU) return 'A';
if (opts & ComputeThinU) return 'S';
return 'N';
}
static signed char jobvt(unsigned int opts) {
if (opts & ComputeFullV) return 'A';
if (opts & ComputeThinV) return 'S';
return 'N';
}
void factorize() {
constexpr cudaDataType_t dtype = internal::cusolver_data_type<Scalar>::value;
constexpr cudaDataType_t rtype = internal::cuda_data_type<RealScalar>::value;
const Index k = (std::min)(m_, n_);
solver_ctx_.mark_pending();
d_S_ = internal::DeviceBuffer(static_cast<size_t>(k) * sizeof(RealScalar));
const unsigned int int_opts = transposed_ ? swap_uv_options(options_) : options_;
const Index ucols = (int_opts & ComputeFullU) ? m_ : ((int_opts & ComputeThinU) ? k : 0);
const Index vtrows = (int_opts & ComputeFullV) ? n_ : ((int_opts & ComputeThinV) ? k : 0);
const int64_t ldu = m_;
const int64_t ldvt = vtrows > 0 ? vtrows : 1;
d_U_ = internal::DeviceBuffer();
d_VT_ = internal::DeviceBuffer();
if (ucols > 0) d_U_ = internal::DeviceBuffer(static_cast<size_t>(m_) * static_cast<size_t>(ucols) * sizeof(Scalar));
if (vtrows > 0)
d_VT_ = internal::DeviceBuffer(static_cast<size_t>(vtrows) * static_cast<size_t>(n_) * sizeof(Scalar));
eigen_assert(m_ >= n_ && "Internal error: m_ < n_ should have been handled by transpose in compute()");
size_t dev_ws = 0, host_ws = 0;
EIGEN_CUSOLVER_CHECK(cusolverDnXgesvd_bufferSize(
solver_ctx_.cusolver_, solver_ctx_.params_.p, jobu(int_opts), jobvt(int_opts), m_, n_, dtype, d_A_.get(), lda_,
rtype, d_S_.get(), dtype, ucols > 0 ? d_U_.get() : nullptr, ldu, dtype, vtrows > 0 ? d_VT_.get() : nullptr,
ldvt, dtype, &dev_ws, &host_ws));
solver_ctx_.ensure_scratch(dev_ws);
solver_ctx_.h_workspace_.resize(host_ws);
EIGEN_CUSOLVER_CHECK(
cusolverDnXgesvd(solver_ctx_.cusolver_, solver_ctx_.params_.p, jobu(int_opts), jobvt(int_opts), m_, n_, dtype,
d_A_.get(), lda_, rtype, d_S_.get(), dtype, ucols > 0 ? d_U_.get() : nullptr, ldu, dtype,
vtrows > 0 ? d_VT_.get() : nullptr, ldvt, dtype, solver_ctx_.scratch_workspace(), dev_ws,
host_ws > 0 ? solver_ctx_.h_workspace_.data() : nullptr, host_ws, solver_ctx_.scratch_info()));
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(&solver_ctx_.info_word(), solver_ctx_.scratch_info(), sizeof(int),
cudaMemcpyDeviceToHost, solver_ctx_.stream_));
}
template <typename Rhs>
PlainMatrix solve_impl(const MatrixBase<Rhs>& B, Index trunc, RealScalar lambda) const {
solver_ctx_.sync_info();
eigen_assert(solver_ctx_.info_ == Success && "SVD::solve called on a failed or uninitialized decomposition");
eigen_assert((options_ & (ComputeThinU | ComputeFullU)) && "solve requires U");
eigen_assert((options_ & (ComputeThinV | ComputeFullV)) && "solve requires V");
const Index m_orig = transposed_ ? n_ : m_;
const Index n_orig = transposed_ ? m_ : n_;
eigen_assert(B.rows() == m_orig);
const Index k = (std::min)(m_, n_);
const Index kk = (std::min)(trunc, k);
const Index nrhs = B.cols();
// Empty problem: no rank, no RHS, or zero domain -> result is the zero matrix.
// Returning early avoids reading S(0) below when k == 0 and prevents zero-extent
// GEMM/dgmm calls.
if (kk == 0 || nrhs == 0 || n_orig == 0) {
return PlainMatrix::Zero(n_orig, nrhs);
}
// Enqueue both transfers on solver_ctx_.stream_ in one batch and sync once. Issuing the
// B upload before reading S means B's H2D is already in flight while we wait for
// gesvd-then-S-D2H, instead of two back-to-back blocking syncs.
const PlainMatrix rhs(B);
internal::DeviceBuffer d_B(static_cast<size_t>(m_orig) * static_cast<size_t>(nrhs) * sizeof(Scalar));
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_B.get(), rhs.data(),
static_cast<size_t>(m_orig) * static_cast<size_t>(nrhs) * sizeof(Scalar),
cudaMemcpyHostToDevice, solver_ctx_.stream_));
RealVector S(k);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(S.data(), d_S_.get(), static_cast<size_t>(k) * sizeof(RealScalar),
cudaMemcpyDeviceToHost, solver_ctx_.stream_));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(solver_ctx_.stream_));
// Typed pointers for device buffers.
auto* U_dev = static_cast<const Scalar*>(d_U_.get());
auto* VT_dev = static_cast<const Scalar*>(d_VT_.get());
auto* B_dev = static_cast<const Scalar*>(d_B.get());
Scalar scalars[2] = {Scalar(1), Scalar(0)};
// Step 1: tmp = U_orig^H * B (kk × nrhs).
internal::DeviceBuffer d_tmp(static_cast<size_t>(kk) * static_cast<size_t>(nrhs) * sizeof(Scalar));
auto* tmp_dev = static_cast<Scalar*>(d_tmp.get());
if (!transposed_) {
EIGEN_CUBLAS_CHECK(internal::cublasXgemm(solver_ctx_.cublas_, CUBLAS_OP_C, CUBLAS_OP_N, internal::to_blas_int(kk),
internal::to_blas_int(nrhs), internal::to_blas_int(m_), &scalars[0],
U_dev, internal::to_blas_int(m_), B_dev, internal::to_blas_int(m_orig),
&scalars[1], tmp_dev, internal::to_blas_int(kk)));
} else {
const Index vtrows_stored = (swap_uv_options(options_) & ComputeFullV) ? n_ : k;
EIGEN_CUBLAS_CHECK(internal::cublasXgemm(
solver_ctx_.cublas_, CUBLAS_OP_N, CUBLAS_OP_N, internal::to_blas_int(kk), internal::to_blas_int(nrhs),
internal::to_blas_int(m_orig), &scalars[0], VT_dev, internal::to_blas_int(vtrows_stored), B_dev,
internal::to_blas_int(m_orig), &scalars[1], tmp_dev, internal::to_blas_int(kk)));
}
// Step 2: Apply diag(D) to tmp on device via cublasXdgmm.
// D is built on host from the (small, k-entry) singular-values vector S
// and uploaded to a small device buffer; tmp itself stays on device.
//
// For lambda == 0 we mirror Eigen's SVDBase::_solve_impl: drop singular
// values below S(0) * k * eps (numerical-rank truncation), so this
// pseudoinverse solve agrees with CPU BDCSVD::solve on near-singular A.
// dgmm wants the diagonal in the matrix scalar type — for complex Scalar
// the diagonal is still real, so we build the real values then cast.
const RealScalar drop_threshold = S(0) * RealScalar(k) * NumTraits<RealScalar>::epsilon();
auto S_head = S.head(kk).array();
PlainVector D(kk);
if (lambda == RealScalar(0)) {
D = (S_head > drop_threshold).select(S_head.inverse(), RealScalar(0)).matrix().template cast<Scalar>();
} else {
D = (S_head / (S_head.square() + lambda * lambda)).matrix().template cast<Scalar>();
}
internal::DeviceBuffer d_D(static_cast<size_t>(kk) * sizeof(Scalar));
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_D.get(), D.data(), static_cast<size_t>(kk) * sizeof(Scalar),
cudaMemcpyHostToDevice, solver_ctx_.stream_));
EIGEN_CUBLAS_CHECK(internal::cublasXdgmm(
solver_ctx_.cublas_, CUBLAS_SIDE_LEFT, internal::to_blas_int(kk), internal::to_blas_int(nrhs), tmp_dev,
internal::to_blas_int(kk), static_cast<const Scalar*>(d_D.get()), 1, tmp_dev, internal::to_blas_int(kk)));
// Step 3: X = V_orig * tmp (n_orig × nrhs).
PlainMatrix X(n_orig, nrhs);
internal::DeviceBuffer d_X(static_cast<size_t>(n_orig) * static_cast<size_t>(nrhs) * sizeof(Scalar));
auto* X_dev = static_cast<Scalar*>(d_X.get());
if (!transposed_) {
const Index vtrows = (options_ & ComputeFullV) ? n_ : k;
EIGEN_CUBLAS_CHECK(internal::cublasXgemm(
solver_ctx_.cublas_, CUBLAS_OP_C, CUBLAS_OP_N, internal::to_blas_int(n_orig), internal::to_blas_int(nrhs),
internal::to_blas_int(kk), &scalars[0], VT_dev, internal::to_blas_int(vtrows), tmp_dev,
internal::to_blas_int(kk), &scalars[1], X_dev, internal::to_blas_int(n_orig)));
} else {
EIGEN_CUBLAS_CHECK(internal::cublasXgemm(
solver_ctx_.cublas_, CUBLAS_OP_N, CUBLAS_OP_N, internal::to_blas_int(n_orig), internal::to_blas_int(nrhs),
internal::to_blas_int(kk), &scalars[0], U_dev, internal::to_blas_int(m_), tmp_dev, internal::to_blas_int(kk),
&scalars[1], X_dev, internal::to_blas_int(n_orig)));
}
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(X.data(), d_X.get(),
static_cast<size_t>(n_orig) * static_cast<size_t>(nrhs) * sizeof(Scalar),
cudaMemcpyDeviceToHost, solver_ctx_.stream_));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(solver_ctx_.stream_));
return X;
}
};
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_SVD_H