blob: a9ed7888afaeed54ef303c508f232a8c953a8875 [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2026 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-License-Identifier: MPL-2.0
// cuFFT support utilities: error checking macro, type mapping.
#ifndef EIGEN_GPU_CUFFT_SUPPORT_H
#define EIGEN_GPU_CUFFT_SUPPORT_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include "./GpuSupport.h"
#include <cufft.h>
namespace Eigen {
namespace gpu {
namespace internal {
// ---- Error checking ---------------------------------------------------------
#define EIGEN_CUFFT_CHECK(x) \
do { \
cufftResult _r = (x); \
eigen_assert(_r == CUFFT_SUCCESS && "cuFFT call failed: " #x); \
EIGEN_UNUSED_VARIABLE(_r); \
} while (0)
// ---- Scalar → cufftType traits ----------------------------------------------
template <typename Scalar>
struct cufft_c2c_type;
template <>
struct cufft_c2c_type<float> {
static constexpr cufftType value = CUFFT_C2C;
};
template <>
struct cufft_c2c_type<double> {
static constexpr cufftType value = CUFFT_Z2Z;
};
template <typename Scalar>
struct cufft_r2c_type;
template <>
struct cufft_r2c_type<float> {
static constexpr cufftType value = CUFFT_R2C;
};
template <>
struct cufft_r2c_type<double> {
static constexpr cufftType value = CUFFT_D2Z;
};
template <typename Scalar>
struct cufft_c2r_type;
template <>
struct cufft_c2r_type<float> {
static constexpr cufftType value = CUFFT_C2R;
};
template <>
struct cufft_c2r_type<double> {
static constexpr cufftType value = CUFFT_Z2D;
};
// ---- Type-dispatched cuFFT execution ----------------------------------------
// C2C
inline cufftResult cufftExecC2C_dispatch(cufftHandle plan, std::complex<float>* in, std::complex<float>* out,
int direction) {
return cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out), direction);
}
inline cufftResult cufftExecC2C_dispatch(cufftHandle plan, std::complex<double>* in, std::complex<double>* out,
int direction) {
return cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out),
direction);
}
// R2C
inline cufftResult cufftExecR2C_dispatch(cufftHandle plan, float* in, std::complex<float>* out) {
return cufftExecR2C(plan, in, reinterpret_cast<cufftComplex*>(out));
}
inline cufftResult cufftExecR2C_dispatch(cufftHandle plan, double* in, std::complex<double>* out) {
return cufftExecD2Z(plan, in, reinterpret_cast<cufftDoubleComplex*>(out));
}
// C2R
inline cufftResult cufftExecC2R_dispatch(cufftHandle plan, std::complex<float>* in, float* out) {
return cufftExecC2R(plan, reinterpret_cast<cufftComplex*>(in), out);
}
inline cufftResult cufftExecC2R_dispatch(cufftHandle plan, std::complex<double>* in, double* out) {
return cufftExecZ2D(plan, reinterpret_cast<cufftDoubleComplex*>(in), out);
}
} // namespace internal
} // namespace gpu
} // namespace Eigen
#endif // EIGEN_GPU_CUFFT_SUPPORT_H