| // This file is part of Eigen, a lightweight C++ template library |
| // for linear algebra. |
| // |
| // Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com> |
| // |
| // This Source Code Form is subject to the terms of the Mozilla |
| // Public License v. 2.0. If a copy of the MPL was not distributed |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| // Dispatch functions that map DeviceMatrix expressions to NVIDIA library calls. |
| // |
| // dispatch_gemm() — GemmExpr → cublasXgemm |
| // |
| // Each function documents the exact library call and parameters. |
| |
| #ifndef EIGEN_GPU_DEVICE_DISPATCH_H |
| #define EIGEN_GPU_DEVICE_DISPATCH_H |
| |
| // IWYU pragma: private |
| #include "./InternalHeaderCheck.h" |
| |
| #include <climits> |
| |
| #include "./DeviceExpr.h" |
| #include "./DeviceBlasExpr.h" |
| #include "./DeviceSolverExpr.h" |
| #include "./GpuContext.h" |
| #include "./CuSolverSupport.h" |
| |
| namespace Eigen { |
| namespace gpu { |
| namespace internal { |
| |
| template <typename Scalar> |
| bool aliases_device_memory(const DeviceMatrix<Scalar>& a, const DeviceMatrix<Scalar>& b) { |
| return a.data() != nullptr && a.data() == b.data(); |
| } |
| |
| // ---- GEMM dispatch ---------------------------------------------------------- |
| // GemmExpr<Lhs, Rhs> → cublasXgemm (type-specific Sgemm/Dgemm/Cgemm/Zgemm). |
| |
| template <typename Lhs, typename Rhs> |
| void dispatch_gemm( |
| Context& ctx, DeviceMatrix<typename device_expr_traits<Lhs>::scalar_type>& dst, const GemmExpr<Lhs, Rhs>& expr, |
| typename device_expr_traits<Lhs>::scalar_type beta_val, |
| typename device_expr_traits<Lhs>::scalar_type alpha_scale = typename device_expr_traits<Lhs>::scalar_type(1)) { |
| using Scalar = typename device_expr_traits<Lhs>::scalar_type; |
| using traits_lhs = device_expr_traits<Lhs>; |
| using traits_rhs = device_expr_traits<Rhs>; |
| |
| const DeviceMatrix<Scalar>& A = traits_lhs::matrix(expr.lhs()); |
| const DeviceMatrix<Scalar>& B = traits_rhs::matrix(expr.rhs()); |
| |
| constexpr cublasOperation_t transA = to_cublas_op(traits_lhs::op); |
| constexpr cublasOperation_t transB = to_cublas_op(traits_rhs::op); |
| |
| // GEMM dimensions: C(m,n) = op(A)(m,k) * op(B)(k,n) |
| const int64_t m = (traits_lhs::op == GpuOp::NoTrans) ? A.rows() : A.cols(); |
| const int64_t k = (traits_lhs::op == GpuOp::NoTrans) ? A.cols() : A.rows(); |
| const int64_t n = (traits_rhs::op == GpuOp::NoTrans) ? B.cols() : B.rows(); |
| const int64_t rhs_k = (traits_rhs::op == GpuOp::NoTrans) ? B.rows() : B.cols(); |
| |
| eigen_assert(k == rhs_k && "DeviceMatrix GEMM dimension mismatch"); |
| |
| const int64_t lda = A.outerStride(); |
| const int64_t ldb = B.outerStride(); |
| |
| eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix GEMM destination aliases lhs operand"); |
| eigen_assert(!aliases_device_memory(dst, B) && "DeviceMatrix GEMM destination aliases rhs operand"); |
| |
| if (!dst.empty()) { |
| dst.waitReady(ctx.stream()); |
| } |
| |
| const bool resized = dst.empty() || dst.rows() != m || dst.cols() != n; |
| if (resized) { |
| dst.resize(m, n); |
| } |
| const int64_t ldc = dst.outerStride(); |
| |
| Scalar alpha_local = alpha_scale * traits_lhs::alpha(expr.lhs()) * traits_rhs::alpha(expr.rhs()); |
| |
| A.waitReady(ctx.stream()); |
| B.waitReady(ctx.stream()); |
| |
| if (resized && beta_val != Scalar(0) && dst.sizeInBytes() > 0) { |
| EIGEN_CUDA_RUNTIME_CHECK(cudaMemsetAsync(dst.data(), 0, dst.sizeInBytes(), ctx.stream())); |
| } |
| |
| eigen_assert(m <= INT_MAX && n <= INT_MAX && k <= INT_MAX && lda <= INT_MAX && ldb <= INT_MAX && ldc <= INT_MAX && |
| "cublasXgemm dimensions exceed int range"); |
| |
| Scalar scalars[2] = {alpha_local, beta_val}; |
| EIGEN_CUBLAS_CHECK(cublasXgemm(ctx.cublasHandle(), transA, transB, static_cast<int>(m), static_cast<int>(n), |
| static_cast<int>(k), &scalars[0], A.data(), static_cast<int>(lda), B.data(), |
| static_cast<int>(ldb), &scalars[1], dst.data(), static_cast<int>(ldc))); |
| |
| dst.recordReady(ctx.stream()); |
| } |
| |
| // ---- LLT solve dispatch ----------------------------------------------------- |
| |
| template <typename Scalar, int UpLo> |
| void dispatch_llt_solve(Context& ctx, DeviceMatrix<Scalar>& dst, const LltSolveExpr<Scalar, UpLo>& expr) { |
| const DeviceMatrix<Scalar>& A = expr.matrix(); |
| const DeviceMatrix<Scalar>& B = expr.rhs(); |
| |
| eigen_assert(A.rows() == A.cols() && "LLT requires a square matrix"); |
| eigen_assert(B.rows() == A.rows() && "LLT solve: RHS rows must match matrix size"); |
| |
| const int64_t n = static_cast<int64_t>(A.rows()); |
| const int64_t nrhs = static_cast<int64_t>(B.cols()); |
| |
| if (n == 0 || nrhs == 0) { |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| dst.resize(n, B.cols()); |
| return; |
| } |
| |
| A.waitReady(ctx.stream()); |
| B.waitReady(ctx.stream()); |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| |
| constexpr cudaDataType_t dtype = cuda_data_type<Scalar>::value; |
| constexpr cublasFillMode_t uplo = cusolver_fill_mode<UpLo>::value; |
| const int64_t lda = static_cast<int64_t>(A.outerStride()); |
| const int64_t ldb = static_cast<int64_t>(B.outerStride()); |
| const size_t mat_bytes = static_cast<size_t>(lda) * static_cast<size_t>(n) * sizeof(Scalar); |
| const size_t rhs_bytes = static_cast<size_t>(ldb) * static_cast<size_t>(nrhs) * sizeof(Scalar); |
| |
| DeviceBuffer d_factor(mat_bytes); |
| EIGEN_CUDA_RUNTIME_CHECK( |
| cudaMemcpyAsync(d_factor.get(), A.data(), mat_bytes, cudaMemcpyDeviceToDevice, ctx.stream())); |
| |
| // Two info slots (potrf, potrs) so we can queue both kernels back-to-back |
| // and host-sync once at the end. If potrf fails, potrs runs on garbage but |
| // the assert fires after the single sync — saving a round trip. |
| PinnedHostBuffer h_info(2 * sizeof(int)); |
| int* info_words = static_cast<int*>(h_info.get()); |
| |
| CusolverParams params; |
| DeviceBuffer d_info(2 * sizeof(int)); |
| int* d_info_potrf = static_cast<int*>(d_info.get()); |
| int* d_info_potrs = d_info_potrf + 1; |
| size_t dev_ws = 0, host_ws = 0; |
| EIGEN_CUSOLVER_CHECK(cusolverDnXpotrf_bufferSize(ctx.cusolverHandle(), params.p, uplo, n, dtype, d_factor.get(), lda, |
| dtype, &dev_ws, &host_ws)); |
| |
| DeviceBuffer d_workspace(dev_ws); |
| std::vector<char> h_workspace(host_ws); |
| |
| EIGEN_CUSOLVER_CHECK(cusolverDnXpotrf(ctx.cusolverHandle(), params.p, uplo, n, dtype, d_factor.get(), lda, dtype, |
| d_workspace.get(), dev_ws, host_ws > 0 ? h_workspace.data() : nullptr, host_ws, |
| d_info_potrf)); |
| |
| EIGEN_CUDA_RUNTIME_CHECK( |
| cudaMemcpyAsync(&info_words[0], d_info_potrf, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream())); |
| |
| dst.resize(n, B.cols()); |
| EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(dst.data(), B.data(), rhs_bytes, cudaMemcpyDeviceToDevice, ctx.stream())); |
| |
| EIGEN_CUSOLVER_CHECK(cusolverDnXpotrs(ctx.cusolverHandle(), params.p, uplo, n, nrhs, dtype, d_factor.get(), lda, |
| dtype, dst.data(), static_cast<int64_t>(dst.outerStride()), d_info_potrs)); |
| |
| EIGEN_CUDA_RUNTIME_CHECK( |
| cudaMemcpyAsync(&info_words[1], d_info_potrs, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream())); |
| EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx.stream())); |
| eigen_assert(info_words[0] == 0 && "cuSOLVER LLT factorization failed (matrix not positive definite)"); |
| eigen_assert(info_words[1] == 0 && "cuSOLVER LLT solve failed"); |
| |
| dst.recordReady(ctx.stream()); |
| } |
| |
| // ---- LU solve dispatch ------------------------------------------------------ |
| |
| template <typename Scalar> |
| void dispatch_lu_solve(Context& ctx, DeviceMatrix<Scalar>& dst, const LuSolveExpr<Scalar>& expr) { |
| const DeviceMatrix<Scalar>& A = expr.matrix(); |
| const DeviceMatrix<Scalar>& B = expr.rhs(); |
| |
| eigen_assert(A.rows() == A.cols() && "LU requires a square matrix"); |
| eigen_assert(B.rows() == A.rows() && "LU solve: RHS rows must match matrix size"); |
| |
| const int64_t n = static_cast<int64_t>(A.rows()); |
| const int64_t nrhs = static_cast<int64_t>(B.cols()); |
| |
| if (n == 0 || nrhs == 0) { |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| dst.resize(n, B.cols()); |
| return; |
| } |
| |
| A.waitReady(ctx.stream()); |
| B.waitReady(ctx.stream()); |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| |
| constexpr cudaDataType_t dtype = cuda_data_type<Scalar>::value; |
| const int64_t lda = static_cast<int64_t>(A.outerStride()); |
| const int64_t ldb = static_cast<int64_t>(B.outerStride()); |
| const size_t mat_bytes = static_cast<size_t>(lda) * static_cast<size_t>(n) * sizeof(Scalar); |
| const size_t rhs_bytes = static_cast<size_t>(ldb) * static_cast<size_t>(nrhs) * sizeof(Scalar); |
| const size_t ipiv_bytes = static_cast<size_t>(n) * sizeof(int64_t); |
| |
| DeviceBuffer d_lu(mat_bytes); |
| EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_lu.get(), A.data(), mat_bytes, cudaMemcpyDeviceToDevice, ctx.stream())); |
| |
| DeviceBuffer d_ipiv(ipiv_bytes); |
| |
| // See dispatch_llt_solve: two info slots + single end-of-chain sync. |
| PinnedHostBuffer h_info(2 * sizeof(int)); |
| int* info_words = static_cast<int*>(h_info.get()); |
| |
| CusolverParams params; |
| DeviceBuffer d_info(2 * sizeof(int)); |
| int* d_info_getrf = static_cast<int*>(d_info.get()); |
| int* d_info_getrs = d_info_getrf + 1; |
| size_t dev_ws = 0, host_ws = 0; |
| EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf_bufferSize(ctx.cusolverHandle(), params.p, n, n, dtype, d_lu.get(), lda, dtype, |
| &dev_ws, &host_ws)); |
| |
| DeviceBuffer d_workspace(dev_ws); |
| std::vector<char> h_workspace(host_ws); |
| |
| EIGEN_CUSOLVER_CHECK(cusolverDnXgetrf(ctx.cusolverHandle(), params.p, n, n, dtype, d_lu.get(), lda, |
| static_cast<int64_t*>(d_ipiv.get()), dtype, d_workspace.get(), dev_ws, |
| host_ws > 0 ? h_workspace.data() : nullptr, host_ws, d_info_getrf)); |
| |
| EIGEN_CUDA_RUNTIME_CHECK( |
| cudaMemcpyAsync(&info_words[0], d_info_getrf, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream())); |
| |
| dst.resize(n, B.cols()); |
| EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(dst.data(), B.data(), rhs_bytes, cudaMemcpyDeviceToDevice, ctx.stream())); |
| |
| EIGEN_CUSOLVER_CHECK(cusolverDnXgetrs(ctx.cusolverHandle(), params.p, CUBLAS_OP_N, n, nrhs, dtype, d_lu.get(), lda, |
| static_cast<const int64_t*>(d_ipiv.get()), dtype, dst.data(), |
| static_cast<int64_t>(dst.outerStride()), d_info_getrs)); |
| |
| EIGEN_CUDA_RUNTIME_CHECK( |
| cudaMemcpyAsync(&info_words[1], d_info_getrs, sizeof(int), cudaMemcpyDeviceToHost, ctx.stream())); |
| EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx.stream())); |
| eigen_assert(info_words[0] == 0 && "cuSOLVER LU factorization failed (singular matrix)"); |
| eigen_assert(info_words[1] == 0 && "cuSOLVER LU solve failed"); |
| |
| dst.recordReady(ctx.stream()); |
| } |
| |
| // ---- TRSM dispatch ---------------------------------------------------------- |
| |
| template <typename Scalar, int UpLo> |
| void dispatch_trsm(Context& ctx, DeviceMatrix<Scalar>& dst, const TrsmExpr<Scalar, UpLo>& expr) { |
| const DeviceMatrix<Scalar>& A = expr.matrix(); |
| const DeviceMatrix<Scalar>& B = expr.rhs(); |
| |
| eigen_assert(A.rows() == A.cols() && "TRSM requires a square triangular matrix"); |
| eigen_assert(B.rows() == A.rows() && "TRSM: RHS rows must match matrix size"); |
| |
| eigen_assert(A.rows() <= INT_MAX && B.cols() <= INT_MAX && A.outerStride() <= INT_MAX && |
| "cublasXtrsm dimensions exceed int range"); |
| |
| const int n = static_cast<int>(A.rows()); |
| const int nrhs = static_cast<int>(B.cols()); |
| |
| if (n == 0 || nrhs == 0) { |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| dst.resize(n, B.cols()); |
| return; |
| } |
| |
| A.waitReady(ctx.stream()); |
| B.waitReady(ctx.stream()); |
| eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix TRSM destination aliases triangular operand"); |
| eigen_assert(!aliases_device_memory(dst, B) && "DeviceMatrix TRSM destination aliases RHS operand"); |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| |
| dst.resize(n, B.cols()); |
| const size_t rhs_bytes = static_cast<size_t>(dst.outerStride()) * static_cast<size_t>(nrhs) * sizeof(Scalar); |
| EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(dst.data(), B.data(), rhs_bytes, cudaMemcpyDeviceToDevice, ctx.stream())); |
| |
| constexpr cublasFillMode_t uplo = (UpLo == Lower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; |
| Scalar alpha(1); |
| |
| EIGEN_CUBLAS_CHECK(cublasXtrsm(ctx.cublasHandle(), CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, n, nrhs, |
| &alpha, A.data(), static_cast<int>(A.outerStride()), dst.data(), |
| static_cast<int>(dst.outerStride()))); |
| |
| dst.recordReady(ctx.stream()); |
| } |
| |
| // ---- SYMM/HEMM dispatch ----------------------------------------------------- |
| |
| template <typename Scalar, int UpLo> |
| void dispatch_symm(Context& ctx, DeviceMatrix<Scalar>& dst, const SymmExpr<Scalar, UpLo>& expr) { |
| const DeviceMatrix<Scalar>& A = expr.matrix(); |
| const DeviceMatrix<Scalar>& B = expr.rhs(); |
| |
| eigen_assert(A.rows() == A.cols() && "SYMM requires a square matrix"); |
| eigen_assert(B.rows() == A.rows() && "SYMM: RHS rows must match matrix size"); |
| eigen_assert(A.rows() <= INT_MAX && B.cols() <= INT_MAX && A.outerStride() <= INT_MAX && B.outerStride() <= INT_MAX && |
| "cublasXsymm dimensions exceed int range"); |
| |
| const int m = static_cast<int>(A.rows()); |
| const int n = static_cast<int>(B.cols()); |
| |
| if (m == 0 || n == 0) { |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| dst.resize(m == 0 ? 0 : m, B.cols()); |
| return; |
| } |
| |
| A.waitReady(ctx.stream()); |
| B.waitReady(ctx.stream()); |
| eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix SYMM destination aliases self-adjoint operand"); |
| eigen_assert(!aliases_device_memory(dst, B) && "DeviceMatrix SYMM destination aliases RHS operand"); |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| |
| dst.resize(m, n); |
| |
| constexpr cublasFillMode_t uplo = (UpLo == Lower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; |
| Scalar scalars[2] = {Scalar(1), Scalar(0)}; |
| |
| EIGEN_CUBLAS_CHECK(cublasXsymm(ctx.cublasHandle(), CUBLAS_SIDE_LEFT, uplo, m, n, &scalars[0], A.data(), |
| static_cast<int>(A.outerStride()), B.data(), static_cast<int>(B.outerStride()), |
| &scalars[1], dst.data(), static_cast<int>(dst.outerStride()))); |
| |
| dst.recordReady(ctx.stream()); |
| } |
| |
| // ---- SYRK/HERK dispatch ----------------------------------------------------- |
| |
| template <typename Scalar, int UpLo> |
| void dispatch_syrk(Context& ctx, DeviceMatrix<Scalar>& dst, const SyrkExpr<Scalar, UpLo>& expr, |
| typename NumTraits<Scalar>::Real alpha_val, typename NumTraits<Scalar>::Real beta_val) { |
| using RealScalar = typename NumTraits<Scalar>::Real; |
| const DeviceMatrix<Scalar>& A = expr.matrix(); |
| |
| eigen_assert(A.rows() <= INT_MAX && A.cols() <= INT_MAX && A.outerStride() <= INT_MAX && |
| "cublasXsyrk dimensions exceed int range"); |
| |
| const int n = static_cast<int>(A.rows()); |
| const int k = static_cast<int>(A.cols()); |
| |
| if (n == 0) { |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| dst.resize(0, 0); |
| return; |
| } |
| |
| A.waitReady(ctx.stream()); |
| eigen_assert(!aliases_device_memory(dst, A) && "DeviceMatrix SYRK destination aliases input operand"); |
| if (!dst.empty()) dst.waitReady(ctx.stream()); |
| |
| if (dst.empty() || dst.rows() != n || dst.cols() != n) { |
| dst.resize(n, n); |
| if (beta_val != RealScalar(0)) { |
| EIGEN_CUDA_RUNTIME_CHECK(cudaMemsetAsync(dst.data(), 0, dst.sizeInBytes(), ctx.stream())); |
| } |
| } |
| |
| constexpr cublasFillMode_t uplo = (UpLo == Lower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; |
| |
| EIGEN_CUBLAS_CHECK(cublasXsyrk(ctx.cublasHandle(), uplo, CUBLAS_OP_N, n, k, &alpha_val, A.data(), |
| static_cast<int>(A.outerStride()), &beta_val, dst.data(), |
| static_cast<int>(dst.outerStride()))); |
| |
| dst.recordReady(ctx.stream()); |
| } |
| |
| } // namespace internal |
| |
| // ---- Assignment: d_C.device(ctx) = expr ------------------------------------ |
| |
| template <typename Scalar_> |
| class Assignment { |
| public: |
| using Scalar = Scalar_; |
| |
| Assignment(DeviceMatrix<Scalar>& dst, Context& ctx) : dst_(dst), ctx_(ctx) {} |
| |
| template <typename Lhs, typename Rhs> |
| DeviceMatrix<Scalar>& operator=(const GemmExpr<Lhs, Rhs>& expr) { |
| internal::dispatch_gemm(ctx_, dst_, expr, Scalar(0)); |
| return dst_; |
| } |
| |
| template <typename Lhs, typename Rhs> |
| DeviceMatrix<Scalar>& operator+=(const GemmExpr<Lhs, Rhs>& expr) { |
| internal::dispatch_gemm(ctx_, dst_, expr, Scalar(1)); |
| return dst_; |
| } |
| |
| template <typename Lhs, typename Rhs> |
| DeviceMatrix<Scalar>& operator-=(const GemmExpr<Lhs, Rhs>& expr) { |
| internal::dispatch_gemm(ctx_, dst_, expr, Scalar(1), Scalar(-1)); |
| return dst_; |
| } |
| |
| template <int UpLo> |
| DeviceMatrix<Scalar>& operator=(const LltSolveExpr<Scalar, UpLo>& expr) { |
| internal::dispatch_llt_solve(ctx_, dst_, expr); |
| return dst_; |
| } |
| |
| DeviceMatrix<Scalar>& operator=(const LuSolveExpr<Scalar>& expr) { |
| internal::dispatch_lu_solve(ctx_, dst_, expr); |
| return dst_; |
| } |
| |
| template <int UpLo> |
| DeviceMatrix<Scalar>& operator=(const TrsmExpr<Scalar, UpLo>& expr) { |
| internal::dispatch_trsm(ctx_, dst_, expr); |
| return dst_; |
| } |
| |
| template <int UpLo> |
| DeviceMatrix<Scalar>& operator=(const SymmExpr<Scalar, UpLo>& expr) { |
| internal::dispatch_symm(ctx_, dst_, expr); |
| return dst_; |
| } |
| |
| template <typename Expr> |
| DeviceMatrix<Scalar>& operator=(const Expr&) { |
| static_assert(sizeof(Expr) == 0, |
| "DeviceMatrix expression not supported: no cuBLAS/cuSOLVER mapping. " |
| "Supported: GEMM (A*B), TRSM (.triangularView().solve()), " |
| "SYMM (.selfadjointView()*B), LLT (.llt().solve()), LU (.lu().solve())."); |
| return dst_; |
| } |
| |
| private: |
| DeviceMatrix<Scalar>& dst_; |
| Context& ctx_; |
| }; |
| |
| // ---- Out-of-line Matrix expression operator= definitions ------------------ |
| // Declared in DeviceMatrix.h, defined here because they need Context::threadLocal(). |
| |
| template <typename Scalar_> |
| template <typename Lhs, typename Rhs> |
| DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const GemmExpr<Lhs, Rhs>& expr) { |
| device(Context::threadLocal()) = expr; |
| return *this; |
| } |
| |
| template <typename Scalar_> |
| template <typename Lhs, typename Rhs> |
| DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator+=(const GemmExpr<Lhs, Rhs>& expr) { |
| device(Context::threadLocal()) += expr; |
| return *this; |
| } |
| |
| template <typename Scalar_> |
| template <int UpLo> |
| DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const LltSolveExpr<Scalar_, UpLo>& expr) { |
| device(Context::threadLocal()) = expr; |
| return *this; |
| } |
| |
| template <typename Scalar_> |
| DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const LuSolveExpr<Scalar_>& expr) { |
| device(Context::threadLocal()) = expr; |
| return *this; |
| } |
| |
| template <typename Scalar_> |
| template <int UpLo> |
| DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const TrsmExpr<Scalar_, UpLo>& expr) { |
| device(Context::threadLocal()) = expr; |
| return *this; |
| } |
| |
| template <typename Scalar_> |
| template <int UpLo> |
| DeviceMatrix<Scalar_>& DeviceMatrix<Scalar_>::operator=(const SymmExpr<Scalar_, UpLo>& expr) { |
| device(Context::threadLocal()) = expr; |
| return *this; |
| } |
| |
| // SelfAdjointView::rankUpdate — defined here because it needs Context. |
| template <typename Scalar_, int UpLo_> |
| void SelfAdjointView<Scalar_, UpLo_>::rankUpdate(const DeviceMatrix<Scalar_>& A, RealScalar alpha) { |
| SyrkExpr<Scalar_, UpLo_> expr(A); |
| RealScalar beta = matrix().empty() ? RealScalar(0) : RealScalar(1); |
| internal::dispatch_syrk(Context::threadLocal(), matrix(), expr, alpha, beta); |
| } |
| |
| } // namespace gpu |
| } // namespace Eigen |
| |
| #endif // EIGEN_GPU_DEVICE_DISPATCH_H |