Skip to content

Commit 0b9166f

Browse files
committed
Introduce OneElementArray
1 parent b41609b commit 0b9166f

9 files changed

+397
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.10"
4+
version = "0.2.11"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
99
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1010
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
11+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1314

@@ -17,6 +18,7 @@ Aqua = "0.8.9"
1718
ArrayLayouts = "1.11.0"
1819
DerivableInterfaces = "0.3.7"
1920
Dictionaries = "0.4.3"
21+
FillArrays = "1.13.0"
2022
LinearAlgebra = "1.10"
2123
MapBroadcast = "0.1.5"
2224
SafeTestsets = "0.1"

src/SparseArraysBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ module SparseArraysBase
33
export SparseArrayDOK,
44
SparseMatrixDOK,
55
SparseVectorDOK,
6+
OneElementArray,
7+
OneElementMatrix,
8+
OneElementVector,
69
eachstoredindex,
710
isstored,
11+
oneelementarray,
812
storedlength,
913
storedpairs,
1014
storedvalues
@@ -14,5 +18,6 @@ include("sparsearrayinterface.jl")
1418
include("wrappers.jl")
1519
include("abstractsparsearray.jl")
1620
include("sparsearraydok.jl")
21+
include("oneelementarray.jl")
1722

1823
end

src/oneelementarray.jl

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
using FillArrays: Fill
2+
3+
# Like [`FillArrays.OneElement`](https://github.com/JuliaArrays/FillArrays.jl)
4+
# and [`OneHotArrays.OneHotArray`](https://github.com/FluxML/OneHotArrays.jl).
5+
struct OneElementArray{T,N,I,A,F} <: AbstractSparseArray{T,N}
6+
value::T
7+
index::I
8+
axes::A
9+
getunstoredindex::F
10+
end
11+
12+
using DerivableInterfaces: @array_aliases
13+
# Define `OneElementMatrix`, `AnyOneElementArray`, etc.
14+
@array_aliases OneElementArray
15+
16+
function OneElementArray{T,N}(
17+
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}, getunstoredindex
18+
) where {T,N}
19+
return OneElementArray{T,N,typeof(index),typeof(axes),typeof(getunstoredindex)}(
20+
value, index, axes, getunstoredindex
21+
)
22+
end
23+
24+
function OneElementArray{T,N}(
25+
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
26+
) where {T,N}
27+
return OneElementArray{T,N}(value, index, axes, default_getunstoredindex)
28+
end
29+
function OneElementArray{<:Any,N}(
30+
value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
31+
) where {T,N}
32+
return OneElementArray{T,N}(value, index, axes)
33+
end
34+
function OneElementArray(
35+
value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
36+
) where {T,N}
37+
return OneElementArray{T,N}(value, index, axes)
38+
end
39+
40+
function OneElementArray{T,N}(
41+
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
42+
) where {T,N}
43+
return OneElementArray{T,N}(one(T), index, axes)
44+
end
45+
function OneElementArray{<:Any,N}(
46+
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
47+
) where {N}
48+
return OneElementArray{Bool,N}(index, axes)
49+
end
50+
function OneElementArray{T}(
51+
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
52+
) where {T,N}
53+
return OneElementArray{T,N}(index, axes)
54+
end
55+
function OneElementArray(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N}
56+
return OneElementArray{Bool,N}(index, axes)
57+
end
58+
59+
function OneElementArray{T,N}(
60+
value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
61+
) where {T,N}
62+
return OneElementArray{T,N}(value, last.(ax_ind), first.(ax_ind))
63+
end
64+
function OneElementArray{<:Any,N}(
65+
value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
66+
) where {T,N}
67+
return OneElementArray{T,N}(value, ax_ind...)
68+
end
69+
function OneElementArray{T}(
70+
value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
71+
) where {T,N}
72+
return OneElementArray{T,N}(value, ax_ind...)
73+
end
74+
function OneElementArray(
75+
value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
76+
) where {T,N}
77+
return OneElementArray{T,N}(value, ax_ind...)
78+
end
79+
80+
function OneElementArray{T,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N}
81+
return OneElementArray{T,N}(last.(ax_ind), first.(ax_ind))
82+
end
83+
function OneElementArray{<:Any,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N}
84+
return OneElementArray{Bool,N}(ax_ind...)
85+
end
86+
function OneElementArray{T}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N}
87+
return OneElementArray{T,N}(ax_ind...)
88+
end
89+
function OneElementArray(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N}
90+
return OneElementArray{Bool,N}(ax_ind...)
91+
end
92+
93+
# Fix ambiguity errors.
94+
function OneElementArray{T,0}(value, index::Tuple{}, axes::Tuple{}) where {T}
95+
return OneElementArray{T,0}(value, index, axes, default_getunstoredindex)
96+
end
97+
function OneElementArray{<:Any,0}(value::T, index::Tuple{}, axes::Tuple{}) where {T}
98+
return OneElementArray{T,0}(value, index, axes)
99+
end
100+
function OneElementArray{T}(value, index::Tuple{}, axes::Tuple{}) where {T}
101+
return OneElementArray{T,0}(value, index, axes)
102+
end
103+
function OneElementArray(value::T, index::Tuple{}, axes::Tuple{}) where {T}
104+
return OneElementArray{T,0}(value, index, axes)
105+
end
106+
107+
# Fix ambiguity errors.
108+
function OneElementArray{T,0}(index::Tuple{}, axes::Tuple{}) where {T}
109+
return OneElementArray{T,0}(one(T), index, axes)
110+
end
111+
function OneElementArray{<:Any,0}(index::Tuple{}, axes::Tuple{})
112+
return OneElementArray{Bool,0}(index, axes)
113+
end
114+
function OneElementArray{T}(index::Tuple{}, axes::Tuple{}) where {T}
115+
return OneElementArray{T,0}(index, axes)
116+
end
117+
function OneElementArray(index::Tuple{}, axes::Tuple{})
118+
return OneElementArray{Bool,0}(value, index, axes)
119+
end
120+
121+
function OneElementArray{T,0}(value) where {T}
122+
return OneElementArray{T,0}(value, (), ())
123+
end
124+
function OneElementArray{<:Any,0}(value::T) where {T}
125+
return OneElementArray{T,0}(value)
126+
end
127+
function OneElementArray{T}(value) where {T}
128+
return OneElementArray{T,0}(value)
129+
end
130+
function OneElementArray(value::T) where {T}
131+
return OneElementArray{T}(value)
132+
end
133+
134+
function OneElementArray{T,0}(index::Tuple{}, axes::Tuple{}) where {T}
135+
return OneElementArray{T,0}(one(T), (), ())
136+
end
137+
function OneElementArray{T,0}() where {T}
138+
return OneElementArray{T,0}((), ())
139+
end
140+
function OneElementArray{<:Any,0}()
141+
return OneElementArray{Bool,0}(value)
142+
end
143+
function OneElementArray{T}() where {T}
144+
return OneElementArray{T,0}()
145+
end
146+
function OneElementArray()
147+
return OneElementArray{Bool}()
148+
end
149+
150+
function OneElementArray{T,N}(
151+
value, index::NTuple{N,Int}, size::NTuple{N,Integer}
152+
) where {T,N}
153+
return OneElementArray{T,N}(value, index, Base.oneto.(size))
154+
end
155+
function OneElementArray{<:Any,N}(
156+
value::T, index::NTuple{N,Int}, size::NTuple{N,Integer}
157+
) where {T,N}
158+
return OneElementArray{T,N}(value, index, size)
159+
end
160+
function OneElementArray{T}(
161+
value, index::NTuple{N,Int}, size::NTuple{N,Integer}
162+
) where {T,N}
163+
return OneElementArray{T,N}(value, index, size)
164+
end
165+
function OneElementArray(
166+
value::T, index::NTuple{N,Int}, size::NTuple{N,Integer}
167+
) where {T,N}
168+
return OneElementArray{T,N}(value, index, Base.oneto.(size))
169+
end
170+
171+
function OneElementArray{T,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N}
172+
return OneElementArray{T,N}(one(T), index, size)
173+
end
174+
function OneElementArray{<:Any,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
175+
return OneElementArray{Bool,N}(index, size)
176+
end
177+
function OneElementArray{T}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N}
178+
return OneElementArray{T,N}(index, size)
179+
end
180+
function OneElementArray(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
181+
return OneElementArray{Bool,N}(index, size)
182+
end
183+
184+
function OneElementVector{T}(value, index::Int, length::Integer) where {T}
185+
return OneElementVector{T}(value, (index,), (length,))
186+
end
187+
function OneElementVector(value::T, index::Int, length::Integer) where {T}
188+
return OneElementVector{T}(value, index, length)
189+
end
190+
function OneElementArray{T}(value, index::Int, length::Integer) where {T}
191+
return OneElementVector{T}(value, index, length)
192+
end
193+
function OneElementArray(value::T, index::Int, length::Integer) where {T}
194+
return OneElementVector{T}(value, index, length)
195+
end
196+
197+
function OneElementVector{T}(index::Int, size::Integer) where {T}
198+
return OneElementVector{T}((index,), (size,))
199+
end
200+
function OneElementVector(index::Int, length::Integer)
201+
return OneElementVector{Bool}(index, length)
202+
end
203+
function OneElementArray{T}(index::Int, size::Integer) where {T}
204+
return OneElementVector{T}(index, size)
205+
end
206+
OneElementArray(index::Int, size::Integer) = OneElementVector{Bool}(index, size)
207+
208+
# Interface to overload for constructing arrays like `OneElementArray`,
209+
# that may not be `OneElementArray` (i.e. wrapped versions).
210+
function oneelement(
211+
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
212+
) where {N}
213+
return OneElementArray(value, index, axes)
214+
end
215+
function oneelement(
216+
eltype::Type, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
217+
) where {N}
218+
return oneelement(one(eltype), index, axes)
219+
end
220+
function oneelement(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N}
221+
return oneelement(Bool, index, axes)
222+
end
223+
224+
function oneelement(value, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
225+
return OneElementArray(value, index, size)
226+
end
227+
function oneelement(eltype::Type, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
228+
return oneelement(one(eltype), index, size)
229+
end
230+
function oneelement(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
231+
return oneelement(Bool, index, size)
232+
end
233+
234+
function oneelement(value, ax_ind::Pair{<:AbstractUnitRange,Int}...)
235+
return oneelement(value, last.(ax_ind), first.(ax_ind))
236+
end
237+
function oneelement(eltype::Type, ax_ind::Pair{<:AbstractUnitRange,Int}...)
238+
return oneelement(one(eltype), ax_ind...)
239+
end
240+
function oneelement(ax_ind::Pair{<:AbstractUnitRange,Int}...)
241+
return oneelement(Bool, ax_ind...)
242+
end
243+
244+
function oneelement(value)
245+
return oneelement(value, (), ())
246+
end
247+
function oneelement(eltype::Type)
248+
return oneelement(one(eltype))
249+
end
250+
function oneelement()
251+
return oneelement(Bool)
252+
end
253+
254+
Base.axes(a::OneElementArray) = getfield(a, :axes)
255+
Base.size(a::OneElementArray) = length.(axes(a))
256+
storedvalue(a::OneElementArray) = getfield(a, :value)
257+
storedvalues(a::OneElementArray) = Fill(storedvalue(a), 1)
258+
259+
storedindex(a::OneElementArray) = getfield(a, :index)
260+
function isstored(a::OneElementArray, I::Int...)
261+
return I == storedindex(a)
262+
end
263+
function eachstoredindex(a::OneElementArray)
264+
return Fill(CartesianIndex(storedindex(a)), 1)
265+
end
266+
267+
function getstoredindex(a::OneElementArray, I::Int...)
268+
return storedvalue(a)
269+
end
270+
function getunstoredindex(a::OneElementArray, I::Int...)
271+
return a.getunstoredindex(a, I...)
272+
end
273+
function setstoredindex!(a::OneElementArray, value, I::Int...)
274+
return error("`OneElementArray` is immutable, you can't set elements.")
275+
end
276+
function setunstoredindex!(a::OneElementArray, value, I::Int...)
277+
return error("`OneElementArray` is immutable, you can't set elements.")
278+
end
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)