Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference of FFT plan creation #2683

Merged
merged 4 commits into from
Mar 10, 2025

Conversation

jipolanco
Copy link
Contributor

This PR fixes a couple of type inference issues when creating FFT plans (on Julia 1.11.3 at least):

  1. Since [CUFFT] Preallocate a buffer for complex-to-real FFT #2578, the return type of plan_rfft(u, region) is not fully inferred since the number of dimensions of the created buffer is itself not inferred. This is even the case when region is a NTuple, which should in principle allow for inference since the number of dimensions to be transformed is statically known.

  2. Plan creation without the region argument may or may not be inferred, as this relies on constant propagation by the Julia compiler. This is because AbstractFFTs.jl currently sets region = 1:ndims(u), and inference relies on this argument being constant-propagated. A solution would be to set region = ntuple(identity, ndims(u)) in AbstractFFTs.jl, but unfortunately an existent PR doing precisely that is still open.

Issue 1 is fixed by avoiding things like collect(size(X)) and array splatting, and directly working with tuples instead.

A workaround for issue 2 is to inline calls to plan_*, at least when the region argument is not a tuple. An alternative would be to use Base.@constprop :aggressive, but this needs Julia 1.10.

Copy link
Contributor

github-actions bot commented Mar 6, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl
index 8b173f59d..cea5372f0 100644
--- a/lib/cufft/fft.jl
+++ b/lib/cufft/fft.jl
@@ -157,16 +157,16 @@ end
 # plan_fft(X, 1:ndims(X)).
 for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft)
     @eval begin
-        Base.@constprop :aggressive function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
+        Base.@constprop :aggressive function $f(X::DenseCuArray{T, N}, region) where {T <: cufftComplexes, N}
             R = length(region)
-            region = NTuple{R,Int}(region)
+            region = NTuple{R, Int}(region)
             $f(X, region)
         end
     end
 end
 
 # inplace complex
-function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
+function plan_fft!(X::DenseCuArray{T, N}, region::NTuple{R, Int}) where {T <: cufftComplexes, N, R}
     K = CUFFT_FORWARD
     inplace = true
 
@@ -177,7 +177,7 @@ function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftC
     CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
 end
 
-function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
+function plan_bfft!(X::DenseCuArray{T, N}, region::NTuple{R, Int}) where {T <: cufftComplexes, N, R}
     K = CUFFT_INVERSE
     inplace = true
 
@@ -189,7 +189,7 @@ function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufft
 end
 
 # out-of-place complex
-function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
+function plan_fft(X::DenseCuArray{T, N}, region::NTuple{R, Int}) where {T <: cufftComplexes, N, R}
     K = CUFFT_FORWARD
     inplace = false
 
@@ -200,7 +200,7 @@ function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftCo
     CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
 end
 
-function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
+function plan_bfft(X::DenseCuArray{T, N}, region::NTuple{R, Int}) where {T <: cufftComplexes, N, R}
     K = CUFFT_INVERSE
     inplace = false
 
@@ -212,13 +212,13 @@ function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftC
 end
 
 # out-of-place real-to-complex
-Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
+Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T, N}, region) where {T <: cufftReals, N}
     R = length(region)
     region = NTuple{R,Int}(region)
     plan_rfft(X, region)
 end
 
-function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R}
+function plan_rfft(X::DenseCuArray{T, N}, region::NTuple{R, Int}) where {T <: cufftReals, N, R}
     K = CUFFT_FORWARD
     inplace = false
 
@@ -239,13 +239,13 @@ function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftR
 end
 
 # out-of-place complex-to-real
-Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
+Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T, N}, d::Integer, region) where {T <: cufftComplexes, N}
     R = length(region)
     region = NTuple{R,Int}(region)
     plan_brfft(X, d, region)
 end
 
-function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
+function plan_brfft(X::DenseCuArray{T, N}, d::Integer, region::NTuple{R, Int}) where {T <: cufftComplexes, N, R}
     K = CUFFT_INVERSE
     inplace = false
 
@@ -256,7 +256,7 @@ function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) whe
     buffer = CuArray{T}(undef, size(X))
     B = typeof(buffer)
 
-    CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), ydims, region, buffer)
+    return CuFFTPlan{real(T), T, K, inplace, N, R, B}(handle, size(X), ydims, region, buffer)
 end
 
 
diff --git a/test/libraries/cufft.jl b/test/libraries/cufft.jl
index 6dc321295..7b36e2e58 100644
--- a/test/libraries/cufft.jl
+++ b/test/libraries/cufft.jl
@@ -133,7 +133,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
     pinvb = @inferred plan_bfft(d_Y)
     d_Z = pinvb * d_Y
     Z = collect(d_Z) ./ length(d_Z)
-    @test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
+    return @test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
 end
 
 function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
@@ -153,7 +153,7 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
     pinvb = @inferred plan_bfft!(d_X)
     pinvb * d_X
     Z = collect(d_X) ./ length(X)
-    @test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
+    return @test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
 end
 
 function batched(X::AbstractArray{T,N},region) where {T <: Complex,N}
@@ -291,10 +291,10 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
     W = collect(d_W)
     @test isapprox(W, Y, rtol = rtol(T), atol = atol(T))
 
-    pinvb = @inferred plan_brfft(d_Y,size(X,1))
+    pinvb = @inferred plan_brfft(d_Y, size(X, 1))
     d_Z = pinvb * d_Y
     Z = collect(d_Z) ./ length(X)
-    @test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
+    return @test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
 end
 
 function batched(X::AbstractArray{T,N},region) where {T <: Real,N}

