Skip to content

Commit ab1b7ef

Browse files
committed
Invalidate Distributed.create_worker to execute custom expression on initialization
1 parent 684f06f commit ab1b7ef

File tree

6 files changed

+206
-47
lines changed

6 files changed

+206
-47
lines changed

docs/literate_example.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ observations .= process_member_data(SimDir(simulation.output_dir))
167167
# The simplest backend is the `JuliaBackend`, which runs all ensemble members sequentially and does not require `Distributed.jl`.
168168
# For more information, see the [`Backends`](https://clima.github.io/ClimaCalibrate.jl/dev/backends/) page.
169169
eki = CAL.calibrate(
170-
CAL.WorkerBackend,
170+
CAL.JuliaBackend,
171+
#md # CAL.WorkerBackend # We can't use this backend in Literate.jl
171172
ensemble_size,
172173
n_iterations,
173174
observations,

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ ClimaCalibrate.submit_pbs_job
5252
ClimaCalibrate.initialize
5353
ClimaCalibrate.save_G_ensemble
5454
ClimaCalibrate.update_ensemble
55+
ClimaCalibrate.update_ensemble!
5556
ClimaCalibrate.ExperimentConfig
5657
ClimaCalibrate.get_prior
5758
ClimaCalibrate.get_param_dict

src/workers.jl

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using Logging
33

44
export SlurmManager, PBSManager, set_worker_loggers
55

6+
worker_timeout() = parse(Float64, get(ENV, "JULIA_WORKER_TIMEOUT", "300.0"))
7+
68
get_worker_pool() = workers() == [1] ? WorkerPool() : default_worker_pool()
79

810
function run_worker_iteration(
@@ -21,7 +23,7 @@ function run_worker_iteration(
2123
remotecall_wait(forward_model, w, iter, m)
2224
end
2325
end
24-
26+
isempty(all_known_workers.workers) && @info "No workers currently available"
2527
@sync while !isempty(work_to_do)
2628
# Add new workers to worker_pool
2729
all_workers = get_worker_pool()
@@ -40,7 +42,7 @@ function run_worker_iteration(
4042
push!(worker_pool, worker)
4143
end
4244
else
43-
println("no workers available")
45+
@debug "no workers available"
4446
sleep(10) # Wait for workers to become available
4547
end
4648
end
@@ -75,7 +77,7 @@ addprocs(SlurmManager(ntasks=4))
7577
7678
# Pass additional arguments to `srun`
7779
addprocs(SlurmManager(ntasks=4), gpus_per_task=1)
78-
80+
```
7981
# Related functions
8082
- `calibrate(WorkerBackend, ...)`: Perform calibration using workers
8183
- `remotecall(func, worker_id, args...)`: Execute functions on specific workers
@@ -100,7 +102,6 @@ function Distributed.manage(
100102
)
101103
if op == :register
102104
set_worker_logger(id)
103-
evaluate_initial_expression(id, manager.expr)
104105
end
105106
end
106107

@@ -313,8 +314,6 @@ Workers inherit the current Julia environment by default.
313314
# Related Functions
314315
- `calibrate(WorkerBackend, ...)`: Perform worker calibration
315316
- `remotecall(func, worker_id, args...)`: Execute functions on specific workers
316-
317-
See also: [`addprocs`](@ref), [`Distributed`](@ref), [`SlurmManager`](@ref)
318317
"""
319318
struct PBSManager <: ClusterManager
320319
ntasks::Integer
@@ -478,3 +477,148 @@ function set_worker_loggers(workers = workers())
478477
end
479478
end
480479
end
480+
481+
# Copied from Distributed.jl in order to evaluate the manager's expression on worker initialization
482+
function Distributed.create_worker(
483+
manager::Union{SlurmManager, PBSManager},
484+
wconfig,
485+
)
486+
# only node 1 can add new nodes, since nobody else has the full list of address:port
487+
@assert Distributed.LPROC.id == 1
488+
timeout = worker_timeout()
489+
490+
# initiate a connect. Does not wait for connection completion in case of TCP.
491+
w = Distributed.Worker()
492+
local r_s, w_s
493+
try
494+
(r_s, w_s) = Distributed.connect(manager, w.id, wconfig)
495+
catch ex
496+
try
497+
Distributed.deregister_worker(w.id)
498+
kill(manager, w.id, wconfig)
499+
finally
500+
rethrow(ex)
501+
end
502+
end
503+
504+
w = Distributed.Worker(w.id, r_s, w_s, manager; config = wconfig)
505+
# install a finalizer to perform cleanup if necessary
506+
finalizer(w) do w
507+
if myid() == 1
508+
Distributed.manage(w.manager, w.id, w.config, :finalize)
509+
end
510+
end
511+
512+
# set when the new worker has finished connections with all other workers
513+
ntfy_oid = Distributed.RRID()
514+
rr_ntfy_join = Distributed.lookup_ref(ntfy_oid)
515+
rr_ntfy_join.waitingfor = myid()
516+
517+
# Start a new task to handle inbound messages from connected worker in master.
518+
# Also calls `wait_connected` on TCP streams.
519+
Distributed.process_messages(w.r_stream, w.w_stream, false)
520+
521+
# send address information of all workers to the new worker.
522+
# Cluster managers set the address of each worker in `WorkerConfig.connect_at`.
523+
# A new worker uses this to setup an all-to-all network if topology :all_to_all is specified.
524+
# Workers with higher pids connect to workers with lower pids. Except process 1 (master) which
525+
# initiates connections to all workers.
526+
527+
# Connection Setup Protocol:
528+
# - Master sends 16-byte cookie followed by 16-byte version string and a JoinPGRP message to all workers
529+
# - On each worker
530+
# - Worker responds with a 16-byte version followed by a JoinCompleteMsg
531+
# - Connects to all workers less than its pid. Sends the cookie, version and an IdentifySocket message
532+
# - Workers with incoming connection requests write back their Version and an IdentifySocketAckMsg message
533+
# - On master, receiving a JoinCompleteMsg triggers rr_ntfy_join (signifies that worker setup is complete)
534+
535+
join_list = []
536+
if Distributed.PGRP.topology === :all_to_all
537+
# need to wait for lower worker pids to have completed connecting, since the numerical value
538+
# of pids is relevant to the connection process, i.e., higher pids connect to lower pids and they
539+
# require the value of config.connect_at which is set only upon connection completion
540+
for jw in Distributed.PGRP.workers
541+
if (jw.id != 1) && (jw.id < w.id)
542+
# wait for wl to join
543+
# We should access this atomically using (@atomic jw.state)
544+
# but this is only recently supported
545+
if jw.state === Distributed.W_CREATED
546+
lock(jw.c_state) do
547+
wait(jw.c_state)
548+
end
549+
end
550+
push!(join_list, jw)
551+
end
552+
end
553+
554+
elseif Distributed.PGRP.topology === :custom
555+
# wait for requested workers to be up before connecting to them.
556+
filterfunc(x) =
557+
(x.id != 1) &&
558+
isdefined(x, :config) &&
559+
(
560+
notnothing(x.config.ident) in
561+
something(wconfig.connect_idents, [])
562+
)
563+
564+
wlist = filter(filterfunc, Distributed.PGRP.workers)
565+
waittime = 0
566+
while wconfig.connect_idents !== nothing &&
567+
length(wlist) < length(wconfig.connect_idents)
568+
if waittime >= timeout
569+
error("peer workers did not connect within $timeout seconds")
570+
end
571+
sleep(1.0)
572+
waittime += 1
573+
wlist = filter(filterfunc, Distributed.PGRP.workers)
574+
end
575+
576+
for wl in wlist
577+
lock(wl.c_state) do
578+
if (@atomic wl.state) === Distributed.W_CREATED
579+
# wait for wl to join
580+
wait(wl.c_state)
581+
end
582+
end
583+
push!(join_list, wl)
584+
end
585+
end
586+
587+
all_locs = Base.mapany(
588+
x ->
589+
isa(x, Distributed.Worker) ?
590+
(something(x.config.connect_at, ()), x.id) : ((), x.id, true),
591+
join_list,
592+
)
593+
Distributed.send_connection_hdr(w, true)
594+
enable_threaded_blas = something(wconfig.enable_threaded_blas, false)
595+
596+
join_message = Distributed.JoinPGRPMsg(
597+
w.id,
598+
all_locs,
599+
Distributed.PGRP.topology,
600+
enable_threaded_blas,
601+
Distributed.isclusterlazy(),
602+
)
603+
Distributed.send_msg_now(
604+
w,
605+
Distributed.MsgHeader(Distributed.RRID(0, 0), ntfy_oid),
606+
join_message,
607+
)
608+
609+
# Ensure the initial expression is evaluated before any other code
610+
@info "Evaluating initial expression on worker $(w.id)"
611+
evaluate_initial_expression(w.id, manager.expr)
612+
613+
@async Distributed.manage(w.manager, w.id, w.config, :register)
614+
615+
# wait for rr_ntfy_join with timeout
616+
if timedwait(() -> isready(rr_ntfy_join), timeout) === :timed_out
617+
error("worker did not connect within $timeout seconds")
618+
end
619+
lock(Distributed.client_refs) do
620+
delete!(Distributed.PGRP.refs, ntfy_oid)
621+
end
622+
623+
return w.id
624+
end

test/pbs_manager_unit_tests.jl

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ using Test, ClimaCalibrate, Distributed, Logging
2121
rmprocs(p)
2222
@test nprocs() == 1
2323
@test workers() == [1]
24+
25+
# Test broken arguments
26+
@test_throws TaskFailedException p = addprocs(PBSManager(1), time = "w")
2427
end
2528

26-
@testset "PBSManager - multiple processes" begin
29+
@testset "Test PBSManager multiple tasks, output file" begin
2730
out_file = "pbs_unit_test.out"
2831
p = addprocs(
2932
PBSManager(2),
@@ -37,34 +40,6 @@ end
3740
@test workers() == p
3841
@test remotecall_fetch(+, p[1], 1, 1) == 2
3942

40-
@everywhere using ClimaCalibrate
41-
# Test function with no arguments
42-
p = workers()
43-
@test ClimaCalibrate.map_remotecall_fetch(myid) == p
44-
45-
# single argument
46-
x = rand(5)
47-
@test ClimaCalibrate.map_remotecall_fetch(identity, x) == fill(x, length(p))
48-
49-
# multiple arguments
50-
@test ClimaCalibrate.map_remotecall_fetch(+, 2, 3) == fill(5, length(p))
51-
52-
# Test specified workers list
53-
@test length(ClimaCalibrate.map_remotecall_fetch(myid; workers = p[1:2])) ==
54-
2
55-
56-
# Test with more complex data structure
57-
d = Dict("a" => 1, "b" => 2)
58-
@test ClimaCalibrate.map_remotecall_fetch(identity, d) == fill(d, length(p))
59-
60-
loggers = ClimaCalibrate.set_worker_loggers()
61-
@test length(loggers) == length(p)
62-
@test typeof(loggers) == Vector{Base.CoreLogging.SimpleLogger}
63-
64-
rmprocs(p)
65-
@test nprocs() == 1
66-
@test workers() == [1]
67-
6843
@test isfile(out_file)
6944
rm(out_file)
7045
end

test/slurm_manager_unit_tests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,49 @@ using Test, ClimaCalibrate, Distributed, Logging
2727
# Test incorrect generic arguments
2828
@test_throws TaskFailedException p = addprocs(SlurmManager(1), time = "w")
2929
end
30+
31+
@testset "SlurmManager Initialization Expressions" begin
32+
p = addprocs(SlurmManager(1; expr = :(@info "test")))
33+
rmprocs(p)
34+
test_logger = TestLogger()
35+
with_logger(test_logger) do
36+
p = addprocs(SlurmManager(1; expr = :(w + 2)))
37+
rmprocs(p)
38+
end
39+
@test test_logger.logs[end].message == "Initial worker expression errored:"
40+
end
41+
42+
@testset "Test remotecall utilities" begin
43+
p = addprocs(SlurmManager(2))
44+
@test nprocs() == length(p) + 1
45+
@test workers() == p
46+
@test remotecall_fetch(+, p[1], 1, 1) == 2
47+
48+
@everywhere using ClimaCalibrate
49+
# Test function with no arguments
50+
p = workers()
51+
@test ClimaCalibrate.map_remotecall_fetch(myid) == p
52+
53+
# single argument
54+
x = rand(5)
55+
@test ClimaCalibrate.map_remotecall_fetch(identity, x) == fill(x, length(p))
56+
57+
# multiple arguments
58+
@test ClimaCalibrate.map_remotecall_fetch(+, 2, 3) == fill(5, length(p))
59+
60+
# Test specified workers list
61+
@test length(ClimaCalibrate.map_remotecall_fetch(myid; workers = p[1:2])) ==
62+
2
63+
64+
# Test with more complex data structure
65+
d = Dict("a" => 1, "b" => 2)
66+
@test ClimaCalibrate.map_remotecall_fetch(identity, d) == fill(d, length(p))
67+
68+
loggers = ClimaCalibrate.set_worker_loggers()
69+
@test length(loggers) == length(p)
70+
@test typeof(loggers) == Vector{Base.CoreLogging.SimpleLogger}
71+
72+
rmprocs(p)
73+
@test nprocs() == 1
74+
@test workers() == [1]
75+
end

test/worker_backend.jl

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ include(
88
"utils.jl",
99
),
1010
)
11-
1211
# Expression to run on worker initialization, used instead of @everywhere
1312
expr = quote
13+
using ClimaCalibrate
1414
include(
1515
joinpath(
1616
pkgdir(ClimaCalibrate),
@@ -35,16 +35,6 @@ if nworkers() == 1
3535
end
3636
end
3737

38-
@everywhere using ClimaCalibrate
39-
@everywhere include(
40-
joinpath(
41-
pkgdir(ClimaCalibrate),
42-
"experiments",
43-
"surface_fluxes_perfect_model",
44-
"model_interface.jl",
45-
),
46-
)
47-
4838
eki = calibrate(
4939
WorkerBackend,
5040
ensemble_size,
@@ -75,6 +65,8 @@ convergence_plot(
7565
g_vs_iter_plot(eki)
7666

7767
@testset "Restarts" begin
68+
initialize(ensemble_size, observation, variance, prior, output_dir)
69+
7870
last_iter = ClimaCalibrate.last_completed_iteration(output_dir)
7971
@test last_iter == n_iterations - 1
8072
ClimaCalibrate.run_worker_iteration(

0 commit comments

Comments
 (0)