-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsparsearrayinterface.jl
39 lines (32 loc) · 1.26 KB
/
sparsearrayinterface.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
using BlockArrays: Block
using SparseArraysBase: SparseArraysBase, eachstoredindex, storedvalues
# Structure storing the block sparse storage
struct BlockSparseStorage{Arr<:AbstractBlockSparseArray}
array::Arr
end
function blockindex_to_cartesianindex(a::AbstractArray, blockindex)
return CartesianIndex(getindex.(axes(a), getindex.(Block.(blockindex.I), blockindex.α)))
end
function Base.keys(s::BlockSparseStorage)
stored_blockindices = Iterators.map(stored_indices(blocks(s.array))) do I
block_axes = axes(blocks(s.array)[I])
blockindices = Block(Tuple(I))[block_axes...]
return Iterators.map(
blockindex -> blockindex_to_cartesianindex(s.array, blockindex), blockindices
)
end
return Iterators.flatten(stored_blockindices)
end
function Base.values(s::BlockSparseStorage)
return Iterators.map(I -> s.array[I], eachindex(s))
end
function Base.iterate(s::BlockSparseStorage, args...)
return iterate(values(s), args...)
end
## TODO: Use `SparseArraysBase.getstoredindex`, `storedvalues`, etc.
## function SparseArraysBase.sparse_storage(a::AbstractBlockSparseArray)
## return BlockSparseStorage(a)
## end
function SparseArraysBase.storedlength(a::AnyAbstractBlockSparseArray)
return sum(storedlength, storedvalues(blocks(a)); init=zero(Int))
end