From 390047b4898cc428314443f16859a9bc4468ae89 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Wed, 8 Jan 2025 19:49:46 +0100 Subject: [PATCH] Make AbstractWorkerPool methods thread-safe and more consistent Previously they did not handle dead workers in the same way. In particular `take!` would remove dead workers but none of the other methods did, leading to cases where `isready` might return true but `take!` would still block. Now they should all be consistent with each other; except for `wait` which will block if the pool is empty, unlike `take!` which will throw an exception. This seems like a reasonable tradeoff to minimize breakage while still ensuring that 'take!() will block if wait() blocks' holds. In theory one could put the dead worker checks in other methods like `length` and `put!`, but the checks would still need to be in `take!`/`isready` etc so it seems simpler to just acknowledge the lack of thread-safety in these methods upfront. --- docs/src/_changelog.md | 5 +++ src/workerpool.jl | 78 ++++++++++++++++++++++++++++++++-------- test/distributed_exec.jl | 19 ++++++++++ 3 files changed, 88 insertions(+), 14 deletions(-) diff --git a/docs/src/_changelog.md b/docs/src/_changelog.md index 5d9207f..133749d 100644 --- a/docs/src/_changelog.md +++ b/docs/src/_changelog.md @@ -11,6 +11,11 @@ This documents notable changes in DistributedNext.jl. The format is based on ### Fixed - Fixed a cause of potential hangs when exiting the process ([#16]). +- Modified the default implementations of methods like `take!` and `wait` on + [`AbstractWorkerPool`](@ref) to be threadsafe and behave more consistently + with each other. This is technically breaking, but it's a strict bugfix to + correct previous inconsistent behaviour so it will still land in a minor + release. ### Added - A watcher mechanism has been added to detect when both the Distributed stdlib diff --git a/src/workerpool.jl b/src/workerpool.jl index 92b02ff..db8eedd 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -16,6 +16,10 @@ The default implementations of the above (on a `AbstractWorkerPool`) require fie - `channel::Channel{Int}` - `workers::Set{Int}` where `channel` contains free worker pids and `workers` is the set of all workers associated with this pool. + +The default implementations of the above handle dead workers by removing them +from the pool. Be aware that since workers could die at any time, depending on +the results of functions like `length` or `isready` is not thread-safe. """ abstract type AbstractWorkerPool end @@ -71,7 +75,43 @@ deserialize(S::AbstractSerializer, t::Type{T}) where {T<:WorkerPool} = T(deseria wp_local_push!(pool::AbstractWorkerPool, w::Int) = (push!(pool.workers, w); put!(pool.channel, w); pool) wp_local_length(pool::AbstractWorkerPool) = length(pool.workers) -wp_local_isready(pool::AbstractWorkerPool) = isready(pool.channel) + +function check_valid_worker!(pool::AbstractWorkerPool, worker) + if !id_in_procs(worker) + # We abuse the Channel lock to provide thread-safety when we modify the + # worker set. + @lock pool.channel delete!(pool.workers, worker) + return false + else + return true + end +end + +function default_and_empty(pool::AbstractWorkerPool) + length(pool) == 0 && pool === default_worker_pool() +end + +function wp_local_isready(pool::AbstractWorkerPool) + if default_and_empty(pool) + # This state wouldn't block take!() so we return true + return true + end + + # Otherwise we lock the channel to prevent anyone else from touching it and + # take!() until we either run out of workers or get a valid one. Locking is + # necessary to avoid blocking on take!() or fetch(). + @lock pool.channel begin + while isready(pool.channel) + worker = take!(pool.channel) + if check_valid_worker!(pool, worker) + put!(pool.channel, worker) + break + end + end + + return isready(pool.channel) + end +end function wp_local_put!(pool::AbstractWorkerPool, w::Int) # In case of default_worker_pool, the master is implicitly considered a worker, i.e., @@ -101,29 +141,39 @@ function wp_local_take!(pool::AbstractWorkerPool) # Find an active worker worker = 0 while true - if length(pool) == 0 - if pool === default_worker_pool() - # No workers, the master process is used as a worker - worker = 1 - break - else - throw(ErrorException("No active worker available in pool")) - end + if default_and_empty(pool) + # No workers, the master process is used as a worker + worker = 1 + break + elseif length(pool) == 0 + throw(ErrorException("No active worker available in pool")) end worker = take!(pool.channel) - if id_in_procs(worker) + if check_valid_worker!(pool, worker) break - else - delete!(pool.workers, worker) # Remove invalid worker from pool end end return worker end function wp_local_wait(pool::AbstractWorkerPool) - wait(pool.channel) - return nothing + if default_and_empty(pool) + # This state wouldn't block take!() so we return + return nothing + end + + while true + # We don't use take!(::AbstractWorkerPool) because that will throw if + # the pool is empty. This will wait forever until one becomes + # available. + worker = take!(pool.channel) + + if check_valid_worker!(pool, worker) + put!(pool.channel, worker) + return nothing + end + end end function remotecall_pool(rc_f, f, pool::AbstractWorkerPool, args...; kwargs...) diff --git a/test/distributed_exec.jl b/test/distributed_exec.jl index 8bfc462..e5e5c9b 100644 --- a/test/distributed_exec.jl +++ b/test/distributed_exec.jl @@ -744,6 +744,25 @@ end status = timedwait(() -> isready(f), 10) @test status == :ok + # Test behaviour with missing workers. Note that pool_workers is assigned + # such that the FIFO behaviour of Channel's will ensure that all the tested + # methods will see the bad_worker first. + bad_worker = maximum(workers()) + 1 + pool_workers = [bad_worker, 1] + + wp = WorkerPool(pool_workers) + @test take!(wp) == 1 # Test take!() + @test !isready(wp) + @test bad_worker ∉ wp.workers + + @test !isready(WorkerPool([bad_worker])) + + wp = WorkerPool(pool_workers) + # This should not hang, and it should end up removing the dead worker + wait(wp) + @test isready(wp) + @test bad_worker ∉ wp.workers + # CachingPool tests wp = CachingPool(workers()) @test [1:100...] == pmap(x->x, wp, 1:100)