-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathabstractblocksparsearray.jl
103 lines (87 loc) · 3.39 KB
/
abstractblocksparsearray.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
using BlockArrays:
BlockArrays, AbstractBlockArray, Block, BlockIndex, BlockedUnitRange, blocks
using SparseArraysBase: sparse_getindex, sparse_setindex!
# TODO: Delete this. This function was replaced
# by `stored_length` but is still used in `NDTensors`.
function nonzero_keys end
abstract type AbstractBlockSparseArray{T,N} <: AbstractBlockArray{T,N} end
using Derive: @array_aliases
# Define AbstractSparseVector, AnyAbstractSparseArray, etc.
@array_aliases AbstractBlockSparseArray
using Derive: Derive
function Derive.interface(::Type{<:AbstractBlockSparseArray})
return BlockSparseArrayInterface()
end
using Derive: @derive
# TODO: These need to be loaded since `AbstractArrayOps`
# includes overloads of functions from these modules.
# Ideally that wouldn't be needed and can be circumvented
# with `GlobalRef`.
using ArrayLayouts: ArrayLayouts
using LinearAlgebra: LinearAlgebra
# Derive `Base.getindex`, `Base.setindex!`, etc.
# TODO: Define `AbstractMatrixOps` and overload for
# `AnyAbstractSparseMatrix` and `AnyAbstractSparseVector`,
# which is where matrix multiplication and factorizations
# shoudl go.
@derive AnyAbstractBlockSparseArray AbstractArrayOps
## Base `AbstractArray` interface
Base.axes(::AbstractBlockSparseArray) = error("Not implemented")
# TODO: Add some logic to unwrapping wrapped arrays.
# TODO: Decide what a good default is.
blockstype(arraytype::Type{<:AbstractBlockSparseArray}) = SparseArrayDOK{AbstractArray}
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T}}) where {T}
return SparseArrayDOK{AbstractArray{T}}
end
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T,N}}) where {T,N}
return SparseArrayDOK{AbstractArray{T,N},N}
end
# Specialized in order to fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return blocksparse_getindex(a, I...)
end
# Specialized in order to fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray{<:Any,0})
return blocksparse_getindex(a)
end
## # Fix ambiguity error with `BlockArrays`.
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
## return ArrayLayouts.layout_getindex(a, I)
## end
##
## # Fix ambiguity error with `BlockArrays`.
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1})
## return ArrayLayouts.layout_getindex(a, I)
## end
##
## # Fix ambiguity error with `BlockArrays`.
## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector})
## ## return blocksparse_getindex(a, I...)
## return ArrayLayouts.layout_getindex(a, I...)
## end
# Specialized in order to fix ambiguity error with `BlockArrays`.
function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
blocksparse_setindex!(a, value, I...)
return a
end
# Fix ambiguity error.
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
blocksparse_setindex!(a, value)
return a
end
function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
blocksize = ntuple(dim -> length(axes(a, dim)[I[dim]]), N)
if size(value) ≠ blocksize
throw(
DimensionMismatch(
"Trying to set block $(Block(Int.(I)...)), which has a size $blocksize, with data of size $(size(value)).",
),
)
end
blocks(a)[Int.(I)...] = value
return a
end