@maleadt
Copy link
Member

maleadt commented Mar 7, 2025

An alternative would be to use Base.@constprop :aggressive, but this needs Julia 1.10.

I'm fine with bumping compat to 1.10, since that is officially the latest LTS now. Also, GPUCompiler.jl is already requiring 1.10, so this wouldn't change much.

@maleadt maleadt requested a review from amontoison March 7, 2025 07:01
@maleadt maleadt added enhancement New feature or request cuda libraries Stuff about CUDA library wrappers. labels Mar 7, 2025
@jipolanco
Copy link
Contributor Author

My mistake, Base.@constprop :aggressive has been available since Julia 1.7 (JuliaLang/julia#42125), so this shouldn''t require bumping the Julia compat version. In fact Julia 1.10 is needed for @constprop usage within a function body (I had misread the docs).

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Benchmark suite Current: 49d9772 Previous: 2540087 Ratio
latency/precompile 46281544368 ns 46159355585.5 ns 1.00
latency/ttfp 7045187008 ns 6957794099 ns 1.01
latency/import 3675819118 ns 3631949691 ns 1.01
integration/volumerhs 9624044.5 ns 9611328 ns 1.00
integration/byval/slices=1 147184 ns 146783 ns 1.00
integration/byval/slices=3 425503 ns 425400 ns 1.00
integration/byval/reference 145004 ns 145027 ns 1.00
integration/byval/slices=2 286413 ns 286289 ns 1.00
integration/cudadevrt 103420 ns 103478 ns 1.00
kernel/indexing 14041 ns 14116.5 ns 0.99
kernel/indexing_checked 14947.5 ns 14717 ns 1.02
kernel/occupancy 639.7973856209151 ns 631.1345029239766 ns 1.01
kernel/launch 2030.2 ns 2048.6 ns 0.99
kernel/rand 14763 ns 15306 ns 0.96
array/reverse/1d 19614 ns 19766 ns 0.99
array/reverse/2d 24657 ns 23320 ns 1.06
array/reverse/1d_inplace 10517 ns 10115 ns 1.04
array/reverse/2d_inplace 12155 ns 11769 ns 1.03
array/copy 21050 ns 21157 ns 0.99
array/iteration/findall/int 158220 ns 159462 ns 0.99
array/iteration/findall/bool 138957 ns 139925 ns 0.99
array/iteration/findfirst/int 153869 ns 154378.5 ns 1.00
array/iteration/findfirst/bool 154538 ns 155307 ns 1.00
array/iteration/scalar 72209.5 ns 74069 ns 0.97
array/iteration/logical 215666.5 ns 217157 ns 0.99
array/iteration/findmin/1d 41565.5 ns 41948 ns 0.99
array/iteration/findmin/2d 94058 ns 94228.5 ns 1.00
array/reductions/reduce/1d 35775 ns 36315 ns 0.99
array/reductions/reduce/2d 44048.5 ns 40828 ns 1.08
array/reductions/mapreduce/1d 33160 ns 33656 ns 0.99
array/reductions/mapreduce/2d 50749 ns 51092 ns 0.99
array/broadcast 20949 ns 21055 ns 0.99
array/copyto!/gpu_to_gpu 11933 ns 11958 ns 1.00
array/copyto!/cpu_to_gpu 209886 ns 211255 ns 0.99
array/copyto!/gpu_to_cpu 245272 ns 244927 ns 1.00
array/accumulate/1d 108803 ns 109111 ns 1.00
array/accumulate/2d 80424 ns 80441 ns 1.00
array/construct 1254.1 ns 1271.2 ns 0.99
array/random/randn/Float32 43719.5 ns 47305 ns 0.92
array/random/randn!/Float32 26497 ns 26683 ns 0.99
array/random/rand!/Int64 27120 ns 27072 ns 1.00
array/random/rand!/Float32 8828.5 ns 8803.333333333334 ns 1.00
array/random/rand/Int64 37964.5 ns 29816 ns 1.27
array/random/rand/Float32 13082 ns 13176 ns 0.99
array/permutedims/4d 61439 ns 61379.5 ns 1.00
array/permutedims/2d 55755 ns 55887 ns 1.00
array/permutedims/3d 56092 ns 56426.5 ns 0.99
array/sorting/1d 2776813 ns 2766596.5 ns 1.00
array/sorting/by 3368603 ns 3370290 ns 1.00
array/sorting/2d 1085307 ns 1085614.5 ns 1.00
cuda/synchronization/stream/auto 1043.7 ns 1027.9 ns 1.02
cuda/synchronization/stream/nonblocking 6367.4 ns 6419.6 ns 0.99
cuda/synchronization/stream/blocking 833.1058823529412 ns 800.59 ns 1.04
cuda/synchronization/context/auto 1162.4 ns 1152.4 ns 1.01
cuda/synchronization/context/nonblocking 6633.2 ns 6599.2 ns 1.01
cuda/synchronization/context/blocking 893.2244897959183 ns 899.7446808510638 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

codecov bot commented Mar 7, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.17%. Comparing base (2540087) to head (49d9772).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2683      +/-   ##
==========================================
+ Coverage   82.08%   82.17%   +0.08%     
==========================================
  Files         154      154              
  Lines       13661    13661              
==========================================
+ Hits        11214    11226      +12     
+ Misses       2447     2435      -12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@maleadt maleadt merged commit 6fdca86 into JuliaGPU:master Mar 10, 2025
1 of 3 checks passed
@jipolanco jipolanco deleted the jip/inference-fft-plans branch March 10, 2025 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda libraries Stuff about CUDA library wrappers. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants