Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Statistics and SparseArrays as extensions #286

Merged
merged 6 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.5.0"
version = "1.6.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics"

[compat]
Aqua = "0.5, 0.6"
julia = "1.6"
Expand All @@ -16,8 +24,10 @@ julia = "1.6"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "Base64", "ReverseDiff", "StaticArrays"]
test = ["Aqua", "Test", "Base64", "ReverseDiff", "SparseArrays", "StaticArrays", "Statistics"]
59 changes: 59 additions & 0 deletions ext/FillArraysSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module FillArraysSparseArraysExt

using SparseArrays
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
import Base: convert, kron
using FillArrays
using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value
using LinearAlgebra

##################
## Sparse arrays
##################
SparseVector{T}(Z::ZerosVector) where T = spzeros(T, length(Z))
SparseVector{Tv,Ti}(Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z))

convert(::Type{AbstractSparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z))
convert(::Type{AbstractSparseVector{T}}, Z::ZerosVector) where T= spzeros(T, length(Z))

SparseMatrixCSC{T}(Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
SparseMatrixCSC{Tv,Ti}(Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} = spzeros(Tv, Ti, size(Z)...)

convert(::Type{AbstractSparseMatrix}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseMatrix{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...)

convert(::Type{AbstractSparseArray}, Z::Zeros{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseArray{Tv}}, Z::Zeros{T}) where {T,Tv} = spzeros(Tv, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Zeros{T}) where {T,Tv,Ti} = spzeros(Tv, Ti, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti,N}}, Z::Zeros{T,N}) where {T,Tv,Ti,N} = spzeros(Tv, Ti, size(Z)...)

SparseMatrixCSC{Tv}(Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
# works around missing `speye`:
SparseMatrixCSC{Tv,Ti}(Z::Eye{T}) where {T,Tv,Ti<:Integer} =
convert(SparseMatrixCSC{Tv,Ti}, SparseMatrixCSC{Tv}(I, size(Z)...))

convert(::Type{AbstractSparseMatrix}, Z::Eye{T}) where {T} = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{AbstractSparseMatrix{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)

convert(::Type{AbstractSparseArray}, Z::Eye{T}) where T = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{AbstractSparseArray{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)


convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)
convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)

function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv}
SparseMatrixCSC{Tv,eltype(axes(R,1))}(R)
end
function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti}
Base.require_one_based_indexing(R)
v = parent(R)
J = getindex_value(v)*I
SparseMatrixCSC{Tv,Ti}(J, size(R))
end

# TODO: remove in v2.0
@deprecate kron(E1::RectDiagonalFill, E2::RectDiagonalFill) kron(sparse(E1), sparse(E2))

end # module
33 changes: 33 additions & 0 deletions ext/FillArraysStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module FillArraysStatisticsExt

import Statistics: mean, var, cov, cor
using LinearAlgebra: diagind

using FillArrays
using FillArrays: AbstractFill, AbstractFillVector, AbstractFillMatrix, getindex_value

mean(A::AbstractFill; dims=(:)) = mean(identity, A; dims=dims)
function mean(f::Union{Function, Type}, A::AbstractFill; dims=(:))
val = float(f(getindex_value(A)))
dims isa Colon ? val :
Fill(val, ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...)
end


function var(A::AbstractFill{T}; corrected::Bool=true, mean=nothing, dims=(:)) where {T<:Number}
dims isa Colon ? zero(float(T)) :
Zeros{float(T)}(ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...)
end

cov(::AbstractFillVector{T}; corrected::Bool=true) where {T<:Number} = zero(float(T))
cov(A::AbstractFillMatrix{T}; corrected::Bool=true, dims::Integer=1) where {T<:Number} =
Zeros{float(T)}(size(A, 3-dims), size(A, 3-dims))

cor(::AbstractFillVector{T}) where {T<:Number} = one(float(T))
function cor(A::AbstractFillMatrix{T}; dims::Integer=1) where {T<:Number}
out = fill(float(T)(NaN), size(A, 3-dims), size(A, 3-dims))
out[diagind(out)] .= 1
out
end

