blob: 46bdf082a070e311709467faaaa145d4d8c35046 [file] [edit]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// Unified GPU execution context.
//
// gpu::Context owns a CUDA stream and NVIDIA library handles (cuBLAS
// eagerly, cuSOLVER / cuBLASLt / cuSPARSE lazily on first use). It is the
// entry point for all GPU operations on gpu::DeviceMatrix.
//
// The cuSOLVER handle is created on the first call to cusolverHandle()
// so that translation units which only use cuFFT or cuBLAS paths (e.g.
// the cufft test) do not pull cusolverDn* symbols into the link.
//
// Usage:
// gpu::Context ctx; // explicit context
// d_C.device(ctx) = d_A * d_B; // GEMM on ctx's stream
//
// d_C = d_A * d_B; // thread-local default context
// gpu::Context& ctx = gpu::Context::threadLocal();
#ifndef EIGEN_GPU_CONTEXT_H
#define EIGEN_GPU_CONTEXT_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./CuBlasSupport.h"
#include "./CuSolverSupport.h"
#include <cusparse.h>
namespace Eigen {
namespace gpu {
/** \ingroup GPU_Module
* \class Context
* \brief Unified GPU execution context owning a CUDA stream and library handles.
*
* Each Context instance creates a dedicated CUDA stream and a cuBLAS handle
* bound to that stream. The cuSOLVER handle is created on first use via
* cusolverHandle(); translation units that never call it do not require
* cuSOLVER at link time. cuBLASLt and cuSPARSE handles are similarly lazy.
* Multiple contexts enable concurrent execution on independent streams.
*
* A lazily-created thread-local default is available via threadLocal() for
* simple single-stream usage. A single Context is not thread-safe — use one
* per thread, or external synchronization, since cuBLAS / cuSOLVER handles
* are not thread-safe per handle and lazy-init of secondary handles is racy.
*/
class Context {
public:
/** Create a new context with a dedicated CUDA stream. */
Context() {
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamCreate(&stream_));
init_cublas();
}
/** Create a context on an existing stream (e.g., stream 0 = nullptr).
* The caller retains ownership of the stream — this context will not destroy it. */
explicit Context(cudaStream_t stream) : stream_(stream), owns_stream_(false) { init_cublas(); }
~Context() {
// Indirect calls keep cusolverDnDestroy / cusparseDestroy out of TUs that
// never call cusolverHandle() / cusparseHandle() (e.g. the cufft test).
if (cusparse_destroyer_) (void)cusparse_destroyer_(cusparse_);
if (cusolver_destroyer_) (void)cusolver_destroyer_(cusolver_);
if (cublas_lt_) (void)cublasLtDestroy(cublas_lt_);
if (cublas_) (void)cublasDestroy(cublas_);
if (owns_stream_ && stream_) (void)cudaStreamDestroy(stream_);
}
// Non-copyable, non-movable (owns library handles).
Context(const Context&) = delete;
Context& operator=(const Context&) = delete;
Context(Context&&) = delete;
Context& operator=(Context&&) = delete;
/** Get the thread-local default context.
*
* If setThreadLocal() has been called, returns that context.
* Otherwise lazily creates a new context with a dedicated stream.
*
* \note The thread-local instance is destroyed when the thread exits (or at
* static destruction time for the main thread). On some CUDA driver
* configurations this may print "CUDA_ERROR_DEINITIALIZED" to stderr if the
* CUDA context has already been torn down. These errors are harmless and are
* suppressed in the destructor, but they can produce noise in test output.
* To avoid this, call cudaDeviceReset() only after all Context instances
* (including thread-local ones) have been destroyed. */
static Context& threadLocal() {
Context* override = tl_override_ptr();
if (override) return *override;
thread_local Context ctx;
return ctx;
}
/** Override the thread-local default context for this thread.
* The caller retains ownership of \p ctx — it must outlive all uses.
* Pass nullptr to restore the lazily-created default. */
static void setThreadLocal(Context* ctx) { tl_override_ptr() = ctx; }
cudaStream_t stream() const { return stream_; }
cublasHandle_t cublasHandle() const { return cublas_; }
/** Returns the cuSOLVER handle, creating it on first call. */
cusolverDnHandle_t cusolverHandle() {
if (!cusolver_) {
EIGEN_CUSOLVER_CHECK(cusolverDnCreate(&cusolver_));
EIGEN_CUSOLVER_CHECK(cusolverDnSetStream(cusolver_, stream_));
cusolver_destroyer_ = &destroyCusolver;
}
return cusolver_;
}
/** cuBLASLt handle (lazy-initialized on first GEMM call). */
cublasLtHandle_t cublasLtHandle() const {
if (!cublas_lt_) {
EIGEN_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_));
}
return cublas_lt_;
}
/** Workspace buffer for cublasLtMatmul (grown lazily by cublaslt_gemm).
* Not thread-safe — all GEMM calls must be on this context's stream. */
internal::DeviceBuffer* gemmWorkspace() const { return &gemm_workspace_; }
/** cuSPARSE handle (lazy-initialized on first call). */
cusparseHandle_t cusparseHandle() const {
if (!cusparse_) {
cusparseStatus_t s1 = cusparseCreate(&cusparse_);
eigen_assert(s1 == CUSPARSE_STATUS_SUCCESS && "cusparseCreate failed");
EIGEN_UNUSED_VARIABLE(s1);
cusparseStatus_t s2 = cusparseSetStream(cusparse_, stream_);
eigen_assert(s2 == CUSPARSE_STATUS_SUCCESS && "cusparseSetStream failed");
EIGEN_UNUSED_VARIABLE(s2);
cusparse_destroyer_ = &destroyCusparse;
}
return cusparse_;
}
private:
static cusolverStatus_t destroyCusolver(cusolverDnHandle_t h) { return cusolverDnDestroy(h); }
static cusparseStatus_t destroyCusparse(cusparseHandle_t h) { return cusparseDestroy(h); }
cudaStream_t stream_ = nullptr;
cublasHandle_t cublas_ = nullptr;
cusolverDnHandle_t cusolver_ = nullptr;
cusolverStatus_t (*cusolver_destroyer_)(cusolverDnHandle_t) = nullptr;
mutable cublasLtHandle_t cublas_lt_ = nullptr; // lazy
mutable cusparseHandle_t cusparse_ = nullptr; // lazy
mutable cusparseStatus_t (*cusparse_destroyer_)(cusparseHandle_t) = nullptr;
mutable internal::DeviceBuffer gemm_workspace_; // lazy
bool owns_stream_ = true;
static Context*& tl_override_ptr() {
thread_local Context* ptr = nullptr;
return ptr;
}
void init_cublas() {
EIGEN_CUBLAS_CHECK(cublasCreate(&cublas_));
EIGEN_CUBLAS_CHECK(cublasSetStream(cublas_, stream_));
}
};
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_CONTEXT_H