Skip to content

Commit 7a3ceed

Browse files
committed
implement bitonic sorting network for SVectors
1 parent 95f2578 commit 7a3ceed

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ include("abstractarray.jl")
123123
include("indexing.jl")
124124
include("broadcast.jl")
125125
include("mapreduce.jl")
126+
include("sort.jl")
126127
include("arraymath.jl")
127128
include("linalg.jl")
128129
include("matrix_multiply.jl")

src/sort.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import Base.@_inline_meta
2+
import Base.Order: Ordering, Forward, ReverseOrdering, ord
3+
import Base.Sort: Algorithm, defalg, lt, sort
4+
5+
6+
struct BitonicSortAlg <: Algorithm end
7+
struct MinSizeSortAlg <: Algorithm end
8+
struct MinDepthSortAlg <: Algorithm end
9+
const MinSortAlg = Union{MinSizeSortAlg,MinDepthSortAlg}
10+
11+
const BitonicSort = BitonicSortAlg()
12+
const MinSizeSort = MinSizeSortAlg()
13+
const MinDepthSort = MinDepthSortAlg()
14+
15+
defalg(::SVector) = BitonicSort
16+
17+
function sort(a::SVector;
18+
alg::Algorithm = defalg(a),
19+
lt = isless,
20+
by = identity,
21+
rev::Union{Bool,Nothing} = nothing,
22+
order::Ordering = Forward)
23+
ordr = ord(lt, by, rev, order)
24+
println(Size(a))
25+
(Size(a) == Size(0) || Size(a) == Size(1)) && return a
26+
return _sort(Size(a), alg, ordr, a)
27+
end
28+
29+
_sort(::Size{T}, alg, _, _) where T =
30+
error("sorting algorithm $alg unimplemented for static array of size $T")
31+
32+
@generated function _sort(::Size{S}, ::BitonicSortAlg, order, a) where {S}
33+
function swap_expr(i, j, rev)
34+
ai = Symbol('a', i)
35+
aj = Symbol('a', j)
36+
order = rev ? :revorder : :order
37+
return :( ($ai, $aj) = lt($order, $ai, $aj) ? ($ai, $aj) : ($aj, $ai) )
38+
end
39+
40+
function merge_exprs(idx, rev)
41+
exprs = Expr[]
42+
length(idx) == 1 && return exprs
43+
44+
ci = 2^(ceil(Int, log2(length(idx))) - 1)
45+
# TODO: generate simd code for these swaps
46+
for i in first(idx):last(idx)-ci
47+
push!(exprs, swap_expr(i, i+ci, rev))
48+
end
49+
append!(exprs, merge_exprs(idx[1:ci], rev))
50+
append!(exprs, merge_exprs(idx[ci+1:end], rev))
51+
return exprs
52+
end
53+
54+
function sort_exprs(idx, rev=false)
55+
exprs = Expr[]
56+
length(idx) == 1 && return exprs
57+
58+
append!(exprs, sort_exprs(idx[1:end÷2], !rev))
59+
append!(exprs, sort_exprs(idx[end÷2+1:end], rev))
60+
append!(exprs, merge_exprs(idx, rev))
61+
return exprs
62+
end
63+
64+
idx = 1:prod(S)
65+
symlist = (Symbol('a', i) for i in idx)
66+
sym_exprs = (:( $ai = a[$i] ) for (i, ai) in enumerate(symlist))
67+
return quote
68+
@_inline_meta
69+
revorder = Base.Order.ReverseOrdering(order)
70+
@inbounds ($(sym_exprs...);)
71+
($(sort_exprs(idx)...);)
72+
return SVector(($(symlist...)))
73+
end
74+
end
75+
76+
77+
## TODO: manually implementing minimal sorting networks for small lengths might
78+
## be worthwhile
79+
#
80+
#@inline _cmpswap(order, a, b) = lt(order, a, b) ? (a, b) : (b, a)
81+
#
82+
#macro _cmpswap(order, a, b)
83+
# return esc(:( ($a, $b) = _cmpswap(order, $a, $b) ))
84+
#end
85+
#
86+
#@inline _sort(::Size{(2,)}, _, order, (a1, a2)) = SVector(_cmpswap(order, a1, a2))
87+
#
88+
#@inline function _sort(::Size{(3,)}, ::MinSortAlg, order, (a1, a2, a3))
89+
# @_cmpswap order a1 a3
90+
# @_cmpswap order a1 a2
91+
# @_cmpswap order a2 a3
92+
# return SVector(a1, a2, a3)
93+
#end
94+
#
95+
#@inline function _sort(::Size{(4,)}, ::MinSortAlg, order, (a1, a2, a3, a4))
96+
# @_cmpswap order a1 a3
97+
# @_cmpswap order a2 a4
98+
# @_cmpswap order a1 a2
99+
# @_cmpswap order a3 a4
100+
# @_cmpswap order a2 a3
101+
# return SVector(a1, a2, a3, a4)
102+
#end

0 commit comments

Comments
 (0)