blob: 79db3c2115df796942c4362473c81c5d3358b5c3 [file] [edit]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
#ifndef EIGEN_GPU_DEVICE_DISPATCH_H
#define EIGEN_GPU_DEVICE_DISPATCH_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include <climits>
#include <cstdint>
#include <limits>
#include "./DeviceExpr.h"
#include "./DeviceBlasExpr.h"
#include "./DeviceSolverExpr.h"
#include "./GpuContext.h"
#include "./CuSolverSupport.h"
namespace Eigen {
namespace gpu {
namespace internal {
template <typename Scalar>
bool aliases_device_memory(const DeviceMatrix<Scalar>& a, const DeviceMatrix<Scalar>& b) {
return a.data() != nullptr && a.data() == b.data();
}
template <typename Lhs, typename Rhs>
void dispatch_gemm(
Context& ctx, DeviceMatrix<typename device_expr_traits<Lhs>::scalar_type>& dst, const GemmExpr<Lhs, Rhs>& expr,
typename device_expr_traits<Lhs>::scalar_type beta_val,
typename device_expr_traits<Lhs>::scalar_type alpha_scale = typename device_expr_traits<Lhs>::scalar_type(1)) {
using Scalar = typename device_expr_traits<Lhs>::scalar_type;
using traits_lhs = device_expr_traits<Lhs>;
using traits_rhs = device_expr_traits<Rhs>;
const DeviceMatrix<Scalar>& A = traits_lhs::matrix(expr.lhs());
const DeviceMatrix<Scalar>& B = traits_rhs::matrix(expr.rhs());
// cuBLAS GEMM: C must not alias A or B (undefined behavior).
eigen_assert(dst.data() != A.data() && "GEMM: output aliases left operand (use a temporary)");
eigen_assert(dst.data() != B.data() && "GEMM: output aliases right operand (use a temporary)");
constexpr cublasOperation_t transA = to_cublas_op(traits_lhs::op);
constexpr cublasOperation_t transB = to_cublas_op(traits_rhs::op);
const int64_t m = (traits_lhs::op == GpuOp::NoTrans) ? A.rows() : A.cols();
const int64_t k = (traits_lhs::op == GpuOp::NoTrans) ? A.cols() : A.rows();
const int64_t n = (traits_rhs::op == GpuOp::NoTrans) ? B.cols() : B.rows();
const int64_t rhs_k = (traits_rhs::op == GpuOp::NoTrans) ? B.rows() : B.cols();
eigen_assert(k == rhs_k && "DeviceMatrix GEMM dimension mismatch");
const int64_t lda = A.rows();
const int64_t ldb = B.rows();
eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix GEMM destination aliases lhs operand");
eigen_assert(!aliases_device_memory(dst, B) && "DeviceMatrix GEMM destination aliases rhs operand");
if (!dst.empty()) {
dst.waitReady(ctx.stream());
}
const bool resized = dst.empty() || dst.rows() != m || dst.cols() != n;
if (resized) {
dst.resize(m, n);
}
const int64_t ldc = dst.rows();
Scalar alpha_local = alpha_scale * traits_lhs::alpha(expr.lhs()) * traits_rhs::alpha(expr.rhs());
A.waitReady(ctx.stream());
B.waitReady(ctx.stream());
if (resized && beta_val != Scalar(0) && dst.sizeInBytes() > 0) {
EIGEN_CUDA_RUNTIME_CHECK(cudaMemsetAsync(dst.data(), 0, dst.sizeInBytes(), ctx.stream()));
}
eigen_assert(m <= INT_MAX && n <= INT_MAX && k <= INT_MAX && lda <= INT_MAX && ldb <= INT_MAX && ldc <= INT_MAX &&
"cublasXgemm dimensions exceed int range");
// cuBLAS reads alpha and beta through host pointers. Store them in an array
// to prevent the compiler from eliding their stack slots — clang and MSVC
// at -O1+ otherwise optimise away the stores for complex types, leaving
// cuBLAS with a dangling pointer.
Scalar scalars[2] = {alpha_local, beta_val};
EIGEN_CUBLAS_CHECK(cublasXgemm(ctx.cublasHandle(), transA, transB, static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k), &scalars[0], A.data(), static_cast<int>(lda), B.data(),
static_cast<int>(ldb), &scalars[1], dst.data(), static_cast<int>(ldc)));
dst.recordReady(ctx.stream());
}
template <typename Scalar, int UpLo>
void dispatch_llt_solve(Context& ctx, DeviceMatrix<Scalar>& dst, const LltSolveExpr<Scalar, UpLo>& expr) {
const DeviceMatrix<Scalar>& A = expr.matrix();
const DeviceMatrix<Scalar>& B = expr.rhs();
eigen_assert(A.rows() == A.cols() && "LLT requires a square matrix");
eigen_assert(B.rows() == A.rows() && "LLT solve: RHS rows must match matrix size");
const int64_t n = static_cast<int64_t>(A.rows());
const int64_t nrhs = static_cast<int64_t>(B.cols());
if (n == 0 || nrhs == 0) {
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(n, B.cols());
return;
}
A.waitReady(ctx.stream());
B.waitReady(ctx.stream());
if (!dst.empty()) dst.waitReady(ctx.stream());
constexpr cudaDataType_t dtype = cuda_data_type<Scalar>::value;
constexpr cublasFillMode_t uplo = cusolver_fill_mode<UpLo>::value;
const int64_t lda = static_cast<int64_t>(A.rows());
const int64_t ldb = static_cast<int64_t>(B.rows());
const size_t mat_bytes = static_cast<size_t>(lda) * static_cast<size_t>(n) * sizeof(Scalar);
const size_t rhs_bytes = static_cast<size_t>(ldb) * static_cast<size_t>(nrhs) * sizeof(Scalar);
DeviceBuffer d_factor(mat_bytes);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_factor.get(), A.data(), mat_bytes, cudaMemcpyDeviceToDevice, ctx.stream()));
// Two info slots (potrf, potrs) so we can queue both kernels back-to-back
// and host-sync once at the end. If potrf fails, potrs runs on garbage but
// the assert fires after the single sync — saving a round trip.
PinnedHostBuffer h_info(2 * sizeof(int));
int* info_words = static_cast<int*>(h_info.get());
CusolverParams params;
DeviceBuffer d_info(2 * sizeof(int));
int* d_info_potrf = static_cast<int*>(d_info.get());
int* d_info_potrs = d_info_potrf + 1;
size_t dev_ws = 0, host_ws = 0;
EIGEN_CUSOLVER_CHECK(cusolverDnXpotrf_bufferSize(ctx.cusolverHandle(), params.p, uplo, n, dtype, d_factor.get(), lda,
dtype, &dev_ws, &host_ws));
DeviceBuffer d_workspace(dev_ws);
std::vector<char> h_workspace(host_ws);
EIGEN_CUSOLVER_CHECK(cusolverDnXpotrf(ctx.cusolverHandle(), params.p, uplo, n, dtype, d_factor.get(), lda, dtype,
d_workspace.get(), dev_ws, host_ws > 0 ? h_workspace.data() : nullptr, host_ws,
d_info_potrf));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(&info_words[0], d_info_potrf, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
dst.resize(n, B.cols());
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(dst.data(), B.data(), rhs_bytes, cudaMemcpyDeviceToDevice, ctx.stream()));
EIGEN_CUSOLVER_CHECK(cusolverDnXpotrs(ctx.cusolverHandle(), params.p, uplo, n, nrhs, dtype, d_factor.get(), lda,
dtype, dst.data(), static_cast<int64_t>(dst.rows()), d_info_potrs));
// Workspace locals must outlive the async kernels — sync before they unwind.
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(&info_words[1], d_info_potrs, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx.stream()));
eigen_assert(info_words[0] == 0 && "cuSOLVER LLT factorization failed (matrix not positive definite)");
eigen_assert(info_words[1] == 0 && "cuSOLVER LLT solve failed");
dst.recordReady(ctx.stream());
}
template <typename Scalar>
void dispatch_lu_solve(Context& ctx, DeviceMatrix<Scalar>& dst, const LuSolveExpr<Scalar>& expr) {
const DeviceMatrix<Scalar>& A = expr.matrix();
const DeviceMatrix<Scalar>& B = expr.rhs();
eigen_assert(A.rows() == A.cols() && "LU requires a square matrix");
eigen_assert(B.rows() == A.rows() && "LU solve: RHS rows must match matrix size");
const int64_t n = static_cast<int64_t>(A.rows());
const int64_t nrhs = static_cast<int64_t>(B.cols());
if (n == 0 || nrhs == 0) {
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(n, B.cols());
return;
}
A.waitReady(ctx.stream());
B.waitReady(ctx.stream());
if (!dst.empty()) dst.waitReady(ctx.stream());
constexpr cudaDataType_t dtype = cuda_data_type<Scalar>::value;
const int64_t lda = static_cast<int64_t>(A.rows());
const int64_t ldb = static_cast<int64_t>(B.rows());
const size_t mat_bytes = static_cast<size_t>(lda) * static_cast<size_t>(n) * sizeof(Scalar);
const size_t rhs_bytes = static_cast<size_t>(ldb) * static_cast<size_t>(nrhs) * sizeof(Scalar);
const size_t ipiv_bytes = static_cast<size_t>(n) * sizeof(int64_t);
DeviceBuffer d_lu(mat_bytes);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_lu.get(), A.data(), mat_bytes, cudaMemcpyDeviceToDevice, ctx.stream()));
DeviceBuffer d_ipiv(ipiv_bytes);
PinnedHostBuffer h_info(2 * sizeof(int));
int* info_words = static_cast<int*>(h_info.get());
CusolverParams params;
DeviceBuffer d_info(2 * sizeof(int));
int* d_info_getrf = static_cast<int*>(d_info.get());
int* d_info_getrs = d_info_getrf + 1;
size_t dev_ws = 0, host_ws = 0;
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf_bufferSize(ctx.cusolverHandle(), params.p, n, n, dtype, d_lu.get(), lda, dtype,
&dev_ws, &host_ws));
DeviceBuffer d_workspace(dev_ws);
std::vector<char> h_workspace(host_ws);
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf(ctx.cusolverHandle(), params.p, n, n, dtype, d_lu.get(), lda,
static_cast<int64_t*>(d_ipiv.get()), dtype, d_workspace.get(), dev_ws,
host_ws > 0 ? h_workspace.data() : nullptr, host_ws, d_info_getrf));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(&info_words[0], d_info_getrf, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
dst.resize(n, B.cols());
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(dst.data(), B.data(), rhs_bytes, cudaMemcpyDeviceToDevice, ctx.stream()));
EIGEN_CUSOLVER_CHECK(cusolverDnXgetrs(ctx.cusolverHandle(), params.p, CUBLAS_OP_N, n, nrhs, dtype, d_lu.get(), lda,
static_cast<const int64_t*>(d_ipiv.get()), dtype, dst.data(),
static_cast<int64_t>(dst.rows()), d_info_getrs));
// Workspace locals must outlive the async kernels — sync before they unwind.
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(&info_words[1], d_info_getrs, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx.stream()));
eigen_assert(info_words[0] == 0 && "cuSOLVER LU factorization failed (singular matrix)");
eigen_assert(info_words[1] == 0 && "cuSOLVER LU solve failed");
dst.recordReady(ctx.stream());
}
template <typename Scalar, int UpLo>
void dispatch_trsm(Context& ctx, DeviceMatrix<Scalar>& dst, const TrsmExpr<Scalar, UpLo>& expr) {
const DeviceMatrix<Scalar>& A = expr.matrix();
const DeviceMatrix<Scalar>& B = expr.rhs();
eigen_assert(A.rows() == A.cols() && "TRSM requires a square triangular matrix");
eigen_assert(B.rows() == A.rows() && "TRSM: RHS rows must match matrix size");
eigen_assert(A.rows() <= INT_MAX && B.cols() <= INT_MAX && "cublasXtrsm dimensions exceed int range");
const int n = static_cast<int>(A.rows());
const int nrhs = static_cast<int>(B.cols());
if (n == 0 || nrhs == 0) {
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(n, B.cols());
return;
}
A.waitReady(ctx.stream());
B.waitReady(ctx.stream());
eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix TRSM destination aliases triangular operand");
eigen_assert(!aliases_device_memory(dst, B) && "DeviceMatrix TRSM destination aliases RHS operand");
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(n, B.cols());
const size_t rhs_bytes = static_cast<size_t>(dst.rows()) * static_cast<size_t>(nrhs) * sizeof(Scalar);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(dst.data(), B.data(), rhs_bytes, cudaMemcpyDeviceToDevice, ctx.stream()));
constexpr cublasFillMode_t uplo = (UpLo == Lower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
Scalar alpha(1);
EIGEN_CUBLAS_CHECK(cublasXtrsm(ctx.cublasHandle(), CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, n, nrhs,
&alpha, A.data(), static_cast<int>(A.rows()), dst.data(),
static_cast<int>(dst.rows())));
dst.recordReady(ctx.stream());
}
template <typename Scalar, int UpLo>
void dispatch_symm(Context& ctx, DeviceMatrix<Scalar>& dst, const SymmExpr<Scalar, UpLo>& expr) {
const DeviceMatrix<Scalar>& A = expr.matrix();
const DeviceMatrix<Scalar>& B = expr.rhs();
eigen_assert(A.rows() == A.cols() && "SYMM requires a square matrix");
eigen_assert(B.rows() == A.rows() && "SYMM: RHS rows must match matrix size");
eigen_assert(A.rows() <= INT_MAX && B.cols() <= INT_MAX && B.rows() <= INT_MAX &&
"cublasXsymm dimensions exceed int range");
const int m = static_cast<int>(A.rows());
const int n = static_cast<int>(B.cols());
if (m == 0 || n == 0) {
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(m, B.cols());
return;
}
A.waitReady(ctx.stream());
B.waitReady(ctx.stream());
eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix SYMM destination aliases self-adjoint operand");
eigen_assert(!aliases_device_memory(dst, B) && "DeviceMatrix SYMM destination aliases RHS operand");
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(m, n);
constexpr cublasFillMode_t uplo = (UpLo == Lower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
// See dispatch_gemm: array prevents compiler from eliding host-pointer stack slots.
Scalar scalars[2] = {Scalar(1), Scalar(0)};
EIGEN_CUBLAS_CHECK(cublasXsymm(ctx.cublasHandle(), CUBLAS_SIDE_LEFT, uplo, m, n, &scalars[0], A.data(),
static_cast<int>(A.rows()), B.data(), static_cast<int>(B.rows()), &scalars[1],
dst.data(), static_cast<int>(dst.rows())));
dst.recordReady(ctx.stream());
}
template <typename Scalar, int UpLo>
void dispatch_syrk(Context& ctx, DeviceMatrix<Scalar>& dst, const SyrkExpr<Scalar, UpLo>& expr,
typename NumTraits<Scalar>::Real alpha_val, typename NumTraits<Scalar>::Real beta_val) {
using RealScalar = typename NumTraits<Scalar>::Real;
const DeviceMatrix<Scalar>& A = expr.matrix();
eigen_assert(A.rows() <= INT_MAX && A.cols() <= INT_MAX && "cublasXsyrk dimensions exceed int range");
const int n = static_cast<int>(A.rows());
const int k = static_cast<int>(A.cols());
if (n == 0) {
if (!dst.empty()) dst.waitReady(ctx.stream());
dst.resize(0, 0);
return;
}
A.waitReady(ctx.stream());
eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix SYRK destination aliases input operand");
if (!dst.empty()) dst.waitReady(ctx.stream());
if (dst.empty() || dst.rows() != n || dst.cols() != n) {
dst.resize(n, n);
if (beta_val != RealScalar(0)) {
EIGEN_CUDA_RUNTIME_CHECK(cudaMemsetAsync(dst.data(), 0, dst.sizeInBytes(), ctx.stream()));
}
}
constexpr cublasFillMode_t uplo = (UpLo == Lower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
EIGEN_CUBLAS_CHECK(cublasXsyrk(ctx.cublasHandle(), uplo, CUBLAS_OP_N, n, k, &alpha_val, A.data(),
static_cast<int>(A.rows()), &beta_val, dst.data(), static_cast<int>(dst.rows())));
dst.recordReady(ctx.stream());
}
} // namespace internal
template <typename Scalar_>
class Assignment {
public:
using Scalar = Scalar_;
Assignment(DeviceMatrix<Scalar>& dst, Context& ctx) : dst_(dst), ctx_(ctx) {}
template <typename Lhs, typename Rhs>
DeviceMatrix<Scalar>& operator=(const GemmExpr<Lhs, Rhs>& expr) {
internal::dispatch_gemm(ctx_, dst_, expr, Scalar(0));
return dst_;
}
template <typename Lhs, typename Rhs>
DeviceMatrix<Scalar>& operator+=(const GemmExpr<Lhs, Rhs>& expr) {
internal::dispatch_gemm(ctx_, dst_, expr, Scalar(1));
return dst_;
}
template <typename Lhs, typename Rhs>
DeviceMatrix<Scalar>& operator-=(const GemmExpr<Lhs, Rhs>& expr) {
internal::dispatch_gemm(ctx_, dst_, expr, Scalar(1), Scalar(-1));
return dst_;
}
template <int UpLo>
DeviceMatrix<Scalar>& operator=(const LltSolveExpr<Scalar, UpLo>& expr) {
internal::dispatch_llt_solve(ctx_, dst_, expr);
return dst_;
}
DeviceMatrix<Scalar>& operator=(const LuSolveExpr<Scalar>& expr) {
internal::dispatch_lu_solve(ctx_, dst_, expr);
return dst_;
}
template <int UpLo>
DeviceMatrix<Scalar>& operator=(const TrsmExpr<Scalar, UpLo>& expr) {
internal::dispatch_trsm(ctx_, dst_, expr);
return dst_;
}
template <int UpLo>
DeviceMatrix<Scalar>& operator=(const SymmExpr<Scalar, UpLo>& expr) {
internal::dispatch_symm(ctx_, dst_, expr);
return dst_;
}
template <typename Expr>
DeviceMatrix<Scalar>& operator=(const Expr&) {
static_assert(sizeof(Expr) == 0,
"DeviceMatrix expression not supported: no cuBLAS/cuSOLVER mapping. "
"Supported: GEMM (A*B), TRSM (.triangularView().solve()), "
"SYMM (.selfadjointView()*B), LLT (.llt().solve()), LU (.lu().solve()).");
return dst_;
}
private:
DeviceMatrix<Scalar>& dst_;
Context& ctx_;
};
// Out-of-line definitions: these depend on Context::threadLocal(), which
// requires the full Context definition unavailable in DeviceMatrix.h.
template <typename Scalar_>
template <typename Lhs, typename Rhs>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const GemmExpr<Lhs, Rhs>& expr) {
device(Context::threadLocal()) = expr;
return *this;
}
template <typename Scalar_>
template <typename Lhs, typename Rhs>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator+=(const GemmExpr<Lhs, Rhs>& expr) {
device(Context::threadLocal()) += expr;
return *this;
}
template <typename Scalar_>
template <int UpLo>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const LltSolveExpr<Scalar_, UpLo>& expr) {
device(Context::threadLocal()) = expr;
return *this;
}
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const LuSolveExpr<Scalar_>& expr) {
device(Context::threadLocal()) = expr;
return *this;
}
template <typename Scalar_>
template <int UpLo>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const TrsmExpr<Scalar_, UpLo>& expr) {
device(Context::threadLocal()) = expr;
return *this;
}
template <typename Scalar_>
template <int UpLo>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const SymmExpr<Scalar_, UpLo>& expr) {
device(Context::threadLocal()) = expr;
return *this;
}
template <typename Scalar_, int UpLo_>
void SelfAdjointView<Scalar_, UpLo_>::rankUpdate(const DeviceMatrix<Scalar_>& A, RealScalar alpha) {
SyrkExpr<Scalar_, UpLo_> expr(A);
RealScalar beta = matrix().empty() ? RealScalar(0) : RealScalar(1);
internal::dispatch_syrk(Context::threadLocal(), matrix(), expr, alpha, beta);
}
// ---- Helper: scoped CUBLAS_POINTER_MODE_DEVICE ------------------------------
// Saves the current pointer mode, switches to device, runs the callable,
// then restores the original mode. Used by dot, norm, operator*=(DeviceScalar),
// and operator+=(DeviceScaledDevice).
namespace internal {
template <typename F>
void with_device_pointer_mode(cublasHandle_t h, F&& f) {
cublasPointerMode_t prev;
EIGEN_CUBLAS_CHECK(cublasGetPointerMode(h, &prev));
EIGEN_CUBLAS_CHECK(cublasSetPointerMode(h, CUBLAS_POINTER_MODE_DEVICE));
f();
EIGEN_CUBLAS_CHECK(cublasSetPointerMode(h, prev));
}
} // namespace internal
// ---- DeviceMatrix BLAS-1 out-of-line definitions ----------------------------
// Defined here because they need the full Context definition.
// All methods take an explicit Context& so callers can ensure same-stream
// execution (zero event overhead when all operations share one context).
//
// Reduction methods (dot, norm, squaredNorm) use CUBLAS_POINTER_MODE_DEVICE:
// the scalar result is written to device memory and stays there until read
// via DeviceScalar's implicit conversion to Scalar (which syncs).
namespace internal {
// BLAS-1 cuBLAS wrappers take int counts. Index is ptrdiff_t on 64-bit
// systems, so guard against silent truncation for matrices with > INT_MAX
// elements.
inline int blas1_int_size(Index rows, Index cols) {
const int64_t total = static_cast<int64_t>(rows) * static_cast<int64_t>(cols);
eigen_assert(total <= static_cast<int64_t>((std::numeric_limits<int>::max)()) &&
"cuBLAS BLAS-1 length exceeds int range");
return static_cast<int>(total);
}
} // namespace internal
template <typename Scalar_>
DeviceScalar<typename DeviceMatrix<Scalar_>::Scalar> DeviceMatrix<Scalar_>::dot(Context& ctx,
const DeviceMatrix& other) const {
const int n = internal::blas1_int_size(rows_, cols_);
eigen_assert(n == internal::blas1_int_size(other.rows_, other.cols_));
DeviceScalar<Scalar> result(Scalar(0), ctx.stream());
if (n > 0) {
waitReady(ctx.stream());
other.waitReady(ctx.stream());
internal::with_device_pointer_mode(ctx.cublasHandle(), [&] {
EIGEN_CUBLAS_CHECK(
internal::cublasXdot(ctx.cublasHandle(), n, data_.get(), 1, other.data_.get(), 1, result.devicePtr()));
});
}
return result;
}
namespace internal {
// Real: dot(x,x) returns DeviceScalar<Scalar> which IS DeviceScalar<RealScalar>.
// Move-construct without any sync.
template <typename Scalar, typename RealScalar>
typename std::enable_if<std::is_same<Scalar, RealScalar>::value, DeviceScalar<RealScalar>>::type squaredNorm_from_dot(
DeviceScalar<Scalar>&& d, cudaStream_t) {
return std::move(d);
}
// Complex: must sync to extract the real part (DeviceScalar arithmetic is real-only).
template <typename Scalar, typename RealScalar>
typename std::enable_if<!std::is_same<Scalar, RealScalar>::value, DeviceScalar<RealScalar>>::type squaredNorm_from_dot(
DeviceScalar<Scalar>&& d, cudaStream_t stream) {
return DeviceScalar<RealScalar>(numext::real(Scalar(d)), stream);
}
} // namespace internal
template <typename Scalar_>
DeviceScalar<typename NumTraits<Scalar_>::Real> DeviceMatrix<Scalar_>::squaredNorm(Context& ctx) const {
// Use dot(x,x) instead of nrm2()^2: dot kernel is ~4.5x faster than nrm2
// (nrm2 uses a numerically careful scaled-sum-of-squares algorithm that is
// unnecessary for CG convergence checks).
using RealScalar = typename NumTraits<Scalar_>::Real;
return internal::squaredNorm_from_dot<Scalar_, RealScalar>(dot(ctx, *this), ctx.stream());
}
template <typename Scalar_>
DeviceScalar<typename NumTraits<Scalar_>::Real> DeviceMatrix<Scalar_>::norm(Context& ctx) const {
using RealScalar = typename NumTraits<Scalar>::Real;
const int n = internal::blas1_int_size(rows_, cols_);
DeviceScalar<RealScalar> result(RealScalar(0), ctx.stream());
if (n > 0) {
waitReady(ctx.stream());
internal::with_device_pointer_mode(ctx.cublasHandle(), [&] {
EIGEN_CUBLAS_CHECK(internal::cublasXnrm2(ctx.cublasHandle(), n, data_.get(), 1, result.devicePtr()));
});
}
return result;
}
template <typename Scalar_>
void DeviceMatrix<Scalar_>::setZero(cudaStream_t stream) {
if (sizeInBytes() > 0) {
waitReady(stream);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemsetAsync(data_.get(), 0, sizeInBytes(), stream));
recordReady(stream);
}
}
template <typename Scalar_>
void DeviceMatrix<Scalar_>::setZero(Context& ctx) {
setZero(ctx.stream());
}
template <typename Scalar_>
void DeviceMatrix<Scalar_>::addScaled(Context& ctx, Scalar alpha, const DeviceMatrix& x) {
const int n = internal::blas1_int_size(rows_, cols_);
eigen_assert(n == internal::blas1_int_size(x.rows_, x.cols_));
if (n > 0) {
waitReady(ctx.stream());
x.waitReady(ctx.stream());
EIGEN_CUBLAS_CHECK(internal::cublasXaxpy(ctx.cublasHandle(), n, &alpha, x.data_.get(), 1, data_.get(), 1));
recordReady(ctx.stream());
}
}
template <typename Scalar_>
void DeviceMatrix<Scalar_>::scale(Context& ctx, Scalar alpha) {
const int n = internal::blas1_int_size(rows_, cols_);
if (n > 0) {
waitReady(ctx.stream());
EIGEN_CUBLAS_CHECK(internal::cublasXscal(ctx.cublasHandle(), n, &alpha, data_.get(), 1));
recordReady(ctx.stream());
}
}
template <typename Scalar_>
void DeviceMatrix<Scalar_>::copyFrom(Context& ctx, const DeviceMatrix& other) {
// Wait on *this before resize — resize may free the old buffer while another
// stream is still reading it.
if (!empty()) waitReady(ctx.stream());
resize(other.rows_, other.cols_);
const int n = internal::blas1_int_size(rows_, cols_);
if (n > 0) {
other.waitReady(ctx.stream());
EIGEN_CUBLAS_CHECK(internal::cublasXcopy(ctx.cublasHandle(), n, other.data_.get(), 1, data_.get(), 1));
recordReady(ctx.stream());
}
}
// ---- BLAS-1 operator overloads for CG compatibility -------------------------
// this += alpha * x (axpy)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator+=(const Scaled<DeviceMatrix>& expr) {
addScaled(Context::threadLocal(), expr.scalar(), internal::device_expr_traits<DeviceMatrix>::matrix(expr.inner()));
return *this;
}
// this -= alpha * x (axpy with negated alpha)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator-=(const Scaled<DeviceMatrix>& expr) {
addScaled(Context::threadLocal(), -expr.scalar(), internal::device_expr_traits<DeviceMatrix>::matrix(expr.inner()));
return *this;
}
// this += x (axpy with alpha=1)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator+=(const DeviceMatrix& other) {
Scalar one(1);
addScaled(Context::threadLocal(), one, other);
return *this;
}
// this -= x (axpy with alpha=-1)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator-=(const DeviceMatrix& other) {
Scalar neg_one(-1);
addScaled(Context::threadLocal(), neg_one, other);
return *this;
}
// this *= alpha (scal, host pointer)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator*=(Scalar alpha) {
scale(Context::threadLocal(), alpha);
return *this;
}
// this *= alpha (scal, device pointer — avoids host sync)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator*=(const DeviceScalar<Scalar>& alpha) {
const int n = internal::blas1_int_size(rows_, cols_);
if (n > 0) {
auto& ctx = Context::threadLocal();
waitReady(ctx.stream());
internal::with_device_pointer_mode(ctx.cublasHandle(), [&] {
EIGEN_CUBLAS_CHECK(internal::cublasXscal(ctx.cublasHandle(), n, alpha.devicePtr(), data_.get(), 1));
});
recordReady(ctx.stream());
}
return *this;
}
// this += DeviceScalar * x (axpy with CUBLAS_POINTER_MODE_DEVICE)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator+=(const DeviceScaledDevice<Scalar_>& expr) {
const int n = internal::blas1_int_size(rows_, cols_);
const auto& x = expr.matrix();
eigen_assert(n == internal::blas1_int_size(x.rows_, x.cols_));
if (n > 0) {
auto& ctx = Context::threadLocal();
waitReady(ctx.stream());
x.waitReady(ctx.stream());
internal::with_device_pointer_mode(ctx.cublasHandle(), [&] {
EIGEN_CUBLAS_CHECK(
internal::cublasXaxpy(ctx.cublasHandle(), n, expr.alpha().devicePtr(), x.data_.get(), 1, data_.get(), 1));
});
recordReady(ctx.stream());
}
return *this;
}
// this -= DeviceScalar * x (axpy with negated device scalar)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator-=(const DeviceScaledDevice<Scalar_>& expr) {
auto neg_alpha = -expr.alpha();
DeviceScaledDevice<Scalar_> neg_expr(neg_alpha, expr.matrix());
return operator+=(neg_expr);
}
// this = alpha * A + beta * B (cuBLAS geam)
template <typename Scalar_>
DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const DeviceAddExpr<Scalar_>& expr) {
auto& ctx = Context::threadLocal();
const auto& A = expr.A();
const auto& B = expr.B();
eigen_assert(A.rows() == B.rows() && A.cols() == B.cols());
const int m = static_cast<int>(A.rows());
const int n = static_cast<int>(A.cols());
// Wait on *this before resize — resize may free the old buffer while another
// stream is still reading it.
if (!empty()) waitReady(ctx.stream());
resize(A.rows(), A.cols());
if (m > 0 && n > 0) {
A.waitReady(ctx.stream());
B.waitReady(ctx.stream());
Scalar_ alpha = expr.alpha();
Scalar_ beta = expr.beta();
EIGEN_CUBLAS_CHECK(internal::cublasXgeam(ctx.cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, &alpha, A.data(), m,
&beta, B.data(), m, data_.get(), m));
recordReady(ctx.stream());
}
return *this;
}
// cwiseProduct (allocating).
template <typename Scalar_>
DeviceMatrix<Scalar_> DeviceMatrix<Scalar_>::cwiseProduct(Context& ctx, const DeviceMatrix& other) const {
const int n = internal::blas1_int_size(rows_, cols_);
eigen_assert(n == internal::blas1_int_size(other.rows_, other.cols_));
DeviceMatrix result(rows_, cols_);
if (n > 0) {
waitReady(ctx.stream());
other.waitReady(ctx.stream());
internal::device_cwiseProduct(data_.get(), other.data_.get(), result.data_.get(), n, ctx.stream());
result.recordReady(ctx.stream());
}
return result;
}
// In-place cwiseProduct: this = a .* b (reuses this buffer, no allocation).
template <typename Scalar_>
void DeviceMatrix<Scalar_>::cwiseProduct(Context& ctx, const DeviceMatrix& a, const DeviceMatrix& b) {
const int n = internal::blas1_int_size(a.rows_, a.cols_);
eigen_assert(n == internal::blas1_int_size(b.rows_, b.cols_));
if (!empty()) waitReady(ctx.stream());
resize(a.rows_, a.cols_);
if (n > 0) {
a.waitReady(ctx.stream());
b.waitReady(ctx.stream());
internal::device_cwiseProduct(a.data_.get(), b.data_.get(), data_.get(), n, ctx.stream());
recordReady(ctx.stream());
}
}
// Convenience overloads using thread-local default Context.
template <typename Scalar_>
DeviceScalar<typename DeviceMatrix<Scalar_>::Scalar> DeviceMatrix<Scalar_>::dot(const DeviceMatrix& other) const {
return dot(Context::threadLocal(), other);
}
template <typename Scalar_>
DeviceScalar<typename NumTraits<Scalar_>::Real> DeviceMatrix<Scalar_>::squaredNorm() const {
return squaredNorm(Context::threadLocal());
}
template <typename Scalar_>
DeviceScalar<typename NumTraits<Scalar_>::Real> DeviceMatrix<Scalar_>::norm() const {
return norm(Context::threadLocal());
}
template <typename Scalar_>
void DeviceMatrix<Scalar_>::setZero() {
setZero(Context::threadLocal());
}
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_DEVICE_DISPATCH_H