Skip to content

Commit

Permalink
Fix inference of FFT plan creation (#2683)
Browse files Browse the repository at this point in the history
  • Loading branch information
jipolanco authored Mar 10, 2025
1 parent 9455b65 commit 6fdca86
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
60 changes: 36 additions & 24 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,23 @@ end
# region is an iterable subset of dimensions
# spec. an integer, range, tuple, or array

# try to constant-propagate the `region` argument when it is not a tuple. This helps with
# inference of calls like plan_fft(X), which is translated by AbstractFFTs.jl into
# plan_fft(X, 1:ndims(X)).
for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft)
@eval begin
Base.@constprop :aggressive function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
R = length(region)
region = NTuple{R,Int}(region)
$f(X, region)
end
end
end

# inplace complex
function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_FORWARD
inplace = true
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
Expand All @@ -166,11 +177,9 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = true
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
Expand All @@ -180,11 +189,9 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
end

# out-of-place complex
function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_FORWARD
inplace = false
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
Expand All @@ -193,11 +200,9 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = false
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
Expand All @@ -207,19 +212,23 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
end

# out-of-place real-to-complex
function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
K = CUFFT_FORWARD
inplace = false
Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
R = length(region)
region = NTuple{R,Int}(region)
plan_rfft(X, region)
end

function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R}
K = CUFFT_FORWARD
inplace = false

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]

handle = cufftGetPlan(complex(T), T, sizex, region)

ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]], 2) + 1
xdims = size(X)
ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1])

# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
Expand All @@ -230,21 +239,24 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
end

# out-of-place complex-to-real
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = false
Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
R = length(region)
region = NTuple{R,Int}(region)
plan_brfft(X, d, region)
end

ydims = collect(size(X))
ydims[region[1]] = d
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = false

handle = cufftGetPlan(real(T), T, (ydims...,), region)
xdims = size(X)
ydims = Base.setindex(xdims, d, region[1])
handle = cufftGetPlan(real(T), T, ydims, region)

buffer = CuArray{T}(undef, size(X))
B = typeof(buffer)

CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), ydims, region, buffer)
end


Expand Down
21 changes: 18 additions & 3 deletions test/libraries/cufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ atol(::Type{Complex{T}}) where {T} = atol(T)
function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)
d_X = CuArray(X)
p = plan_fft(d_X)
p = @inferred plan_fft(d_X)
d_Y = p * d_X
Y = collect(d_Y)
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
Expand All @@ -130,12 +130,16 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
Z = collect(d_Z)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))

pinvb = @inferred plan_bfft(d_Y)
d_Z = pinvb * d_Y
Z = collect(d_Z) ./ length(d_Z)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
end

function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)
d_X = CuArray(X)
p = plan_fft!(d_X)
p = @inferred plan_fft!(d_X)
p * d_X
Y = collect(d_X)
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
Expand All @@ -144,6 +148,12 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
pinv * d_X
Z = collect(d_X)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
p * d_X

pinvb = @inferred plan_bfft!(d_X)
pinvb * d_X
Z = collect(d_X) ./ length(X)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
end

function batched(X::AbstractArray{T,N},region) where {T <: Complex,N}
Expand Down Expand Up @@ -261,7 +271,7 @@ end
function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
fftw_X = rfft(X)
d_X = CuArray(X)
p = plan_rfft(d_X)
p = @inferred plan_rfft(d_X)
d_Y = p * d_X
Y = collect(d_Y)
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
Expand All @@ -280,6 +290,11 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
d_W = pinv3 * d_X
W = collect(d_W)
@test isapprox(W, Y, rtol = rtol(T), atol = atol(T))

pinvb = @inferred plan_brfft(d_Y,size(X,1))
d_Z = pinvb * d_Y
Z = collect(d_Z) ./ length(X)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
end

function batched(X::AbstractArray{T,N},region) where {T <: Real,N}
Expand Down

0 comments on commit 6fdca86

Please sign in to comment.