Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce OneElementArray #26

Merged
merged 4 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.10"
version = "0.2.11"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"

Expand All @@ -17,6 +18,7 @@ Aqua = "0.8.9"
ArrayLayouts = "1.11.0"
DerivableInterfaces = "0.3.7"
Dictionaries = "0.4.3"
FillArrays = "1.13.0"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.5"
SafeTestsets = "0.1"
Expand Down
5 changes: 5 additions & 0 deletions src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ module SparseArraysBase
export SparseArrayDOK,
SparseMatrixDOK,
SparseVectorDOK,
OneElementArray,
OneElementMatrix,
OneElementVector,
eachstoredindex,
isstored,
oneelementarray,
storedlength,
storedpairs,
storedvalues
Expand All @@ -14,5 +18,6 @@ include("sparsearrayinterface.jl")
include("wrappers.jl")
include("abstractsparsearray.jl")
include("sparsearraydok.jl")
include("oneelementarray.jl")

end
275 changes: 275 additions & 0 deletions src/oneelementarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
using FillArrays: Fill

# Like [`FillArrays.OneElement`](https://github.com/JuliaArrays/FillArrays.jl)
# and [`OneHotArrays.OneHotArray`](https://github.com/FluxML/OneHotArrays.jl).
struct OneElementArray{T,N,I,A,F} <: AbstractSparseArray{T,N}
value::T
index::I
axes::A
getunstoredindex::F
end

using DerivableInterfaces: @array_aliases
# Define `OneElementMatrix`, `AnyOneElementArray`, etc.
@array_aliases OneElementArray

function OneElementArray{T,N}(
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}, getunstoredindex
) where {T,N}
return OneElementArray{T,N,typeof(index),typeof(axes),typeof(getunstoredindex)}(
value, index, axes, getunstoredindex
)
end

