blob: a9c6705362b0fa8487a854df4657885ea3f3d7c1 [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// Lightweight expression types for DeviceMatrix operations.
//
// These are NOT Eigen expression templates. Each type maps 1:1 to a single
// NVIDIA library call (cuBLAS or cuSOLVER). There is no coefficient-level
// evaluation, no lazy fusion, no packet operations.
//
// Expression types:
// AdjointView<S> — d_A.adjoint() → marks ConjTrans for GEMM
// TransposeView<S> — d_A.transpose() → marks Trans for GEMM
// Scaled<Expr> — alpha * expr → carries scalar factor
// gpu::GemmExpr<Lhs, Rhs> — lhs * rhs → dispatches to cublasXgemm
#ifndef EIGEN_GPU_DEVICE_EXPR_H
#define EIGEN_GPU_DEVICE_EXPR_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./CuBlasSupport.h"
namespace Eigen {
namespace gpu {
namespace internal {
// Forward declaration — specializations follow below, after the class definitions.
template <typename Expr>
struct device_expr_traits;
} // namespace internal
// Forward declaration.
template <typename Scalar_>
class DeviceMatrix;
// ---- AdjointView: marks ConjTrans -------------------------------------------
// Returned by DeviceMatrix::adjoint(). Maps to cublasXgemm transA/B = C.
template <typename Scalar_>
class AdjointView {
public:
using Scalar = Scalar_;
explicit AdjointView(const DeviceMatrix<Scalar>& m) : mat_(m) {}
const DeviceMatrix<Scalar>& matrix() const { return mat_; }
private:
const DeviceMatrix<Scalar>& mat_;
};
// ---- TransposeView: marks Trans ---------------------------------------------
// Returned by DeviceMatrix::transpose(). Maps to cublasXgemm transA/B = T.
template <typename Scalar_>
class TransposeView {
public:
using Scalar = Scalar_;
explicit TransposeView(const DeviceMatrix<Scalar>& m) : mat_(m) {}
const DeviceMatrix<Scalar>& matrix() const { return mat_; }
private:
const DeviceMatrix<Scalar>& mat_;
};
// ---- Scaled: alpha * expr ---------------------------------------------------
// Returned by operator*(Scalar, DeviceMatrix/View). Carries the scalar factor.
template <typename Inner>
class Scaled {
public:
using Scalar = typename internal::device_expr_traits<Inner>::scalar_type;
Scaled(Scalar alpha, const Inner& inner) : alpha_(alpha), inner_(inner) {}
Scalar scalar() const { return alpha_; }
const Inner& inner() const { return inner_; }
private:
Scalar alpha_;
const Inner& inner_;
};
// ---- GemmExpr: lhs * rhs -> cublasXgemm ------------------------------------
// Returned by operator*(lhs_expr, rhs_expr). Dispatches to cuBLAS GEMM.
template <typename Lhs, typename Rhs>
class GemmExpr {
public:
using Scalar = typename internal::device_expr_traits<Lhs>::scalar_type;
static_assert(std::is_same<Scalar, typename internal::device_expr_traits<Rhs>::scalar_type>::value,
"DeviceMatrix GEMM: LHS and RHS must have the same scalar type");
GemmExpr(const Lhs& lhs, const Rhs& rhs) : lhs_(lhs), rhs_(rhs) {}
const Lhs& lhs() const { return lhs_; }
const Rhs& rhs() const { return rhs_; }
private:
// Stored by reference — like Eigen's CPU expression templates, these must
// not be captured with auto (the references will dangle). Use .eval() or
// assign to a DeviceMatrix immediately.
const Lhs& lhs_;
const Rhs& rhs_;
};
// ---- Free operator* overloads that produce GemmExpr -------------------------
// Defined after device_expr_traits so it can accept any supported view pair.
// ---- Scalar * Matrix / View -> Scaled ---------------------------------------
template <typename S>
Scaled<DeviceMatrix<S>> operator*(S alpha, const DeviceMatrix<S>& m) {
return {alpha, m};
}
template <typename S>
Scaled<AdjointView<S>> operator*(S alpha, const AdjointView<S>& m) {
return {alpha, m};
}
template <typename S>
Scaled<TransposeView<S>> operator*(S alpha, const TransposeView<S>& m) {
return {alpha, m};
}
namespace internal {
// ---- Traits: extract operation info from expression types -------------------
// Default: a DeviceMatrix is NoTrans.
template <typename T>
struct device_expr_traits {
static constexpr bool is_device_expr = false;
};
template <typename Scalar>
struct device_expr_traits<DeviceMatrix<Scalar>> {
using scalar_type = Scalar;
static constexpr GpuOp op = GpuOp::NoTrans;
static constexpr bool is_device_expr = true;
static const DeviceMatrix<Scalar>& matrix(const DeviceMatrix<Scalar>& x) { return x; }
static Scalar alpha(const DeviceMatrix<Scalar>&) { return Scalar(1); }
};
template <typename Scalar>
struct device_expr_traits<AdjointView<Scalar>> {
using scalar_type = Scalar;
static constexpr GpuOp op = GpuOp::ConjTrans;
static constexpr bool is_device_expr = true;
static const DeviceMatrix<Scalar>& matrix(const AdjointView<Scalar>& x) { return x.matrix(); }
static Scalar alpha(const AdjointView<Scalar>&) { return Scalar(1); }
};
template <typename Scalar>
struct device_expr_traits<TransposeView<Scalar>> {
using scalar_type = Scalar;
static constexpr GpuOp op = GpuOp::Trans;
static constexpr bool is_device_expr = true;
static const DeviceMatrix<Scalar>& matrix(const TransposeView<Scalar>& x) { return x.matrix(); }
static Scalar alpha(const TransposeView<Scalar>&) { return Scalar(1); }
};
template <typename Inner>
struct device_expr_traits<Scaled<Inner>> {
using scalar_type = typename device_expr_traits<Inner>::scalar_type;
static constexpr GpuOp op = device_expr_traits<Inner>::op;
static constexpr bool is_device_expr = true;
static const DeviceMatrix<scalar_type>& matrix(const Scaled<Inner>& x) {
return device_expr_traits<Inner>::matrix(x.inner());
}
static scalar_type alpha(const Scaled<Inner>& x) { return x.scalar() * device_expr_traits<Inner>::alpha(x.inner()); }
};
} // namespace internal
template <typename Lhs, typename Rhs,
typename std::enable_if<internal::device_expr_traits<Lhs>::is_device_expr &&
internal::device_expr_traits<Rhs>::is_device_expr,
int>::type = 0>
GemmExpr<Lhs, Rhs> operator*(const Lhs& a, const Rhs& b) {
return {a, b};
}
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_DEVICE_EXPR_H