@@ -4,23 +4,25 @@ import AbstractFFTs
44import LinearAlgebra. mul!
55using AbstractFFTs: Plan
66
7- mutable struct TestPlan{T,N} <: Plan{T}
7+ mutable struct TestPlan{T,N,inplace } <: Plan{T}
88 region
99 sz:: NTuple{N,Int}
1010 pinv:: Plan{T}
11- function TestPlan {T} (region, sz:: NTuple{N,Int} ) where {T,N}
12- return new {T,N} (region, sz)
11+ function TestPlan {T,inplace } (region, sz:: NTuple{N,Int} ) where {T,N,inplace }
12+ return new {T,N,inplace } (region, sz)
1313 end
1414end
15+ TestPlan {T} (region, sz) where {T} = TestPlan {T,false} (region, sz)
1516
16- mutable struct InverseTestPlan{T,N} <: Plan{T}
17+ mutable struct InverseTestPlan{T,N,inplace } <: Plan{T}
1718 region
1819 sz:: NTuple{N,Int}
1920 pinv:: Plan{T}
20- function InverseTestPlan {T} (region, sz:: NTuple{N,Int} ) where {T,N}
21- return new {T,N} (region, sz)
21+ function InverseTestPlan {T,inplace } (region, sz:: NTuple{N,Int} ) where {T,N,inplace }
22+ return new {T,N,inplace } (region, sz)
2223 end
2324end
25+ InverseTestPlan {T} (region, sz) where {T} = InverseTestPlan {T,false} (region, sz)
2426
2527Base. size (p:: TestPlan ) = p. sz
2628Base. ndims (:: TestPlan{T,N} ) where {T,N} = N
@@ -34,18 +36,25 @@ function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T
3436 return InverseTestPlan {T} (region, size (x))
3537end
3638
37- function AbstractFFTs. plan_inv (p:: TestPlan{T} ) where {T}
38- unscaled_pinv = InverseTestPlan {T} (p. region, p. sz)
39- N = AbstractFFTs. normalization (T, p. sz, p. region)
40- unscaled_pinv. pinv = AbstractFFTs. ScaledPlan (p, N)
41- pinv = AbstractFFTs. ScaledPlan (unscaled_pinv, N)
39+ function AbstractFFTs. plan_fft! (x:: AbstractArray{T} , region; kwargs... ) where {T}
40+ return TestPlan {T,true} (region, size (x))
41+ end
42+ function AbstractFFTs. plan_bfft! (x:: AbstractArray{T} , region; kwargs... ) where {T}
43+ return InverseTestPlan {T,true} (region, size (x))
44+ end
45+
46+ function AbstractFFTs. plan_inv (p:: TestPlan{T,N,inplace} ) where {T,N,inplace}
47+ unscaled_pinv = InverseTestPlan {T,inplace} (p. region, p. sz)
48+ _N = AbstractFFTs. normalization (T, p. sz, p. region)
49+ unscaled_pinv. pinv = AbstractFFTs. ScaledPlan (p, _N)
50+ pinv = AbstractFFTs. ScaledPlan (unscaled_pinv, _N)
4251 return pinv
4352end
44- function AbstractFFTs. plan_inv (pinv:: InverseTestPlan{T} ) where {T}
45- unscaled_p = TestPlan {T} (pinv. region, pinv. sz)
46- N = AbstractFFTs. normalization (T, pinv. sz, pinv. region)
47- unscaled_p. pinv = AbstractFFTs. ScaledPlan (pinv, N )
48- p = AbstractFFTs. ScaledPlan (unscaled_p, N )
53+ function AbstractFFTs. plan_inv (pinv:: InverseTestPlan{T,N,inplace } ) where {T,N,inplace }
54+ unscaled_p = TestPlan {T,inplace } (pinv. region, pinv. sz)
55+ _N = AbstractFFTs. normalization (T, pinv. sz, pinv. region)
56+ unscaled_p. pinv = AbstractFFTs. ScaledPlan (pinv, _N )
57+ p = AbstractFFTs. ScaledPlan (unscaled_p, _N )
4958 return p
5059end
5160
@@ -80,20 +89,23 @@ function dft!(
8089end
8190
8291function mul! (
83- y:: AbstractArray{<:Complex,N} , p:: TestPlan , x:: AbstractArray{<:Union{Complex,Real},N}
84- ) where {N}
92+ y:: AbstractArray{<:Complex,N} , p:: TestPlan{T,N,false} , x:: AbstractArray{<:Union{Complex,Real},N}
93+ ) where {T, N}
8594 size (y) == size (p) == size (x) || throw (DimensionMismatch ())
8695 dft! (y, x, p. region, - 1 )
8796end
8897function mul! (
89- y:: AbstractArray{<:Complex,N} , p:: InverseTestPlan , x:: AbstractArray{<:Union{Complex,Real},N}
90- ) where {N}
98+ y:: AbstractArray{<:Complex,N} , p:: InverseTestPlan{T,N,false} , x:: AbstractArray{<:Union{Complex,Real},N}
99+ ) where {T, N}
91100 size (y) == size (p) == size (x) || throw (DimensionMismatch ())
92101 dft! (y, x, p. region, 1 )
93102end
94103
95- Base.:* (p:: TestPlan , x:: AbstractArray ) = mul! (similar (x, complex (float (eltype (x)))), p, x)
96- Base.:* (p:: InverseTestPlan , x:: AbstractArray ) = mul! (similar (x, complex (float (eltype (x)))), p, x)
104+ Base.:* (p:: TestPlan{T,N,false} , x:: AbstractArray ) where {T,N} = mul! (similar (x, complex (float (eltype (x)))), p, x)
105+ Base.:* (p:: InverseTestPlan{T,N,false} , x:: AbstractArray ) where {T,N} = mul! (similar (x, complex (float (eltype (x)))), p, x)
106+
107+ Base.:* (p:: TestPlan{T,N,true} , x:: AbstractArray ) where {T,N} = copy! (x, dft! (similar (x), x, p. region, - 1 ))
108+ Base.:* (p:: InverseTestPlan{T,N,true} , x:: AbstractArray ) where {T,N} = copy! (x, dft! (similar (x), x, p. region, 1 ))
97109
98110mutable struct TestRPlan{T,N} <: Plan{T}
99111 region
0 commit comments