blob: 8b88c5c8e764a218bd044f62a739870bc01f2363 [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// cuSOLVER-specific support types:
// - cuSOLVER error-checking macro
// - RAII wrapper for cusolverDnParams
// - Scalar → cudaDataType_t mapping
// - (UpLo, StorageOrder) → cublasFillMode_t mapping
//
// Generic CUDA runtime utilities (DeviceBuffer, EIGEN_CUDA_RUNTIME_CHECK)
// are in GpuSupport.h.
#ifndef EIGEN_GPU_CUSOLVER_SUPPORT_H
#define EIGEN_GPU_CUSOLVER_SUPPORT_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./GpuSupport.h"
#include <cusolverDn.h>
#include <cstdio>
namespace Eigen {
namespace gpu {
namespace internal {
// ---- Error-checking macros --------------------------------------------------
// cuSOLVER does not ship a cusolverGetErrorString() in the public API, so we
// stringify the codes ourselves. Keeps failed asserts actionable.
inline const char* cusolver_status_name(cusolverStatus_t s) {
switch (s) {
case CUSOLVER_STATUS_SUCCESS:
return "CUSOLVER_STATUS_SUCCESS";
case CUSOLVER_STATUS_NOT_INITIALIZED:
return "CUSOLVER_STATUS_NOT_INITIALIZED";
case CUSOLVER_STATUS_ALLOC_FAILED:
return "CUSOLVER_STATUS_ALLOC_FAILED";
case CUSOLVER_STATUS_INVALID_VALUE:
return "CUSOLVER_STATUS_INVALID_VALUE";
case CUSOLVER_STATUS_ARCH_MISMATCH:
return "CUSOLVER_STATUS_ARCH_MISMATCH";
case CUSOLVER_STATUS_EXECUTION_FAILED:
return "CUSOLVER_STATUS_EXECUTION_FAILED";
case CUSOLVER_STATUS_INTERNAL_ERROR:
return "CUSOLVER_STATUS_INTERNAL_ERROR";
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
case CUSOLVER_STATUS_NOT_SUPPORTED:
return "CUSOLVER_STATUS_NOT_SUPPORTED";
default:
return "CUSOLVER_STATUS_UNKNOWN";
}
}
inline bool report_cusolver_failure(cusolverStatus_t s, const char* expr, const char* file, int line) {
std::fprintf(stderr,
"cuSOLVER call failed\n"
" expr: %s\n"
" status: %s (%d)\n"
" at: %s:%d\n",
expr, cusolver_status_name(s), static_cast<int>(s), file, line);
return false;
}
#define EIGEN_CUSOLVER_CHECK(expr) \
do { \
cusolverStatus_t _s = (expr); \
eigen_assert(_s == CUSOLVER_STATUS_SUCCESS || \
::Eigen::gpu::internal::report_cusolver_failure(_s, #expr, __FILE__, __LINE__)); \
} while (0)
// ---- RAII: cusolverDnParams -------------------------------------------------
struct CusolverParams {
cusolverDnParams_t p = nullptr;
CusolverParams() { EIGEN_CUSOLVER_CHECK(cusolverDnCreateParams(&p)); }
~CusolverParams() {
if (p) (void)cusolverDnDestroyParams(p); // destructor: can't propagate
}
// Move-only.
CusolverParams(CusolverParams&& o) noexcept : p(o.p) { o.p = nullptr; }
CusolverParams& operator=(CusolverParams&& o) noexcept {
if (this != &o) {
if (p) EIGEN_CUSOLVER_CHECK(cusolverDnDestroyParams(p));
p = o.p;
o.p = nullptr;
}
return *this;
}
CusolverParams(const CusolverParams&) = delete;
CusolverParams& operator=(const CusolverParams&) = delete;
};
// ---- Scalar → cudaDataType_t ------------------------------------------------
// Alias for backward compatibility. The canonical trait is cuda_data_type<> in GpuSupport.h.
template <typename Scalar>
using cusolver_data_type = cuda_data_type<Scalar>;
// ---- UpLo → cublasFillMode_t ------------------------------------------------
// cuSOLVER always interprets the matrix as column-major. Callers pass the
// triangle that holds the data in column-major layout.
template <int UpLo>
struct cusolver_fill_mode;
template <>
struct cusolver_fill_mode<Lower> {
static constexpr cublasFillMode_t value = CUBLAS_FILL_MODE_LOWER;
};
template <>
struct cusolver_fill_mode<Upper> {
static constexpr cublasFillMode_t value = CUBLAS_FILL_MODE_UPPER;
};
// ---- Type-specific cuSOLVER wrappers ----------------------------------------
// cuSOLVER does not provide generic X variants for ormqr/unmqr. These overloaded
// wrappers dispatch to the correct type-specific function.
// For real types: ormqr (orthogonal Q). For complex types: unmqr (unitary Q).
inline cusolverStatus_t cusolverDnXormqr(cusolverDnHandle_t h, cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const float* A, int lda, const float* tau, float* C, int ldc,
float* work, int lwork, int* info) {
return cusolverDnSormqr(h, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, info);
}
inline cusolverStatus_t cusolverDnXormqr(cusolverDnHandle_t h, cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const double* A, int lda, const double* tau, double* C, int ldc,
double* work, int lwork, int* info) {
return cusolverDnDormqr(h, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, info);
}
inline cusolverStatus_t cusolverDnXormqr(cusolverDnHandle_t h, cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const std::complex<float>* A, int lda,
const std::complex<float>* tau, std::complex<float>* C, int ldc,
std::complex<float>* work, int lwork, int* info) {
return cusolverDnCunmqr(h, side, trans, m, n, k, reinterpret_cast<const cuComplex*>(A), lda,
reinterpret_cast<const cuComplex*>(tau), reinterpret_cast<cuComplex*>(C), ldc,
reinterpret_cast<cuComplex*>(work), lwork, info);
}
inline cusolverStatus_t cusolverDnXormqr(cusolverDnHandle_t h, cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const std::complex<double>* A, int lda,
const std::complex<double>* tau, std::complex<double>* C, int ldc,
std::complex<double>* work, int lwork, int* info) {
return cusolverDnZunmqr(h, side, trans, m, n, k, reinterpret_cast<const cuDoubleComplex*>(A), lda,
reinterpret_cast<const cuDoubleComplex*>(tau), reinterpret_cast<cuDoubleComplex*>(C), ldc,
reinterpret_cast<cuDoubleComplex*>(work), lwork, info);
}
// Buffer size wrappers for ormqr/unmqr.
inline cusolverStatus_t cusolverDnXormqr_bufferSize(cusolverDnHandle_t h, cublasSideMode_t side,
cublasOperation_t trans, int m, int n, int k, const float* A,
int lda, const float* tau, const float* C, int ldc, int* lwork) {
return cusolverDnSormqr_bufferSize(h, side, trans, m, n, k, A, lda, tau, C, ldc, lwork);
}
inline cusolverStatus_t cusolverDnXormqr_bufferSize(cusolverDnHandle_t h, cublasSideMode_t side,
cublasOperation_t trans, int m, int n, int k, const double* A,
int lda, const double* tau, const double* C, int ldc, int* lwork) {
return cusolverDnDormqr_bufferSize(h, side, trans, m, n, k, A, lda, tau, C, ldc, lwork);
}
inline cusolverStatus_t cusolverDnXormqr_bufferSize(cusolverDnHandle_t h, cublasSideMode_t side,
cublasOperation_t trans, int m, int n, int k,
const std::complex<float>* A, int lda,
const std::complex<float>* tau, const std::complex<float>* C,
int ldc, int* lwork) {
return cusolverDnCunmqr_bufferSize(h, side, trans, m, n, k, reinterpret_cast<const cuComplex*>(A), lda,
reinterpret_cast<const cuComplex*>(tau), reinterpret_cast<const cuComplex*>(C),
ldc, lwork);
}
inline cusolverStatus_t cusolverDnXormqr_bufferSize(cusolverDnHandle_t h, cublasSideMode_t side,
cublasOperation_t trans, int m, int n, int k,
const std::complex<double>* A, int lda,
const std::complex<double>* tau, const std::complex<double>* C,
int ldc, int* lwork) {
return cusolverDnZunmqr_bufferSize(h, side, trans, m, n, k, reinterpret_cast<const cuDoubleComplex*>(A), lda,
reinterpret_cast<const cuDoubleComplex*>(tau),
reinterpret_cast<const cuDoubleComplex*>(C), ldc, lwork);
}
} // namespace internal
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_CUSOLVER_SUPPORT_H