function OneElementArray{T,N}(
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {T,N}
return OneElementArray{T,N}(value, index, axes, default_getunstoredindex)
end
function OneElementArray{<:Any,N}(
value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {T,N}
return OneElementArray{T,N}(value, index, axes)
end
function OneElementArray(
value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {T,N}
return OneElementArray{T,N}(value, index, axes)
end

function OneElementArray{T,N}(
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {T,N}
return OneElementArray{T,N}(one(T), index, axes)
end
function OneElementArray{<:Any,N}(
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {N}
return OneElementArray{Bool,N}(index, axes)
end
function OneElementArray{T}(
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {T,N}
return OneElementArray{T,N}(index, axes)
end
function OneElementArray(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N}
return OneElementArray{Bool,N}(index, axes)
end

function OneElementArray{T,N}(
value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
) where {T,N}
return OneElementArray{T,N}(value, last.(ax_ind), first.(ax_ind))
end
function OneElementArray{<:Any,N}(
value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
) where {T,N}
return OneElementArray{T,N}(value, ax_ind...)
end
function OneElementArray{T}(

Check warning on line 69 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L69

Added line #L69 was not covered by tests
value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
) where {T,N}
return OneElementArray{T,N}(value, ax_ind...)

Check warning on line 72 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L72

Added line #L72 was not covered by tests
end
function OneElementArray(
value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
) where {T,N}
return OneElementArray{T,N}(value, ax_ind...)
end

function OneElementArray{T,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N}
return OneElementArray{T,N}(last.(ax_ind), first.(ax_ind))
end
function OneElementArray{<:Any,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N}
return OneElementArray{Bool,N}(ax_ind...)
end
function OneElementArray{T}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N}
return OneElementArray{T,N}(ax_ind...)
end
function OneElementArray(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N}
return OneElementArray{Bool,N}(ax_ind...)
end

# Fix ambiguity errors.
function OneElementArray{T,0}(value, index::Tuple{}, axes::Tuple{}) where {T}
return OneElementArray{T,0}(value, index, axes, default_getunstoredindex)
end
function OneElementArray{<:Any,0}(value::T, index::Tuple{}, axes::Tuple{}) where {T}
return OneElementArray{T,0}(value, index, axes)

Check warning on line 98 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L97-L98

Added lines #L97 - L98 were not covered by tests
end
function OneElementArray{T}(value, index::Tuple{}, axes::Tuple{}) where {T}
return OneElementArray{T,0}(value, index, axes)

Check warning on line 101 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L100-L101

Added lines #L100 - L101 were not covered by tests
end
function OneElementArray(value::T, index::Tuple{}, axes::Tuple{}) where {T}
return OneElementArray{T,0}(value, index, axes)

Check warning on line 104 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end

# Fix ambiguity errors.
function OneElementArray{T,0}(index::Tuple{}, axes::Tuple{}) where {T}
return OneElementArray{T,0}(one(T), index, axes)
end
function OneElementArray{<:Any,0}(index::Tuple{}, axes::Tuple{})
return OneElementArray{Bool,0}(index, axes)

Check warning on line 112 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L111-L112

Added lines #L111 - L112 were not covered by tests
end
function OneElementArray{T}(index::Tuple{}, axes::Tuple{}) where {T}
return OneElementArray{T,0}(index, axes)

Check warning on line 115 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L114-L115

Added lines #L114 - L115 were not covered by tests
end
function OneElementArray(index::Tuple{}, axes::Tuple{})
return OneElementArray{Bool,0}(value, index, axes)

Check warning on line 118 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L117-L118

Added lines #L117 - L118 were not covered by tests
end

function OneElementArray{T,0}(value) where {T}
return OneElementArray{T,0}(value, (), ())

Check warning on line 122 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L121-L122

Added lines #L121 - L122 were not covered by tests
end
function OneElementArray{<:Any,0}(value::T) where {T}
return OneElementArray{T,0}(value)

Check warning on line 125 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
end
function OneElementArray{T}(value) where {T}
return OneElementArray{T,0}(value)

Check warning on line 128 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L127-L128

Added lines #L127 - L128 were not covered by tests
end
function OneElementArray(value::T) where {T}
return OneElementArray{T}(value)

Check warning on line 131 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L130-L131

Added lines #L130 - L131 were not covered by tests
end

function OneElementArray{T,0}() where {T}
return OneElementArray{T,0}((), ())
end
function OneElementArray{<:Any,0}()
return OneElementArray{Bool,0}(value)

Check warning on line 138 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L137-L138

Added lines #L137 - L138 were not covered by tests
end
function OneElementArray{T}() where {T}
return OneElementArray{T,0}()
end
function OneElementArray()
return OneElementArray{Bool}()
end

function OneElementArray{T,N}(
value, index::NTuple{N,Int}, size::NTuple{N,Integer}
) where {T,N}
return OneElementArray{T,N}(value, index, Base.oneto.(size))
end
function OneElementArray{<:Any,N}(
value::T, index::NTuple{N,Int}, size::NTuple{N,Integer}
) where {T,N}
return OneElementArray{T,N}(value, index, size)
end
function OneElementArray{T}(
value, index::NTuple{N,Int}, size::NTuple{N,Integer}
) where {T,N}
return OneElementArray{T,N}(value, index, size)
end
function OneElementArray(
value::T, index::NTuple{N,Int}, size::NTuple{N,Integer}
) where {T,N}
return OneElementArray{T,N}(value, index, Base.oneto.(size))
end

function OneElementArray{T,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N}
return OneElementArray{T,N}(one(T), index, size)
end
function OneElementArray{<:Any,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
return OneElementArray{Bool,N}(index, size)
end
function OneElementArray{T}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N}
return OneElementArray{T,N}(index, size)
end
function OneElementArray(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
return OneElementArray{Bool,N}(index, size)
end

function OneElementVector{T}(value, index::Int, length::Integer) where {T}
return OneElementVector{T}(value, (index,), (length,))

Check warning on line 182 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
end
function OneElementVector(value::T, index::Int, length::Integer) where {T}
return OneElementVector{T}(value, index, length)

Check warning on line 185 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L184-L185

Added lines #L184 - L185 were not covered by tests
end
function OneElementArray{T}(value, index::Int, length::Integer) where {T}
return OneElementVector{T}(value, index, length)

Check warning on line 188 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L187-L188

Added lines #L187 - L188 were not covered by tests
end
function OneElementArray(value::T, index::Int, length::Integer) where {T}
return OneElementVector{T}(value, index, length)

Check warning on line 191 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L190-L191

Added lines #L190 - L191 were not covered by tests
end

function OneElementVector{T}(index::Int, size::Integer) where {T}
return OneElementVector{T}((index,), (size,))
end
function OneElementVector(index::Int, length::Integer)
return OneElementVector{Bool}(index, length)
end
function OneElementArray{T}(index::Int, size::Integer) where {T}
return OneElementVector{T}(index, size)

Check warning on line 201 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L200-L201

Added lines #L200 - L201 were not covered by tests
end
OneElementArray(index::Int, size::Integer) = OneElementVector{Bool}(index, size)

# Interface to overload for constructing arrays like `OneElementArray`,
# that may not be `OneElementArray` (i.e. wrapped versions).
function oneelement(
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {N}
return OneElementArray(value, index, axes)
end
function oneelement(
eltype::Type, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
) where {N}
return oneelement(one(eltype), index, axes)
end
function oneelement(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N}
return oneelement(Bool, index, axes)
end

function oneelement(value, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
return oneelement(value, index, Base.oneto.(size))
end
function oneelement(eltype::Type, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
return oneelement(one(eltype), index, size)
end
function oneelement(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
return oneelement(Bool, index, size)
end

function oneelement(value, ax_ind::Pair{<:AbstractUnitRange,Int}...)
return oneelement(value, last.(ax_ind), first.(ax_ind))
end
function oneelement(eltype::Type, ax_ind::Pair{<:AbstractUnitRange,Int}...)
return oneelement(one(eltype), ax_ind...)
end
function oneelement(ax_ind::Pair{<:AbstractUnitRange,Int}...)
return oneelement(Bool, ax_ind...)
end

function oneelement(value)
return oneelement(value, (), ())

Check warning on line 242 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L241-L242

Added lines #L241 - L242 were not covered by tests
end
function oneelement(eltype::Type)
return oneelement(one(eltype))

Check warning on line 245 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L244-L245

Added lines #L244 - L245 were not covered by tests
end
function oneelement()
return oneelement(Bool)

Check warning on line 248 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L247-L248

Added lines #L247 - L248 were not covered by tests
end

Base.axes(a::OneElementArray) = getfield(a, :axes)
Base.size(a::OneElementArray) = length.(axes(a))
storedvalue(a::OneElementArray) = getfield(a, :value)
storedvalues(a::OneElementArray) = Fill(storedvalue(a), 1)

storedindex(a::OneElementArray) = getfield(a, :index)
function isstored(a::OneElementArray, I::Int...)
return I == storedindex(a)
end
function eachstoredindex(a::OneElementArray)
return Fill(CartesianIndex(storedindex(a)), 1)
end

function getstoredindex(a::OneElementArray, I::Int...)
return storedvalue(a)
end
function getunstoredindex(a::OneElementArray, I::Int...)
return a.getunstoredindex(a, I...)
end
function setstoredindex!(a::OneElementArray, value, I::Int...)
return error("`OneElementArray` is immutable, you can't set elements.")

Check warning on line 271 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L270-L271

Added lines #L270 - L271 were not covered by tests
end
function setunstoredindex!(a::OneElementArray, value, I::Int...)
return error("`OneElementArray` is immutable, you can't set elements.")

Check warning on line 274 in src/oneelementarray.jl

View check run for this annotation

Codecov / codecov/patch

src/oneelementarray.jl#L273-L274

Added lines #L273 - L274 were not covered by tests
end
File renamed without changes.
11 changes: 10 additions & 1 deletion test/basics/test_diagonal.jl → test/test_diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@ using SparseArraysBase:

using Test: @test, @testset

# compat with LTS:
@static if VERSION ≥ v"1.11"
_diagind = diagind
else
function _diagind(x::Diagonal, ::IndexCartesian)
return view(CartesianIndices(x), diagind(x))
end
end

elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@testset "Diagonal{$T}" for T in elts
L = 4
D = Diagonal(rand(T, 4))
@test storedlength(D) == 4
@test eachstoredindex(D) == diagind(D, IndexCartesian())
@test eachstoredindex(D) == _diagind(D, IndexCartesian())
@test isstored(D, 2, 2)
@test getstoredindex(D, 2, 2) == D[2, 2]
@test !isstored(D, 2, 1)
Expand Down
4 changes: 4 additions & 0 deletions test/basics/test_exports.jl → test/test_exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ using Test: @test, @testset
:SparseArrayDOK,
:SparseMatrixDOK,
:SparseVectorDOK,
:OneElementArray,
:OneElementMatrix,
:OneElementVector,
:eachstoredindex,
:isstored,
:oneelementarray,
:storedlength,
:storedpairs,
:storedvalues,
Expand Down
File renamed without changes.
Loading