Skip to content

Commit

Permalink
Add compatibility for FFTW.jl and AbstractFFTs.jl (#2)
Browse files Browse the repository at this point in the history
* Implement FFT methods

* * src/fft.jl: fix fft pointing to wrong package

* Add non-time tests and minor bump
  • Loading branch information
Tokazama authored Jun 25, 2020
1 parent 9af45bd commit ff55da2
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 8 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
name = "TimeAxes"
uuid = "9a9fc9a6-283d-47e9-a2f6-b3a44e559ea3"
authors = ["Zachary P. Christensen <[email protected]>"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AxisIndices = "f52c9ee2-1b1c-4fd8-8546-6350938c7f11"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -16,7 +18,6 @@ NamedDims = "0.2"
Reexport = "0.2"
julia = "1"


[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3 changes: 3 additions & 0 deletions src/TimeAxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module TimeAxes
end TimeAxes

using AxisIndices
using FFTW
using AbstractFFTs
using NamedDims
using IntervalSets
using Reexport
Expand Down Expand Up @@ -42,5 +44,6 @@ export
include("timedim.jl")
include("timeaxis.jl")
include("timestamps.jl")
include("fft.jl")

end # module
98 changes: 98 additions & 0 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# TODO if FFTW functionality is merged int NamedDims then this code needs to be changed

###
### fft_axes
###
function fft_axis(fxn::Function, axis::AbstractAxis, inds::AbstractUnitRange)
return assign_indices(axis, inds) # default
end

@inline function fft_axes(fxn::Function, axs::Tuple, inds::Tuple, dims::Tuple, cnt::Int=1)
if first(dims) === cnt
return (fft_axis(fxn, first(axs), first(inds)),
fft_axes(fxn, tail(axs), tail(inds), tail(dims), cnt + 1)...)
else
return (assign_indices(first(axs), first(inds)),
fft_axes(fxn, tail(axs), tail(inds), dims, cnt + 1)...)
end
end
fft_axes(fxn::Function, axs::Tuple, inds::Tuple, dims::Tuple{}, cnt::Int=1) = map(assign_indices, axs, inds)
fft_axes(fxn::Function, axs::Tuple{}, inds::Tuple{}, dims::Tuple{}, cnt::Int=1) = ()


###
### fft_names
###
fft_name(fxn::Function, dname::Symbol) = dname # default is to do nothing
@inline function fft_names(fxn::Function, dnames::NTuple{N,Symbol}, dims::Tuple{Vararg{Int}}) where {N}
ntuple(Val(N)) do i
if i in dims
fft_name(fxn, getfield(dnames, i))
else
getfield(dnames, i)
end
end
end


for f in (:fft, :ifft, :bfft)
@eval begin
function AbstractFFTs.$f(A::AbstractAxisArray, dims)
p = AbstractFFTs.$f(parent(A), dims)
axs = fft_axes(AbstractFFTs.$f, axes(A), axes(p), dims)
return unsafe_reconstruct(A, p, axs)
end

function AbstractFFTs.$f(A::NamedDimsArray, dims::Union{Symbol,Integer})
return AbstractFFTs.$f(A, (NamedDims.dim(dimnames(A), dims),))
end

function AbstractFFTs.$f(A::NamedDimsArray, dims::Tuple)
if dims isa Tuple{Vararg{<:Integer}}
dn = fft_names(AbstractFFTs.$f, dimnames(A), dims)
return NamedDimsArray{dn}(AbstractFFTs.$f(parent(A), dims))
else
return AbstractFFTs.$f(A, (NamedDims.dims(dimnames(A), dims),))
end
end

function AbstractFFTs.$f(A::NamedDimsArray{L,T,N}) where {L,T,N}
if has_timedim(A)
return AbstractFFTs.$f(A, (timedim(A),))
else
return AbstractFFTs.$f(A, ntuple(+, Val(N)))
end
end
end
end

for f in (:dct, :idct)
@eval begin
function FFTW.$f(A::AbstractAxisArray, dims)
p = FFTW.$f(parent(A), dims)
axs = fft_axes(FFTW.$f, axes(A), axes(p), dims)
return unsafe_reconstruct(A, p, axs)
end

function FFTW.$f(A::NamedDimsArray, dims::Union{Symbol,Integer})
return FFTW.$f(A, (NamedDims.dim(dimnames(A), dims),))
end

function FFTW.$f(A::NamedDimsArray, dims::Tuple)
if dims isa Tuple{Vararg{<:Integer}}
dn = fft_names(FFTW.$f, dimnames(A), dims)
return NamedDimsArray{dn}(FFTW.$f(parent(A), dims))
else
return FFTW.$f(A, (NamedDims.dim(dimnames(A), dims),))
end
end

function FFTW.$f(A::NamedDimsArray{L,T,N}) where {L,T,N}
if has_timedim(A)
return FFTW.$f(A, (timedim(A),))
else
return FFTW.$f(A, ntuple(+, Val(N)))
end
end
end
end
28 changes: 25 additions & 3 deletions src/timedim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@

Base.@pure is_time(x::Symbol) = (x === :time) || (x === :Time)

AxisIndices.@defdim time is_time
AxisIndices.@defdim(time, is_time, false)

"""
ntime(x) -> Int
Returns the size along the dimension corresponding to the time. Defaults to 1
"""
@inline function ntime(x)
if has_timedim(x)
return size(x, timedim(x))
else
return 1
end
end

"""
time_end(x)
Expand Down Expand Up @@ -43,8 +56,18 @@ sampling_rate(x) = 1 / time_step(x)
Throw an error if the `x` has a time dimension that is not the last dimension.
"""
@inline assert_timedim_last(x) = is_time(last(dimnames(x)))
@inline function assert_timedim_last(x::AbstractArray{T,N}) where {T,N}
if has_timedim(x)
if timedim(x) === N
return nothing
else
error("time dimension is not last")
end

else
return nothing
end
end

"""
lead(A::AbstractArray, n::Integer)
Expand Down Expand Up @@ -205,4 +228,3 @@ end
return to_axis(axis, keys(axis)[(firstindex(axis) + n):lastindex(axis)], newinds)
end
end

72 changes: 69 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ using Unitful

nia = NamedAxisArray(reshape(1:6, 2, 3), x = 2:3, time = 3.0:5.0)
@test has_timedim(nia)
@test @inferred(assert_timedim_last(nia))
@test @inferred(!assert_timedim_last(NamedAxisArray(reshape(1:6, 3, 2), time = 3.0:5.0, x = 2:3)))
@test @inferred(assert_timedim_last(nia)) == nothing
@test_throws ErrorException assert_timedim_last(NamedAxisArray(reshape(1:6, 3, 2), time = 3.0:5.0, x = 2:3))
@test !has_timedim(parent(nia))
@test @inferred(time_keys(nia)) == 3:5
@test @inferred(ntime(nia)) == 3
Expand Down Expand Up @@ -46,10 +46,76 @@ t2 = @inferred(t[:ts1..:ts2])
@test values(t2) == 1:3
@test keys(t2) == Second(1):Second(1):Second(3)

@testset "fft-tests" begin
A = reshape(1:6, 2, 3)
A_named_axes = NamedAxisArray(A, x = 2:3, time = 3.0:5.0)
A_named_axes_no_time = NamedAxisArray(A, x = 2:3, y = 3.0:5.0)


@testset "fft" begin
A_fft = TimeAxes.fft(A, 2)
A_named_axes_fft = TimeAxes.fft(A_named_axes, 2)
@test A_fft == A_named_axes_fft
@test typeof(A_named_axes_fft) <: NamedAxisArray{(:x, :time)}

A_fft = TimeAxes.fft(A)
A_named_axes_no_time_fft = TimeAxes.fft(A_named_axes_no_time)
@test A_fft == A_named_axes_no_time_fft
@test typeof(A_named_axes_no_time_fft) <: NamedAxisArray{(:x, :y)}
end

@testset "ifft" begin
A_ifft = TimeAxes.ifft(A, 2)
A_named_axes_ifft = TimeAxes.ifft(A_named_axes, 2)
@test A_ifft == A_named_axes_ifft
@test typeof(A_named_axes_ifft) <: NamedAxisArray{(:x, :time)}

A_fft = TimeAxes.ifft(A)
A_named_axes_no_time_fft = TimeAxes.ifft(A_named_axes_no_time)
@test A_fft == A_named_axes_no_time_fft
@test typeof(A_named_axes_no_time_fft) <: NamedAxisArray{(:x, :y)}
end

@testset "bfft" begin
A_bfft = TimeAxes.bfft(A, 2)
A_named_axes_bfft = TimeAxes.bfft(A_named_axes, 2)
@test A_bfft == A_named_axes_bfft
@test typeof(A_named_axes_bfft) <: NamedAxisArray{(:x, :time)}

A_fft = TimeAxes.bfft(A)
A_named_axes_no_time_fft = TimeAxes.bfft(A_named_axes_no_time)
@test A_fft == A_named_axes_no_time_fft
@test typeof(A_named_axes_no_time_fft) <: NamedAxisArray{(:x, :y)}
end

@testset "dct" begin
A_dct = TimeAxes.dct(A, 2)
A_named_axes_dct = TimeAxes.dct(A_named_axes, 2)
@test A_dct == A_named_axes_dct
@test typeof(A_named_axes_dct) <: NamedAxisArray{(:x, :time)}

A_fft = TimeAxes.dct(A)
A_named_axes_no_time_fft = TimeAxes.dct(A_named_axes_no_time)
@test A_fft == A_named_axes_no_time_fft
@test typeof(A_named_axes_no_time_fft) <: NamedAxisArray{(:x, :y)}
end

@testset "idct" begin
A_idct = TimeAxes.idct(A, 2)
A_named_axes_idct = TimeAxes.idct(A_named_axes, 2)
@test A_idct == A_named_axes_idct
@test typeof(A_named_axes_idct) <: NamedAxisArray{(:x, :time)}

A_fft = TimeAxes.idct(A)
A_named_axes_no_time_fft = TimeAxes.idct(A_named_axes_no_time)
@test A_fft == A_named_axes_no_time_fft
@test typeof(A_named_axes_no_time_fft) <: NamedAxisArray{(:x, :y)}
end
end

# this avoids errors due to differences in how Symbols are printing between versions of Julia
if !(VERSION < v"1.4")
@testset "docs" begin
doctest(TimeAxes)
end
end

2 comments on commit ff55da2

@Tokazama
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/16960

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" ff55da20f32dadacf88f15e82b8c701826b15ccc
git push origin v0.2.1

Please sign in to comment.