From f54c6d2f0d953f1676608578a2c40160e62f1d69 Mon Sep 17 00:00:00 2001 From: Jake Roth Date: Mon, 9 Sep 2024 18:22:30 -0500 Subject: [PATCH 1/2] add iterative-pc, splitting-pc, and halving-pc to union find structure --- src/DataStructures.jl | 1 + src/disjoint_set.jl | 53 +++++++++++++++++++++++++++++++++++++- test/bench_disjoint_set.jl | 42 +++++++++++++++++++++++++++++- test/test_disjoint_set.jl | 27 +++++++++++++++++++ 4 files changed, 121 insertions(+), 2 deletions(-) diff --git a/src/DataStructures.jl b/src/DataStructures.jl index 5e5d767ec..cc99eccc6 100644 --- a/src/DataStructures.jl +++ b/src/DataStructures.jl @@ -68,6 +68,7 @@ module DataStructures include("queue.jl") include("accumulator.jl") include("disjoint_set.jl") + export PCRecursive, PCIterative, PCHalving, PCSplitting include("heaps.jl") include("default_dict.jl") diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl index 364645472..220f97902 100644 --- a/src/disjoint_set.jl +++ b/src/disjoint_set.jl @@ -60,13 +60,64 @@ function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} return p end +# iterative path compression: makes every node on the path point directly to the root +@inline function find_root_iterative!(parents::Vector{T}, x::Integer) where {T<:Integer} + current = x + # find the root of the tree + @inbounds while parents[current] != current + current = parents[current] + end + root = current + # compress the path: make every node point directly to the root + current = x + @inbounds while parents[current] != root + p = parents[current] # temporarily store the parent + parents[current] = root # point directly to the root + current = p # move to the next node in the original path + end + return root +end + +# path-halving and path-splitting are a one-pass forms of path compression with inverse-ackerman complexity +# e.g., see p.19 of https://www.cs.princeton.edu/courses/archive/spr11/cos423/Lectures/PathCompressionAnalysisII.pdf + +# path-halving: every node on the path points to its grandparent +@inline function find_root_halving!(parents::Vector{T}, x::Integer) where {T<:Integer} + current = x # use a separate variable 'current' to track traversal + @inbounds while parents[current] != current + @inbounds parents[current] = parents[parents[current]] # point to grandparent + @inbounds current = parents[current] # move to grandparent + end + return current +end + +# path-splitting: every node on the path points to its grandparent +@inline function find_root_splitting!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds while parents[x] != x + p = parents[x] # store the current parent + parents[x] = parents[p] # point to grandparent + x = p # move to parent + end + return x +end + + +struct PCRecursive end # path compression types +struct PCIterative end # path compression types +struct PCHalving end # path compression types +struct PCSplitting end # path compression types + """ find_root!(s::IntDisjointSet{T}, x::T) Find the root element of the subset that contains an member `x`. Path compression happens here. """ -find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) # default +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCRecursive) where {T<:Integer} = find_root_impl!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCIterative) where {T<:Integer} = find_root_iterative!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCHalving) where {T<:Integer} = find_root_halving!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCSplitting) where {T<:Integer} = find_root_splitting!(s.parents, x) """ in_same_set(s::IntDisjointSet{T}, x::T, y::T) diff --git a/test/bench_disjoint_set.jl b/test/bench_disjoint_set.jl index d4712b5e6..8cc4bb09a 100644 --- a/test/bench_disjoint_set.jl +++ b/test/bench_disjoint_set.jl @@ -1,6 +1,6 @@ # Benchmark on disjoint set forests -using DataStructures +using DataStructures, BenchmarkTools # do 10^6 random unions over 10^6 element set @@ -29,3 +29,43 @@ x = rand(1:n, T) y = rand(1:n, T) @time batch_union!(s, x, y) + +#= +benchmark `find` operation +=# + +function create_disjoint_set_struct(n::Int) + parents = [1; collect(1:n-1)] # each element's parent is its predecessor + ranks = zeros(Int, n) # ranks are all zero + DataStructures.IntDisjointSet(parents, ranks, n) +end + +# benchmarking function +function benchmark_find_root(n::Int) + println("Benchmarking recursive path compression implementation (find_root_impl!):") + if n >= 10^5 + println("Recursive may path compression may encounter stack-overflow; skipping") + else + s = create_disjoint_set_struct(n) + @btime DataStructures.find_root!($s, $n, DataStructures.PCRecursive()) + end + + println("Benchmarking iterative path compression implementation (find_root_iterative!):") + s = create_disjoint_set_struct(n) # reset parents + @btime DataStructures.find_root!($s, $n, DataStructures.PCIterative()) + + println("Benchmarking path-halving implementation (find_root_halving!):") + s = create_disjoint_set_struct(n) # reset parents + @btime DataStructures.find_root!($s, $n, DataStructures.PCHalving()) + + println("Benchmarking path-splitting implementation (find_root_path_splitting!):") + s = create_disjoint_set_struct(n) # reset parents + @btime DataStructures.find_root!($s, $n, DataStructures.PCSplitting()) +end + +# run benchmark tests +benchmark_find_root(1_000) +benchmark_find_root(10_000) +benchmark_find_root(100_000) +benchmark_find_root(1_000_000) +benchmark_find_root(10_000_000) \ No newline at end of file diff --git a/test/test_disjoint_set.jl b/test/test_disjoint_set.jl index a03465470..e96b3d9ce 100644 --- a/test/test_disjoint_set.jl +++ b/test/test_disjoint_set.jl @@ -48,10 +48,19 @@ @test union!(s, T(8), T(5)) == T(8) @test num_groups(s) == T(7) @test find_root!(s, T(6)) == T(8) + @test find_root!(s, T(6), :iterative) == T(8) + @test find_root!(s, T(6), :halving) == T(8) + @test find_root!(s, T(6), :splitting) == T(8) union!(s, T(2), T(6)) @test find_root!(s, T(2)) == T(8) root1 = find_root!(s, T(6)) + root1 = find_root!(s, T(6), :iterative) + root1 = find_root!(s, T(6), :halving) + root1 = find_root!(s, T(6), :splitting) root2 = find_root!(s, T(2)) + root2 = find_root!(s, T(2), :iterative) + root2 = find_root!(s, T(2), :halving) + root2 = find_root!(s, T(2), :splitting) @test root_union!(s, T(root1), T(root2)) == T(8) @test union!(s, T(5), T(6)) == T(8) end @@ -136,6 +145,24 @@ @test root1 != root2 root_union!(s, 7, 3) @test find_root!(s, 7) == find_root!(s, 3) + + root1 = find_root!(s, 7, :iterative) + root2 = find_root!(s, 3, :iterative) + @test root1 != root2 + root_union!(s, 7, 3) + @test find_root!(s, 7, :iterative) == find_root!(s, 3, :iterative) + + root1 = find_root!(s, 7, :halving) + root2 = find_root!(s, 3, :halving) + @test root1 != root2 + root_union!(s, 7, 3) + @test find_root!(s, 7, :halving) == find_root!(s, 3, :halving) + + root1 = find_root!(s, 7, :splitting) + root2 = find_root!(s, 3, :splitting) + @test root1 != root2 + root_union!(s, 7, 3) + @test find_root!(s, 7, :splitting) == find_root!(s, 3, :splitting) end @testset "Some tests using non-integer disjoint sets" begin From bad54e02b40208c2a2019884aa22ecfaaae0edf8 Mon Sep 17 00:00:00 2001 From: Jake Roth Date: Mon, 9 Sep 2024 18:55:42 -0500 Subject: [PATCH 2/2] cleanup --- src/disjoint_set.jl | 4 ++ test/bench_disjoint_set.jl | 10 +-- test/test_disjoint_set.jl | 138 +++++++++++++++++++++++++++++-------- 3 files changed, 120 insertions(+), 32 deletions(-) diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl index 220f97902..6b00608a9 100644 --- a/src/disjoint_set.jl +++ b/src/disjoint_set.jl @@ -242,6 +242,10 @@ end Find the root element of the subset in `s` which has the element `x` as a member. """ find_root!(s::DisjointSet{T}, x::T) where {T} = s.revmap[find_root!(s.internal, s.intmap[x])] +find_root!(s::DisjointSet{T}, x::T, ::PCIterative) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCIterative())] +find_root!(s::DisjointSet{T}, x::T, ::PCRecursive) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCRecursive())] +find_root!(s::DisjointSet{T}, x::T, ::PCHalving) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCHalving())] +find_root!(s::DisjointSet{T}, x::T, ::PCSplitting) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCSplitting())] """ in_same_set(s::DisjointSet{T}, x::T, y::T) diff --git a/test/bench_disjoint_set.jl b/test/bench_disjoint_set.jl index 8cc4bb09a..f7c0796f4 100644 --- a/test/bench_disjoint_set.jl +++ b/test/bench_disjoint_set.jl @@ -37,7 +37,7 @@ benchmark `find` operation function create_disjoint_set_struct(n::Int) parents = [1; collect(1:n-1)] # each element's parent is its predecessor ranks = zeros(Int, n) # ranks are all zero - DataStructures.IntDisjointSet(parents, ranks, n) + IntDisjointSet(parents, ranks, n) end # benchmarking function @@ -47,20 +47,20 @@ function benchmark_find_root(n::Int) println("Recursive may path compression may encounter stack-overflow; skipping") else s = create_disjoint_set_struct(n) - @btime DataStructures.find_root!($s, $n, DataStructures.PCRecursive()) + @btime find_root!($s, $n, PCRecursive()) end println("Benchmarking iterative path compression implementation (find_root_iterative!):") s = create_disjoint_set_struct(n) # reset parents - @btime DataStructures.find_root!($s, $n, DataStructures.PCIterative()) + @btime find_root!($s, $n, PCIterative()) println("Benchmarking path-halving implementation (find_root_halving!):") s = create_disjoint_set_struct(n) # reset parents - @btime DataStructures.find_root!($s, $n, DataStructures.PCHalving()) + @btime find_root!($s, $n, PCHalving()) println("Benchmarking path-splitting implementation (find_root_path_splitting!):") s = create_disjoint_set_struct(n) # reset parents - @btime DataStructures.find_root!($s, $n, DataStructures.PCSplitting()) + @btime find_root!($s, $n, PCSplitting()) end # run benchmark tests diff --git a/test/test_disjoint_set.jl b/test/test_disjoint_set.jl index e96b3d9ce..d146f309b 100644 --- a/test/test_disjoint_set.jl +++ b/test/test_disjoint_set.jl @@ -29,10 +29,16 @@ @test num_groups(s) == T(9) @test in_same_set(s, T(2), T(3)) @test find_root!(s, T(3)) == T(2) + @test find_root!(s, T(3), PCIterative()) == T(2) + @test find_root!(s, T(3), PCHalving()) == T(2) + @test find_root!(s, T(3), PCSplitting()) == T(2) union!(s, T(3), T(2)) @test num_groups(s) == T(9) @test in_same_set(s, T(2), T(3)) @test find_root!(s, T(3)) == T(2) + @test find_root!(s, T(3), PCIterative()) == T(2) + @test find_root!(s, T(3), PCHalving()) == T(2) + @test find_root!(s, T(3), PCSplitting()) == T(2) end @testset "more tests" begin @@ -48,19 +54,19 @@ @test union!(s, T(8), T(5)) == T(8) @test num_groups(s) == T(7) @test find_root!(s, T(6)) == T(8) - @test find_root!(s, T(6), :iterative) == T(8) - @test find_root!(s, T(6), :halving) == T(8) - @test find_root!(s, T(6), :splitting) == T(8) + @test find_root!(s, T(6), PCIterative()) == T(8) + @test find_root!(s, T(6), PCHalving()) == T(8) + @test find_root!(s, T(6), PCSplitting()) == T(8) union!(s, T(2), T(6)) @test find_root!(s, T(2)) == T(8) root1 = find_root!(s, T(6)) - root1 = find_root!(s, T(6), :iterative) - root1 = find_root!(s, T(6), :halving) - root1 = find_root!(s, T(6), :splitting) + root1 = find_root!(s, T(6), PCIterative()) + root1 = find_root!(s, T(6), PCHalving()) + root1 = find_root!(s, T(6), PCSplitting()) root2 = find_root!(s, T(2)) - root2 = find_root!(s, T(2), :iterative) - root2 = find_root!(s, T(2), :halving) - root2 = find_root!(s, T(2), :splitting) + root2 = find_root!(s, T(2), PCIterative()) + root2 = find_root!(s, T(2), PCHalving()) + root2 = find_root!(s, T(2), PCSplitting()) @test root_union!(s, T(root1), T(root2)) == T(8) @test union!(s, T(5), T(6)) == T(8) end @@ -107,6 +113,12 @@ r = [find_root!(s, i) for i in 1 : 10] @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCIterative()) for i in 1 : 10] + @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCHalving()) for i in 1 : 10] + @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCSplitting()) for i in 1 : 10] + @test isequal(r, collect(1:10)) end @testset "union!" begin @@ -126,6 +138,57 @@ @test num_groups(s) == 2 end + @testset "union! PCIterative" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCIterative()) == find_root!(s, y, PCIterative()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCIterative()) + @test union!(s, 3, 5) == find_root!(s, 1, PCIterative()) + @test union!(s, 7, 9) == find_root!(s, 7, PCIterative()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + + @testset "union! PCHalving" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCHalving()) == find_root!(s, y, PCHalving()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCHalving()) + @test union!(s, 3, 5) == find_root!(s, 1, PCHalving()) + @test union!(s, 7, 9) == find_root!(s, 7, PCHalving()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + + @testset "union! PCSplitting" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCSplitting()) == find_root!(s, y, PCSplitting()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCSplitting()) + @test union!(s, 3, 5) == find_root!(s, 1, PCSplitting()) + @test union!(s, 7, 9) == find_root!(s, 7, PCSplitting()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + @testset "r0" begin r0 = [ find_root!(s,i) for i in 1:10 ] # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 @@ -139,30 +202,51 @@ @test isequal(r, r0) end + @testset "r0 Iterative" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCIterative()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + + @testset "r0 Splitting" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCSplitting()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + + @testset "r0 Halving" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCHalving()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + @testset "root_union!" begin root1 = find_root!(s, 7) root2 = find_root!(s, 3) @test root1 != root2 root_union!(s, 7, 3) @test find_root!(s, 7) == find_root!(s, 3) - - root1 = find_root!(s, 7, :iterative) - root2 = find_root!(s, 3, :iterative) - @test root1 != root2 - root_union!(s, 7, 3) - @test find_root!(s, 7, :iterative) == find_root!(s, 3, :iterative) - - root1 = find_root!(s, 7, :halving) - root2 = find_root!(s, 3, :halving) - @test root1 != root2 - root_union!(s, 7, 3) - @test find_root!(s, 7, :halving) == find_root!(s, 3, :halving) - - root1 = find_root!(s, 7, :splitting) - root2 = find_root!(s, 3, :splitting) - @test root1 != root2 - root_union!(s, 7, 3) - @test find_root!(s, 7, :splitting) == find_root!(s, 3, :splitting) end @testset "Some tests using non-integer disjoint sets" begin