Skip to content

Commit 90f44f6

Browse files
committed
fixup! Add support for worker state callbacks
1 parent d0ab810 commit 90f44f6

File tree

2 files changed

+74
-31
lines changed

2 files changed

+74
-31
lines changed

src/cluster.jl

+56-22
Original file line numberDiff line numberDiff line change
@@ -457,22 +457,34 @@ end
457457
```
458458
"""
459459
function addprocs(manager::ClusterManager; kwargs...)
460+
params = merge(default_addprocs_params(manager), Dict{Symbol, Any}(kwargs))
461+
460462
init_multi()
461463

462464
cluster_mgmt_from_master_check()
463465

464-
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager; kwargs...)
466+
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager, params)
467+
468+
callback_tasks = Dict{Any, Task}()
465469
for worker in new_workers
466-
for callback in values(worker_added_callbacks)
467-
callback(worker)
470+
for (name, callback) in worker_added_callbacks
471+
callback_tasks[name] = Threads.@spawn callback(worker)
468472
end
469473
end
470474

475+
running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
476+
while timedwait(() -> isempty(running_callbacks()), params[:callback_warning_interval]) === :timed_out
477+
callbacks_str = join(running_callbacks(), ", ")
478+
@warn "Waiting for these worker-added callbacks to finish: $(callbacks_str)"
479+
end
480+
481+
# Wait on the tasks so that exceptions bubble up
482+
wait.(values(callback_tasks))
483+
471484
return new_workers
472485
end
473486

474-
function addprocs_locked(manager::ClusterManager; kwargs...)
475-
params = merge(default_addprocs_params(manager), Dict{Symbol,Any}(kwargs))
487+
function addprocs_locked(manager::ClusterManager, params)
476488
topology(Symbol(params[:topology]))
477489

478490
if PGRP.topology !== :all_to_all
@@ -559,7 +571,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
559571
:exeflags => ``,
560572
:env => [],
561573
:enable_threaded_blas => false,
562-
:lazy => true)
574+
:lazy => true,
575+
:callback_warning_interval => 10)
563576

564577

565578
function setup_launched_worker(manager, wconfig, launched_q)
@@ -872,6 +885,8 @@ end
872885
function _add_callback(f, key, dict)
873886
if !hasmethod(f, Tuple{Int})
874887
throw(ArgumentError("Callback function is invalid, it must be able to accept a single Int argument"))
888+
elseif haskey(dict, key)
889+
throw(ArgumentError("A callback function with key '$(key)' already exists"))
875890
end
876891

877892
if isnothing(key)
@@ -889,14 +904,23 @@ _remove_callback(key, dict) = delete!(dict, key)
889904
890905
Register a callback to be called on the master process whenever a worker is
891906
added. The callback will be called with the added worker ID,
892-
e.g. `f(w::Int)`. Returns a unique key for the callback.
907+
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
908+
not specified.
909+
910+
The worker-added callbacks will be executed concurrently. If one throws an
911+
exception it will not be caught and will bubble up through [`addprocs()`](@ref).
912+
913+
Keep in mind that the callbacks will add to the time taken to launch workers; so
914+
try to either keep the callbacks fast to execute, or do the actual
915+
initialization asynchronously by spawning a task in the callback (beware of race
916+
conditions if you do this).
893917
"""
894918
add_worker_added_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_added_callbacks)
895919

896920
"""
897921
remove_worker_added_callback(key)
898922
899-
Remove the callback for `key`.
923+
Remove the callback for `key` that was added with [`add_worker_added_callback()`](@ref).
900924
"""
901925
remove_worker_added_callback(key) = _remove_callback(key, worker_added_callbacks)
902926

@@ -905,18 +929,19 @@ remove_worker_added_callback(key) = _remove_callback(key, worker_added_callbacks
905929
906930
Register a callback to be called on the master process immediately before a
907931
worker is removed with [`rmprocs()`](@ref). The callback will be called with the
908-
worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
932+
worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
933+
if `key` is not specified.
909934
910-
All callbacks will be executed asynchronously and if they don't all finish
911-
before the `callback_timeout` passed to `rmprocs()` then the process will be
912-
removed anyway.
935+
All worker-exiting callbacks will be executed concurrently and if they don't
936+
all finish before the `callback_timeout` passed to `rmprocs()` then the process
937+
will be removed anyway.
913938
"""
914939
add_worker_exiting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exiting_callbacks)
915940

916941
"""
917942
remove_worker_exiting_callback(key)
918943
919-
Remove the callback for `key`.
944+
Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
920945
"""
921946
remove_worker_exiting_callback(key) = _remove_callback(key, worker_exiting_callbacks)
922947

@@ -926,14 +951,17 @@ remove_worker_exiting_callback(key) = _remove_callback(key, worker_exiting_callb
926951
Register a callback to be called on the master process when a worker has exited
927952
for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
928953
segfaulting etc). The callback will be called with the worker ID,
929-
e.g. `f(w::Int)`. Returns a unique key for the callback.
954+
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
955+
not specified.
956+
957+
If the callback throws an exception it will be caught and printed.
930958
"""
931959
add_worker_exited_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exited_callbacks)
932960

933961
"""
934962
remove_worker_exited_callback(key)
935963
936-
Remove the callback for `key`.
964+
Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
937965
"""
938966
remove_worker_exited_callback(key) = _remove_callback(key, worker_exited_callbacks)
939967

@@ -1176,15 +1204,17 @@ function _rmprocs(pids, waitfor, callback_timeout)
11761204
lock(worker_lock)
11771205
try
11781206
# Run the callbacks
1179-
callback_tasks = Task[]
1207+
callback_tasks = Dict{Any, Task}()
11801208
for pid in pids
1181-
for callback in values(worker_exiting_callbacks)
1182-
push!(callback_tasks, Threads.@spawn callback(pid))
1209+
for (name, callback) in worker_exiting_callbacks
1210+
callback_tasks[name] = Threads.@spawn callback(pid)
11831211
end
11841212
end
11851213

1186-
if timedwait(() -> all(istaskdone.(callback_tasks)), callback_timeout) === :timed_out
1187-
@warn "Some callbacks timed out, continuing to remove workers anyway"
1214+
if timedwait(() -> all(istaskdone.(values(callback_tasks))), callback_timeout) === :timed_out
1215+
timedout_callbacks = ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
1216+
callbacks_str = join(timedout_callbacks, ", ")
1217+
@warn "Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str)"
11881218
end
11891219

11901220
rmprocset = Union{LocalProcess, Worker}[]
@@ -1335,8 +1365,12 @@ function deregister_worker(pg, pid)
13351365

13361366
# Call callbacks on the master
13371367
if myid() == 1
1338-
for callback in values(worker_exited_callbacks)
1339-
callback(pid)
1368+
for (name, callback) in worker_exited_callbacks
1369+
try
1370+
callback(pid)
1371+
catch ex
1372+
@error "Error when running worker-exited callback '$(name)'" exception=(ex, catch_backtrace())
1373+
end
13401374
end
13411375
end
13421376

test/distributed_exec.jl

+18-9
Original file line numberDiff line numberDiff line change
@@ -1937,41 +1937,50 @@ include("splitrange.jl")
19371937
end
19381938

19391939
@testset "Worker state callbacks" begin
1940-
if nprocs() > 1
1941-
rmprocs(workers())
1942-
end
1940+
rmprocs(other_workers())
19431941

19441942
# Smoke test to ensure that all the callbacks are executed
19451943
added_workers = Int[]
19461944
exiting_workers = Int[]
19471945
exited_workers = Int[]
1948-
added_key = DistributedNext.add_worker_added_callback(pid -> push!(added_workers, pid))
1946+
added_key = DistributedNext.add_worker_added_callback(pid -> (push!(added_workers, pid); error("foo")))
19491947
exiting_key = DistributedNext.add_worker_exiting_callback(pid -> push!(exiting_workers, pid))
19501948
exited_key = DistributedNext.add_worker_exited_callback(pid -> push!(exited_workers, pid))
19511949

1952-
pid = only(addprocs(1))
1950+
# Test that the worker-added exception bubbles up
1951+
@test_throws TaskFailedException addprocs(1)
1952+
1953+
pid = only(workers())
19531954
@test added_workers == [pid]
19541955
rmprocs(workers())
19551956
@test exiting_workers == [pid]
19561957
@test exited_workers == [pid]
19571958

1959+
# Trying to reset an existing callback should fail
1960+
@test_throws ArgumentError DistributedNext.add_worker_added_callback(Returns(nothing); key=added_key)
1961+
19581962
# Remove the callbacks
19591963
DistributedNext.remove_worker_added_callback(added_key)
19601964
DistributedNext.remove_worker_exiting_callback(exiting_key)
19611965
DistributedNext.remove_worker_exited_callback(exited_key)
19621966

1963-
# Test that the `callback_timeout` option works
1967+
# Test that the worker-exiting `callback_timeout` option works and that we
1968+
# get warnings about slow worker-added callbacks.
19641969
event = Base.Event()
19651970
callback_task = nothing
1971+
added_key = DistributedNext.add_worker_added_callback(_ -> sleep(0.5))
19661972
exiting_key = DistributedNext.add_worker_exiting_callback(_ -> (callback_task = current_task(); wait(event)))
1967-
addprocs(1)
19681973

1969-
@test_logs (:warn, r"Some callbacks timed out.+") rmprocs(workers(); callback_timeout=0.5)
1974+
@test_logs (:warn, r"Waiting for these worker-added callbacks.+") match_mode=:any addprocs(1; callback_warning_interval=0.05)
1975+
DistributedNext.remove_worker_added_callback(added_key)
1976+
1977+
@test_logs (:warn, r"Some worker-exiting callbacks have not yet finished.+") rmprocs(workers(); callback_timeout=0.5)
1978+
DistributedNext.remove_worker_exiting_callback(exiting_key)
19701979

19711980
notify(event)
19721981
wait(callback_task)
19731982

1974-
# Test that the previous callbacks were indeed removed
1983+
# Test that the initial callbacks were indeed removed
19751984
@test length(added_workers) == 1
19761985
@test length(exiting_workers) == 1
19771986
@test length(exited_workers) == 1

0 commit comments

Comments
 (0)