-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patharraylayouts.jl
54 lines (49 loc) · 1.92 KB
/
arraylayouts.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
using ArrayLayouts: ArrayLayouts, DualLayout, MemoryLayout, MulAdd
using BlockArrays: BlockLayout
using SparseArraysBase: SparseLayout
using TypeParameterAccessors: parenttype, similartype
## TODO: Bring this back.
## function ArrayLayouts.MemoryLayout(arraytype::Type{<:AnyAbstractBlockSparseArray})
## outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
## inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
## return BlockLayout{outer_layout,inner_layout}()
## end
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
function ArrayLayouts.MemoryLayout(
arraytype::Type{<:Adjoint{<:Any,<:AbstractBlockSparseVector}}
)
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
end
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
function ArrayLayouts.MemoryLayout(
arraytype::Type{<:Transpose{<:Any,<:AbstractBlockSparseVector}}
)
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
end
function Base.similar(
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
elt::Type,
axes,
) where {A,B}
# TODO: Check that this equals `similartype(blocktype(B), elt, axes)`,
# or maybe promote them?
output_blocktype = similartype(blocktype(A), elt, axes)
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
end
# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`.
# TODO: Use `similar`?
blocktype_a = blocktype(parent(a))
a_dest = BlockSparseArray{eltype(a),length(axes),blocktype_a}(axes)
a_dest .= a
return a_dest
end
# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
a_dest = blocktype(a)(undef, length.(axes))
a_dest .= a
return a_dest
end