Fix FFT when destination does not have unit stride.
diff --git a/unsupported/Eigen/FFT b/unsupported/Eigen/FFT index 630be1e..557fdf6 100644 --- a/unsupported/Eigen/FFT +++ b/unsupported/Eigen/FFT
@@ -231,11 +231,12 @@ THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES) if (nfft < 1) nfft = src.size(); - - if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) - dst.derived().resize((nfft >> 1) + 1); - else - dst.derived().resize(nfft); + + Index dst_size = nfft; + if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) { + dst_size = (nfft >> 1) + 1; + } + dst.derived().resize(dst_size); if (src.innerStride() != 1 || src.size() < nfft) { Matrix<src_type, 1, Dynamic> tmp; @@ -245,9 +246,21 @@ } else { tmp = src; } - fwd(&dst[0], &tmp[0], nfft); + if (dst.innerStride() != 1) { + Matrix<dst_type, 1, Dynamic> out(1, dst_size); + fwd(&out[0], &tmp[0], nfft); + dst.derived() = out; + } else { + fwd(&dst[0], &tmp[0], nfft); + } } else { - fwd(&dst[0], &src[0], nfft); + if (dst.innerStride() != 1) { + Matrix<dst_type, 1, Dynamic> out(1, dst_size); + fwd(&out[0], &src[0], nfft); + dst.derived() = out; + } else { + fwd(&dst[0], &src[0], nfft); + } } } @@ -326,9 +339,22 @@ } else { tmp = src; } - inv(&dst[0], &tmp[0], nfft); + + if (dst.innerStride() != 1) { + Matrix<dst_type, 1, Dynamic> out(1, nfft); + inv(&out[0], &tmp[0], nfft); + dst.derived() = out; + } else { + inv(&dst[0], &tmp[0], nfft); + } } else { - inv(&dst[0], &src[0], nfft); + if (dst.innerStride() != 1) { + Matrix<dst_type, 1, Dynamic> out(1, nfft); + inv(&out[0], &src[0], nfft); + dst.derived() = out; + } else { + inv(&dst[0], &src[0], nfft); + } } }
diff --git a/unsupported/test/fft_test_shared.h b/unsupported/test/fft_test_shared.h index 0e040ad..3adcd90 100644 --- a/unsupported/test/fft_test_shared.h +++ b/unsupported/test/fft_test_shared.h
@@ -164,9 +164,41 @@ } template <typename T> +void test_complex_strided(int nfft) { + typedef typename FFT<T>::Complex Complex; + typedef typename Eigen::Vector<Complex, Dynamic> ComplexVector; + constexpr int kInputStride = 3; + constexpr int kOutputStride = 7; + constexpr int kInvOutputStride = 13; + + FFT<T> fft; + + ComplexVector inbuf(nfft * kInputStride); + inbuf.setRandom(); + ComplexVector outbuf(nfft * kOutputStride); + outbuf.setRandom(); + ComplexVector invoutbuf(nfft * kInvOutputStride); + invoutbuf.setRandom(); + + using StridedComplexVector = Map<ComplexVector, /*MapOptions=*/0, InnerStride<Dynamic>>; + StridedComplexVector input(inbuf.data(), nfft, InnerStride<Dynamic>(kInputStride)); + StridedComplexVector output(outbuf.data(), nfft, InnerStride<Dynamic>(kOutputStride)); + StridedComplexVector inv_output(invoutbuf.data(), nfft, InnerStride<Dynamic>(kInvOutputStride)); + + for (int k = 0; k < nfft; ++k) + input[k] = Complex((T)(rand() / (double)RAND_MAX - .5), (T)(rand() / (double)RAND_MAX - .5)); + fft.fwd(output, input); + + VERIFY(T(fft_rmse(output, input)) < test_precision<T>()); // gross check + fft.inv(inv_output, output); + VERIFY(T(dif_rmse(inv_output, input)) < test_precision<T>()); // gross check +} + +template <typename T> void test_complex(int nfft) { test_complex_generic<StdVectorContainer, T>(nfft); test_complex_generic<EigenVectorContainer, T>(nfft); + test_complex_strided<T>(nfft); } template <typename T, int nrows, int ncols>