-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathblockedpermutation.jl
175 lines (142 loc) · 5.91 KB
/
blockedpermutation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
using BlockArrays:
BlockArrays, Block, blockfirsts, blocklasts, blocklength, blocklengths, blocks
using EllipsisNotation: Ellipsis, var".."
using TupleTools: TupleTools
trivialperm(len) = ntuple(identity, len)
function istrivialperm(t::Tuple)
return t == trivialperm(length(t))
end
value(::Val{N}) where {N} = N
_flatten_tuples(t::Tuple) = t
function _flatten_tuples(t1::Tuple, t2::Tuple, trest::Tuple...)
return _flatten_tuples((t1..., t2...), trest...)
end
_flatten_tuples() = ()
flatten_tuples(ts::Tuple) = _flatten_tuples(ts...)
collect_tuple(x) = (x,)
collect_tuple(x::Ellipsis) = x
collect_tuple(t::Tuple) = t
#
# =============================== AbstractBlockPermutation ===============================
#
abstract type AbstractBlockPermutation{BlockLength} <: AbstractBlockTuple{BlockLength} end
widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple
# Block a permutation based on the specified lengths.
# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
# TODO: Optimize with StaticNumbers.jl or generated functions, see:
# https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567
function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})
return blockedperm(BlockedTuple(perm, blocklengths))
end
function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)
return blockedperm(BlockedTuple(perm, BlockLengths))
end
function Base.invperm(blockedperm::AbstractBlockPermutation)
# use Val to preserve compile time info
return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm)))
end
#
# Constructors
#
# Bipartition a vector according to the
# bipartitioned permutation.
# Like `Base.permute!` block out-of-place and blocked.
function blockpermute(v, blockedperm::AbstractBlockPermutation)
return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))
end
# blockedperm((4, 3), (2, 1))
function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing)
return blockedperm(length, permblocks...)
end
function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...)
return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)
end
# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,))
function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...)
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
end
function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwargs...)
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
end
function blockedperm(bt::AbstractBlockTuple)
return blockedperm(Val(length(bt)), blocks(bt)...)
end
function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
return maximum(specified_perm)
end
function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}})
return value(vallength)
end
# blockedperm((4, 3), .., 1) == blockedperm((4, 3), 2, 1)
# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), 2, 5, 1)
function blockedperm(
permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing
)
# Check there is only one `Ellipsis`.
@assert isone(count(x -> x isa Ellipsis, permblocks))
specified_permblocks = filter(x -> !(x isa Ellipsis), permblocks)
unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks)
specified_perm = flatten_tuples(specified_permblocks)
len = _blockedperm_length(length, specified_perm)
unspecified_dims = Tuple(setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks)))
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, unspecified_dims)
return blockedperm(permblocks_specified...)
end
# Version of `indexin` that outputs a `blockedperm`.
function blockedperm_indexin(collection, subs...)
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
end
#
# ================================== BlockedPermutation ==================================
#
# for dispatch reason, it is convenient to have BlockLength as the first parameter
struct BlockedPermutation{BlockLength,BlockLengths,Flat} <:
AbstractBlockPermutation{BlockLength}
flat::Flat
function BlockedPermutation{BlockLength,BlockLengths}(
flat::Tuple
) where {BlockLength,BlockLengths}
length(flat) != sum(BlockLengths; init=0) &&
throw(DimensionMismatch("Invalid total length"))
length(BlockLengths) != BlockLength &&
throw(DimensionMismatch("Invalid total blocklength"))
any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length"))
return new{BlockLength,BlockLengths,typeof(flat)}(flat)
end
end
# Base interface
Base.Tuple(blockedperm::BlockedPermutation) = getfield(blockedperm, :flat)
# BlockArrays interface
function BlockArrays.blocklengths(
::Type{<:BlockedPermutation{<:Any,BlockLengths}}
) where {BlockLengths}
return BlockLengths
end
function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...)
blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}(
flatten_tuples(permblocks)
)
@assert isperm(blockedperm)
return blockedperm
end
#
# ============================== BlockedTrivialPermutation ===============================
#
trivialperm(length::Union{Integer,Val}) = ntuple(identity, length)
struct BlockedTrivialPermutation{BlockLength,BlockLengths} <:
AbstractBlockPermutation{BlockLength} end
Base.Tuple(blockedperm::BlockedTrivialPermutation) = trivialperm(length(blockedperm))
# BlockArrays interface
function BlockArrays.blocklengths(
::Type{<:BlockedTrivialPermutation{<:Any,BlockLengths}}
) where {BlockLengths}
return BlockLengths
end
blockedperm(tp::BlockedTrivialPermutation) = tp
function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
return BlockedTrivialPermutation{length(blocklengths),blocklengths}()
end
function trivialperm(blockedperm::AbstractBlockTuple)
return blockedtrivialperm(blocklengths(blockedperm))
end
Base.invperm(blockedperm::BlockedTrivialPermutation) = blockedperm