Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce allocations in dims_howmany #269

Merged
merged 9 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FFTW"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.6.1"
version = "1.7.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
87 changes: 62 additions & 25 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -566,20 +566,51 @@ unsafe_execute!(plan::r2rFFTWPlan{T},
# re-use the table of trigonometric constants from the first plan.

# Compute dims and howmany for FFTW guru planner
function dims_howmany(X::StridedArray, Y::StridedArray,
sz::Vector{Int}, region)
reg = Int[region...]::Vector{Int}
if length(unique(reg)) < length(reg)
_anyrepeated(::Union{Number, AbstractUnitRange}) = false
function _anyrepeated(region)
any(region) do x
count(==(x), region) > 1
end
end

# Utility methods to reduce allocations in dims_howmany
@inline _setindex(oreg, v, n) = (oreg[n] = v; oreg)
@inline _setindex(oreg::Tuple, v, n) = Base.setindex(oreg, v, n)
@inline _filtercoll(region::Union{Int, Tuple}, len) = ntuple(zero, len)
@inline _filtercoll(region, len) = Vector{Int}(undef, len)
# Optimized filter(∉(region), 1:ndims(X))
function _filter_notin_region(region, ::Val{ndimsX}) where {ndimsX}
oreg = _filtercoll(region, ndimsX - length(region))
n = 1
for dim in 1:ndimsX
dim in region && continue
oreg = _setindex(oreg, dim, n)
n += 1
end
oreg
end
function dims_howmany(X::StridedArray, Y::StridedArray, sz, region)
if _anyrepeated(region)
throw(ArgumentError("each dimension can be transformed at most once"))
end
ist = [strides(X)...]
ost = [strides(Y)...]
dims = Matrix(transpose([sz[reg] ist[reg] ost[reg]]))
oreg = [1:ndims(X);]
oreg[reg] .= 0
oreg = filter(d -> d > 0, oreg)
howmany = Matrix(transpose([sz[oreg] ist[oreg] ost[oreg]]))
return (dims, howmany)
ist = strides(X)
ost = strides(Y)
dims = Matrix{Int}(undef, 3, length(region))
for (ind, i) in enumerate(region)
dims[1, ind] = sz[i]
dims[2, ind] = ist[i]
dims[3, ind] = ost[i]
end

oreg = _filter_notin_region(region, Val(ndims(X)))
howmany = Matrix{Int}(undef, 3, length(oreg))
for (ind, i) in enumerate(oreg)
howmany[1, ind] = sz[i]
howmany[2, ind] = ist[i]
howmany[3, ind] = ost[i]
end

return dims, howmany
end

function fix_kinds(region, kinds)
Expand All @@ -604,6 +635,10 @@ function fix_kinds(region, kinds)
return k
end

_circshiftmin1(v) = circshift(collect(Int, v), -1)
_circshiftmin1(t::Tuple) = (t[2:end]..., t[1])
_circshiftmin1(x::Integer) = x

# low-level FFTWPlan creation (for internal use in FFTW module)
for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
(:Float32,:(Complex{Float32}),"fftwf",:libfftw3f))
Expand All @@ -613,7 +648,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
direction = K
unsafe_set_timelimit($Tr, timelimit)
R = isa(region, Tuple) ? region : copy(region)
dims, howmany = dims_howmany(X, Y, [size(X)...], R)
dims, howmany = dims_howmany(X, Y, size(X), R)
plan = ccall(($(string(fftw,"_plan_guru64_dft")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -631,9 +666,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
Y::StridedArray{$Tc,N},
region, flags::Integer, timelimit::Real) where {inplace,N}
R = isa(region, Tuple) ? region : copy(region)
region = circshift(Int[region...],-1) # FFTW halves last dim
regionshft = _circshiftmin1(region) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
dims, howmany = dims_howmany(X, Y, size(X), regionshft)
plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -651,9 +686,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
Y::StridedArray{$Tr,N},
region, flags::Integer, timelimit::Real) where {inplace,N}
R = isa(region, Tuple) ? region : copy(region)
region = circshift(Int[region...],-1) # FFTW halves last dim
regionshft = _circshiftmin1(region) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(Y)...], region)
dims, howmany = dims_howmany(X, Y, size(Y), regionshft)
plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -675,7 +710,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
R = isa(region, Tuple) ? region : copy(region)
knd = fix_kinds(region, kinds)
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
dims, howmany = dims_howmany(X, Y, size(X), region)
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
PlanPtr,
(Int32, Ptr{Int}, Int32, Ptr{Int},
Expand All @@ -698,9 +733,11 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
R = isa(region, Tuple) ? region : copy(region)
knd = fix_kinds(region, kinds)
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
dims[2:3, 1:size(dims,2)] *= 2
howmany[2:3, 1:size(howmany,2)] *= 2
dims, howmany = dims_howmany(X, Y, size(X), region)
@views begin
dims[2:3, :] .*= 2
howmany[2:3, :] .*= 2
end
howmany = [howmany [2,1,1]] # append loop over real/imag parts
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
PlanPtr,
Expand Down Expand Up @@ -759,9 +796,9 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
end
$plan_f(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f(X, 1:ndims(X); kws...)
$plan_f(X, ntuple(identity, ndims(X)); kws...)
$plan_f!(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f!(X, 1:ndims(X); kws...)
$plan_f!(X, ntuple(identity, ndims(X)); kws...)

function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
Expand Down Expand Up @@ -845,8 +882,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
end
end

plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,ntuple(identity, ndims(X));kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,ntuple(identity, ndims(X));kws...)

function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
num_threads::Union{Nothing, Integer} = nothing) where N
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ true_fftd3_m3d[:,:,2] .= -15
end

@testset "rfft/rfftn" begin
# Test regions as int/collection
@test rfft(m4,1) == rfft(m4,1:1) == rfft(m4,(1,)) == rfft(m4, [1])

rfft_m4 = rfft(m4,1)
rfftd2_m4 = rfft(m4,2)
rfftn_m4 = rfft(m4)
Expand Down