end # module
85 changes: 6 additions & 79 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" `FillArrays` module to lazily represent matrices with a single value """
module FillArrays

using LinearAlgebra, SparseArrays, Statistics
using LinearAlgebra
import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
Expand All @@ -16,9 +16,6 @@ import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose,

import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted

import Statistics: mean, std, var, cov, cor


export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement

import Base: oneto
Expand Down Expand Up @@ -542,52 +539,6 @@ for SMT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal)
end
end

##################
## Sparse arrays
##################
SparseVector{T}(Z::ZerosVector) where T = spzeros(T, length(Z))
SparseVector{Tv,Ti}(Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z))

convert(::Type{AbstractSparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z))
convert(::Type{AbstractSparseVector{T}}, Z::ZerosVector) where T= spzeros(T, length(Z))

SparseMatrixCSC{T}(Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
SparseMatrixCSC{Tv,Ti}(Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} = spzeros(Tv, Ti, size(Z)...)

convert(::Type{AbstractSparseMatrix}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseMatrix{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...)

convert(::Type{AbstractSparseArray}, Z::Zeros{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseArray{Tv}}, Z::Zeros{T}) where {T,Tv} = spzeros(Tv, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Zeros{T}) where {T,Tv,Ti} = spzeros(Tv, Ti, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti,N}}, Z::Zeros{T,N}) where {T,Tv,Ti,N} = spzeros(Tv, Ti, size(Z)...)

SparseMatrixCSC{Tv}(Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
# works around missing `speye`:
SparseMatrixCSC{Tv,Ti}(Z::Eye{T}) where {T,Tv,Ti<:Integer} =
convert(SparseMatrixCSC{Tv,Ti}, SparseMatrixCSC{Tv}(I, size(Z)...))

convert(::Type{AbstractSparseMatrix}, Z::Eye{T}) where {T} = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{AbstractSparseMatrix{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)

convert(::Type{AbstractSparseArray}, Z::Eye{T}) where T = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{AbstractSparseArray{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)


convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)
convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)

function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv}
SparseMatrixCSC{Tv,eltype(axes(R,1))}(R)
end
function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti}
Base.require_one_based_indexing(R)
v = parent(R)
J = getindex_value(v)*I
SparseMatrixCSC{Tv,Ti}(J, size(R))
end

#########
# maximum/minimum
Expand Down Expand Up @@ -698,35 +649,6 @@ function in(x, A::RectDiagonal{<:Number})
x == zero(eltype(A)) || x in A.diag
end

#########
# mean, std
#########

mean(A::AbstractFill; dims=(:)) = mean(identity, A; dims=dims)
function mean(f::Union{Function, Type}, A::AbstractFill; dims=(:))
val = float(f(getindex_value(A)))
dims isa Colon ? val :
Fill(val, ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...)
end


function var(A::AbstractFill{T}; corrected::Bool=true, mean=nothing, dims=(:)) where {T<:Number}
dims isa Colon ? zero(float(T)) :
Zeros{float(T)}(ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...)
end

cov(A::AbstractFillVector{T}; corrected::Bool=true) where {T<:Number} = zero(float(T))
cov(A::AbstractFillMatrix{T}; corrected::Bool=true, dims::Integer=1) where {T<:Number} =
Zeros{float(T)}(size(A, 3-dims), size(A, 3-dims))

cor(A::AbstractFillVector{T}) where {T<:Number} = one(float(T))
function cor(A::AbstractFillMatrix{T}; dims::Integer=1) where {T<:Number}
out = fill(float(T)(NaN), size(A, 3-dims), size(A, 3-dims))
out[LinearAlgebra.diagind(out)] .= 1
out
end


#########
# include
#########
Expand All @@ -735,6 +657,11 @@ include("fillalgebra.jl")
include("fillbroadcast.jl")
include("trues.jl")

@static if !isdefined(Base, :get_extension)
include("../ext/FillArraysSparseArraysExt.jl")
include("../ext/FillArraysStatisticsExt.jl")
end

##
# print
##
Expand Down
1 change: 0 additions & 1 deletion src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,4 +467,3 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat)
sz = _kronsize(f, g)
_kron(f, g, sz)
end
kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2))