blob: 0e478d5fc67f42c8f053890f79edb8b02a613d8a [file] [edit]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// GPU FFT via cuFFT.
//
// FFT class with plan caching. Supports 1D and 2D transforms:
// C2C (complex-to-complex), R2C (real-to-complex), C2R (complex-to-real).
//
// Stream and cuBLAS handle come from a gpu::Context — the default
// constructor binds to Context::threadLocal() so an FFT instance shares a
// stream with other GPU operations on the same thread by default. Pass an
// explicit Context to bind to a different stream.
//
// Inverse transforms are scaled by 1/n (1D) or 1/(n*m) (2D) so that
// inv(fwd(x)) == x, matching Eigen's FFT convention.
//
// cuFFT plans are cached by (size, type) and reused across calls.
//
// Thread safety: not thread-safe. Concurrent fwd/inv calls on a single FFT
// instance race on the cached plans and the bound Context. Use one FFT
// instance per thread.
//
// Usage:
// FFT<float> fft; // shares the thread-local Context
// VectorXcf X = fft.fwd(x); // 1D C2C or R2C
// VectorXcf y = fft.inv(X); // 1D C2C inverse
// VectorXf r = fft.invReal(X, n); // 1D C2R inverse
// MatrixXcf B = fft.fwd2(A); // 2D C2C forward
// MatrixXcf C = fft.inv2(B); // 2D C2C inverse
//
// gpu::Context ctx;
// FFT<float> fft2(ctx); // shares ctx's stream/cuBLAS
#ifndef EIGEN_GPU_FFT_H
#define EIGEN_GPU_FFT_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./CuFftSupport.h"
#include "./CuBlasSupport.h"
#include "./GpuContext.h"
#include <unordered_map>
namespace Eigen {
namespace gpu {
template <typename Scalar_>
class FFT {
public:
using Scalar = Scalar_;
using Complex = std::complex<Scalar>;
using ComplexVector = Matrix<Complex, Dynamic, 1>;
using RealVector = Matrix<Scalar, Dynamic, 1>;
using ComplexMatrix = Matrix<Complex, Dynamic, Dynamic, ColMajor>;
/** Construct an FFT bound to the calling thread's default Context.
* The instance is thread-affine: it must not outlive the thread that
* constructed it, since it borrows a pointer into thread-local storage.
* For cross-thread lifetimes, pass an explicit Context. */
FFT() : ctx_(&Context::threadLocal()) {}
/** Construct an FFT bound to the given Context. The Context must outlive
* this FFT instance; this object only borrows its stream and cuBLAS handle. */
explicit FFT(Context& ctx) : ctx_(&ctx) {}
~FFT() {
for (auto& kv : plans_) (void)cufftDestroy(kv.second);
}
FFT(const FFT&) = delete;
FFT& operator=(const FFT&) = delete;
// ---- 1D Complex-to-Complex ------------------------------------------------
/** Forward 1D C2C FFT. */
template <typename Derived, std::enable_if_t<NumTraits<typename Derived::Scalar>::IsComplex>* = nullptr>
ComplexVector fwd(const MatrixBase<Derived>& x) {
const ComplexVector input(x.derived());
const int n = static_cast<int>(input.size());
if (n == 0) return ComplexVector(0);
ensure_buffers(n * sizeof(Complex), n * sizeof(Complex));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_in_.get(), input.data(), n * sizeof(Complex), cudaMemcpyHostToDevice, ctx_->stream()));
cufftHandle plan = get_plan_1d(n, internal::cufft_c2c_type<Scalar>::value);
EIGEN_CUFFT_CHECK(internal::cufftExecC2C_dispatch(plan, static_cast<Complex*>(d_in_.get()),
static_cast<Complex*>(d_out_.get()), CUFFT_FORWARD));
ComplexVector result(n);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(result.data(), d_out_.get(), n * sizeof(Complex), cudaMemcpyDeviceToHost, ctx_->stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
return result;
}
/** Inverse 1D C2C FFT. Scaled by 1/n. */
template <typename Derived>
ComplexVector inv(const MatrixBase<Derived>& X) {
static_assert(NumTraits<typename Derived::Scalar>::IsComplex, "inv() requires complex input");
const ComplexVector input(X.derived());
const int n = static_cast<int>(input.size());
if (n == 0) return ComplexVector(0);
ensure_buffers(n * sizeof(Complex), n * sizeof(Complex));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_in_.get(), input.data(), n * sizeof(Complex), cudaMemcpyHostToDevice, ctx_->stream()));
cufftHandle plan = get_plan_1d(n, internal::cufft_c2c_type<Scalar>::value);
EIGEN_CUFFT_CHECK(internal::cufftExecC2C_dispatch(plan, static_cast<Complex*>(d_in_.get()),
static_cast<Complex*>(d_out_.get()), CUFFT_INVERSE));
// Scale by 1/n.
EIGEN_CUBLAS_CHECK(
internal::cublasXscal(ctx_->cublasHandle(), n, Scalar(1) / Scalar(n), static_cast<Complex*>(d_out_.get()), 1));
ComplexVector result(n);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(result.data(), d_out_.get(), n * sizeof(Complex), cudaMemcpyDeviceToHost, ctx_->stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
return result;
}
// ---- 1D Real-to-Complex ---------------------------------------------------
/** Forward 1D R2C FFT. Returns n/2+1 complex values (half-spectrum). */
template <typename Derived, std::enable_if_t<!NumTraits<typename Derived::Scalar>::IsComplex>* = nullptr>
ComplexVector fwd(const MatrixBase<Derived>& x) {
const RealVector input(x.derived());
const int n = static_cast<int>(input.size());
if (n == 0) return ComplexVector(0);
const int n_complex = n / 2 + 1;
ensure_buffers(n * sizeof(Scalar), n_complex * sizeof(Complex));
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(d_in_.get(), input.data(), n * sizeof(Scalar), cudaMemcpyHostToDevice, ctx_->stream()));
cufftHandle plan = get_plan_1d(n, internal::cufft_r2c_type<Scalar>::value);
EIGEN_CUFFT_CHECK(
internal::cufftExecR2C_dispatch(plan, static_cast<Scalar*>(d_in_.get()), static_cast<Complex*>(d_out_.get())));
ComplexVector result(n_complex);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(result.data(), d_out_.get(), n_complex * sizeof(Complex),
cudaMemcpyDeviceToHost, ctx_->stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
return result;
}
// ---- 1D Complex-to-Real ---------------------------------------------------
/** Inverse 1D C2R FFT. Input is n/2+1 complex values, output is nfft real values.
* Scaled by 1/nfft. Caller must specify nfft (original real signal length). */
template <typename Derived>
RealVector invReal(const MatrixBase<Derived>& X, Index nfft) {
static_assert(NumTraits<typename Derived::Scalar>::IsComplex, "invReal() requires complex input");
const ComplexVector input(X.derived());
const int n = static_cast<int>(nfft);
const int n_complex = n / 2 + 1;
eigen_assert(input.size() == n_complex);
if (n == 0) return RealVector(0);
ensure_buffers(n_complex * sizeof(Complex), n * sizeof(Scalar));
// cuFFT C2R may overwrite the input, so we copy to d_in_.
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_in_.get(), input.data(), n_complex * sizeof(Complex),
cudaMemcpyHostToDevice, ctx_->stream()));
cufftHandle plan = get_plan_1d(n, internal::cufft_c2r_type<Scalar>::value);
EIGEN_CUFFT_CHECK(
internal::cufftExecC2R_dispatch(plan, static_cast<Complex*>(d_in_.get()), static_cast<Scalar*>(d_out_.get())));
// Scale by 1/n.
EIGEN_CUBLAS_CHECK(
internal::cublasXscal(ctx_->cublasHandle(), n, Scalar(1) / Scalar(n), static_cast<Scalar*>(d_out_.get()), 1));
RealVector result(n);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(result.data(), d_out_.get(), n * sizeof(Scalar), cudaMemcpyDeviceToHost, ctx_->stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
return result;
}
// ---- 2D Complex-to-Complex ------------------------------------------------
/** Forward 2D C2C FFT. Input and output are rows x cols complex matrices. */
template <typename Derived>
ComplexMatrix fwd2(const MatrixBase<Derived>& A) {
static_assert(NumTraits<typename Derived::Scalar>::IsComplex, "fwd2() requires complex input");
const ComplexMatrix input(A.derived());
const int rows = static_cast<int>(input.rows());
const int cols = static_cast<int>(input.cols());
if (rows == 0 || cols == 0) return ComplexMatrix(rows, cols);
const size_t total = static_cast<size_t>(rows) * static_cast<size_t>(cols) * sizeof(Complex);
ensure_buffers(total, total);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_in_.get(), input.data(), total, cudaMemcpyHostToDevice, ctx_->stream()));
cufftHandle plan = get_plan_2d(rows, cols, internal::cufft_c2c_type<Scalar>::value);
EIGEN_CUFFT_CHECK(internal::cufftExecC2C_dispatch(plan, static_cast<Complex*>(d_in_.get()),
static_cast<Complex*>(d_out_.get()), CUFFT_FORWARD));
ComplexMatrix result(rows, cols);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(result.data(), d_out_.get(), total, cudaMemcpyDeviceToHost, ctx_->stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
return result;
}
/** Inverse 2D C2C FFT. Scaled by 1/(rows*cols). */
template <typename Derived>
ComplexMatrix inv2(const MatrixBase<Derived>& A) {
static_assert(NumTraits<typename Derived::Scalar>::IsComplex, "inv2() requires complex input");
const ComplexMatrix input(A.derived());
const int rows = static_cast<int>(input.rows());
const int cols = static_cast<int>(input.cols());
if (rows == 0 || cols == 0) return ComplexMatrix(rows, cols);
const size_t total = static_cast<size_t>(rows) * static_cast<size_t>(cols) * sizeof(Complex);
ensure_buffers(total, total);
EIGEN_CUDA_RUNTIME_CHECK(cudaMemcpyAsync(d_in_.get(), input.data(), total, cudaMemcpyHostToDevice, ctx_->stream()));
cufftHandle plan = get_plan_2d(rows, cols, internal::cufft_c2c_type<Scalar>::value);
EIGEN_CUFFT_CHECK(internal::cufftExecC2C_dispatch(plan, static_cast<Complex*>(d_in_.get()),
static_cast<Complex*>(d_out_.get()), CUFFT_INVERSE));
// Scale by 1/(rows*cols).
const int total_elems = rows * cols;
EIGEN_CUBLAS_CHECK(internal::cublasXscal(ctx_->cublasHandle(), total_elems, Scalar(1) / Scalar(total_elems),
static_cast<Complex*>(d_out_.get()), 1));
ComplexMatrix result(rows, cols);
EIGEN_CUDA_RUNTIME_CHECK(
cudaMemcpyAsync(result.data(), d_out_.get(), total, cudaMemcpyDeviceToHost, ctx_->stream()));
EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
return result;
}
// ---- Accessors ------------------------------------------------------------
/** The CUDA stream borrowed from the bound Context. */
cudaStream_t stream() const { return ctx_->stream(); }
/** The Context this FFT is bound to. */
Context& context() const { return *ctx_; }
private:
Context* ctx_;
std::unordered_map<int64_t, cufftHandle> plans_;
internal::DeviceBuffer d_in_;
internal::DeviceBuffer d_out_;
size_t d_in_size_ = 0;
size_t d_out_size_ = 0;
// Buffers grow but never shrink. The pre-realloc sync drains the *bound*
// Context's stream — including unrelated GEMMs/solves/`device(ctx) = ...`
// assignments queued on it — so callers running FFTs alongside other GPU
// work on the same Context should size up front (call fwd/inv with the
// largest expected n once) to avoid mid-pipeline stalls.
void ensure_buffers(size_t in_bytes, size_t out_bytes) {
if (in_bytes > d_in_size_) {
if (d_in_) EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
d_in_ = internal::DeviceBuffer(in_bytes);
d_in_size_ = in_bytes;
}
if (out_bytes > d_out_size_) {
if (d_out_) EIGEN_CUDA_RUNTIME_CHECK(cudaStreamSynchronize(ctx_->stream()));
d_out_ = internal::DeviceBuffer(out_bytes);
d_out_size_ = out_bytes;
}
}
// Plan key encoding: rank (1 bit) | type (4 bits) | dims.
// cufftType uses 7 bits; the top 3 (precision discriminator) are redundant
// since Scalar fixes precision per FFT instance, so mask to 4 bits — without
// it, e.g. plan_key_1d(5, C2C) and plan_key_1d(7, C2C) collide.
static constexpr int64_t kTypeMask = 0xF;
static constexpr int kCols2DBits = 30; // bits 5..34
static constexpr int kRows2DBits = 29; // bits 35..63
static int64_t plan_key_1d(int n, cufftType type) { return (int64_t(n) << 5) | (int64_t(type & kTypeMask) << 1) | 0; }
static int64_t plan_key_2d(int rows, int cols, cufftType type) {
eigen_assert(rows >= 0 && int64_t(rows) < (int64_t(1) << kRows2DBits) &&
"FFT plan rows exceed plan-key bit budget");
eigen_assert(cols >= 0 && int64_t(cols) < (int64_t(1) << kCols2DBits) &&
"FFT plan cols exceed plan-key bit budget");
return (int64_t(rows) << 35) | (int64_t(cols) << 5) | (int64_t(type & kTypeMask) << 1) | 1;
}
cufftHandle get_plan_1d(int n, cufftType type) {
int64_t key = plan_key_1d(n, type);
auto it = plans_.find(key);
if (it != plans_.end()) return it->second;
cufftHandle plan;
EIGEN_CUFFT_CHECK(cufftPlan1d(&plan, n, type, /*batch=*/1));
EIGEN_CUFFT_CHECK(cufftSetStream(plan, ctx_->stream()));
plans_[key] = plan;
return plan;
}
cufftHandle get_plan_2d(int rows, int cols, cufftType type) {
int64_t key = plan_key_2d(rows, cols, type);
auto it = plans_.find(key);
if (it != plans_.end()) return it->second;
// cuFFT uses row-major (C order) for 2D: first dim = rows, second = cols.
// Eigen matrices are column-major, so we pass (cols, rows) to cuFFT
// to get the correct 2D transform.
cufftHandle plan;
EIGEN_CUFFT_CHECK(cufftPlan2d(&plan, cols, rows, type));
EIGEN_CUFFT_CHECK(cufftSetStream(plan, ctx_->stream()));
plans_[key] = plan;
return plan;
}
};
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_FFT_H