-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmap.jl
164 lines (145 loc) · 5.56 KB
/
map.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase:
SparseArraysBase,
SparseArrayStyle,
sparse_map!,
sparse_copy!,
sparse_copyto!,
sparse_permutedims!,
sparse_mapreduce,
sparse_iszero,
sparse_isreal
# Returns `Vector{<:CartesianIndices}`
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
combined_axes = combine_axes(axes.(as)...)
stored_blocked_cartesianindices_as = map(as) do a
return blocked_cartesianindices(axes(a), combined_axes, block_stored_indices(a))
end
return ∪(stored_blocked_cartesianindices_as...)
end
# This is used by `map` to get the output axes.
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)
reblock(a) = a
# If the blocking of the slice doesn't match the blocking of the
# parent array, reblock according to the blocking of the parent array.
function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
)
# TODO: This relies on the behavior that slicing a block sparse
# array with a UnitRange inherits the blocking of the underlying
# block sparse array, we might change that default behavior
# so this might become something like `@blocked parent(a)[...]`.
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
end
function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
)
return @view parent(a)[map(I -> I.array, parentindices(a))...]
end
function reblock(
a::SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
},
)
# Remove the blocking.
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
end
# TODO: Define as `@interface BlockSparseArrayInterface Base.map!(...)`.
# TODO: Rewrite this so that it takes the blocking structure
# made by combining the blocking of the axes (i.e. the blocking that
# is used to determine `union_stored_blocked_cartesianindices(...)`).
# `reblock` is a partial solution to that, but a bit ad-hoc.
# TODO: Move to `blocksparsearrayinterface/map.jl`.
function blocksparse_map!(
f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
# TODO: Use `map!!` to handle immutable blocks.
map!(f, subblock_dest, subblock_srcs...)
# Replace the entire block, handles initializing new blocks
# or if blocks are immutable.
blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest
end
return a_dest
end
# TODO: Implement this.
# function SparseArraysBase.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray})
# end
# TODO: @derive AbstractArrayInterface
function Base.map!(f, a_dest::AbstractArray, a_srcs::Vararg{AnyAbstractBlockSparseArray})
sparse_map!(f, a_dest, a_srcs...)
return a_dest
end
# TODO: @derive AbstractArrayInterface
function Base.map(f, as::Vararg{AnyAbstractBlockSparseArray})
return f.(as...)
end
# TODO: @derive AbstractArrayInterface
function Base.copy!(a_dest::AbstractArray, a_src::AnyAbstractBlockSparseArray)
sparse_copy!(a_dest, a_src)
return a_dest
end
# TODO: @derive AbstractArrayInterface
function Base.copyto!(a_dest::AbstractArray, a_src::AnyAbstractBlockSparseArray)
sparse_copyto!(a_dest, a_src)
return a_dest
end
# TODO: @derive AbstractArrayInterface
# Fix ambiguity error
function Base.copyto!(a_dest::LayoutArray, a_src::AnyAbstractBlockSparseArray)
sparse_copyto!(a_dest, a_src)
return a_dest
end
# TODO: @derive AbstractArrayInterface
function Base.copyto!(
a_dest::AbstractMatrix, a_src::Transpose{T,<:AbstractBlockSparseMatrix{T}}
) where {T}
sparse_copyto!(a_dest, a_src)
return a_dest
end
# TODO: @derive AbstractArrayInterface
function Base.copyto!(
a_dest::AbstractMatrix, a_src::Adjoint{T,<:AbstractBlockSparseMatrix{T}}
) where {T}
sparse_copyto!(a_dest, a_src)
return a_dest
end
# TODO: @derive AbstractArrayInterface
# TODO: Define as `a_dest .= PermutedDimsArray(a_src, perm)`
function Base.permutedims!(a_dest, a_src::AnyAbstractBlockSparseArray, perm)
sparse_permutedims!(a_dest, a_src, perm)
return a_dest
end
# TODO: @derive AbstractArrayInterface
function Base.mapreduce(f, op, as::Vararg{AnyAbstractBlockSparseArray}; kwargs...)
return sparse_mapreduce(f, op, as...; kwargs...)
end
# TODO: @derive AbstractArrayInterface
# TODO: Why isn't this calling `mapreduce` already?
function Base.iszero(a::AnyAbstractBlockSparseArray)
return sparse_iszero(blocks(a))
end
# TODO: @derive AbstractArrayInterface
# TODO: Why isn't this calling `mapreduce` already?
function Base.isreal(a::AnyAbstractBlockSparseArray)
return sparse_isreal(blocks(a))
end