Skip to content

Commit 9854b9e

Browse files
committed
Invalidate Distributed.create_worker to execute custom expression on initialization
1 parent 684f06f commit 9854b9e

File tree

7 files changed

+208
-48
lines changed

7 files changed

+208
-48
lines changed

docs/literate_example.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ observations .= process_member_data(SimDir(simulation.output_dir))
167167
# The simplest backend is the `JuliaBackend`, which runs all ensemble members sequentially and does not require `Distributed.jl`.
168168
# For more information, see the [`Backends`](https://clima.github.io/ClimaCalibrate.jl/dev/backends/) page.
169169
eki = CAL.calibrate(
170-
CAL.WorkerBackend,
170+
CAL.JuliaBackend,
171+
#md # CAL.WorkerBackend # We can't use this backend in Literate.jl
171172
ensemble_size,
172173
n_iterations,
173174
observations,

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ ClimaCalibrate.submit_pbs_job
5252
ClimaCalibrate.initialize
5353
ClimaCalibrate.save_G_ensemble
5454
ClimaCalibrate.update_ensemble
55+
ClimaCalibrate.update_ensemble!
5556
ClimaCalibrate.ExperimentConfig
5657
ClimaCalibrate.get_prior
5758
ClimaCalibrate.get_param_dict

src/backends.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ function calibrate(
292292
noise;
293293
ekp_kwargs...,
294294
)
295-
return calibrate(b, eki, n_iterations, prior, output_dir; worker_pool)
295+
return calibrate(b, eki, n_iterations, prior, output_dir)
296296
end
297297

298298
function calibrate(

src/workers.jl

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using Logging
33

44
export SlurmManager, PBSManager, set_worker_loggers
55

6+
worker_timeout() = parse(Float64, get(ENV, "JULIA_WORKER_TIMEOUT", "300.0"))
7+
68
get_worker_pool() = workers() == [1] ? WorkerPool() : default_worker_pool()
79

810
function run_worker_iteration(
@@ -21,7 +23,7 @@ function run_worker_iteration(
2123
remotecall_wait(forward_model, w, iter, m)
2224
end
2325
end
24-
26+
isempty(all_known_workers.workers) && @info "No workers currently available"
2527
@sync while !isempty(work_to_do)
2628
# Add new workers to worker_pool
2729
all_workers = get_worker_pool()
@@ -40,7 +42,7 @@ function run_worker_iteration(
4042
push!(worker_pool, worker)
4143
end
4244
else
43-
println("no workers available")
45+
@debug "no workers available"
4446
sleep(10) # Wait for workers to become available
4547
end
4648
end
@@ -75,7 +77,7 @@ addprocs(SlurmManager(ntasks=4))
7577
7678
# Pass additional arguments to `srun`
7779
addprocs(SlurmManager(ntasks=4), gpus_per_task=1)
78-
80+
```
7981
# Related functions
8082
- `calibrate(WorkerBackend, ...)`: Perform calibration using workers
8183
- `remotecall(func, worker_id, args...)`: Execute functions on specific workers
@@ -100,7 +102,6 @@ function Distributed.manage(
100102
)
101103
if op == :register
102104
set_worker_logger(id)
103-
evaluate_initial_expression(id, manager.expr)
104105
end
105106
end
106107

@@ -313,8 +314,6 @@ Workers inherit the current Julia environment by default.
313314
# Related Functions
314315
- `calibrate(WorkerBackend, ...)`: Perform worker calibration
315316
- `remotecall(func, worker_id, args...)`: Execute functions on specific workers
316-
317-
See also: [`addprocs`](@ref), [`Distributed`](@ref), [`SlurmManager`](@ref)
318317
"""
319318
struct PBSManager <: ClusterManager
320319
ntasks::Integer
@@ -457,6 +456,7 @@ This function should be called from the worker process.
457456
"""
458457
function set_worker_logger()
459458
@eval Main using Logging
459+
redirect_stderr(stdout)
460460
io = open("worker_$(myid()).log", "w")
461461
logger = SimpleLogger(io)
462462
Base.global_logger(logger)
@@ -478,3 +478,148 @@ function set_worker_loggers(workers = workers())
478478
end
479479
end
480480
end
481+
482+
# Copied from Distributed.jl in order to evaluate the manager's expression on worker initialization
483+
function Distributed.create_worker(
484+
manager::Union{SlurmManager, PBSManager},
485+
wconfig,
486+
)
487+
# only node 1 can add new nodes, since nobody else has the full list of address:port
488+
@assert Distributed.LPROC.id == 1
489+
timeout = worker_timeout()
490+
491+
# initiate a connect. Does not wait for connection completion in case of TCP.
492+
w = Distributed.Worker()
493+
local r_s, w_s
494+
try
495+
(r_s, w_s) = Distributed.connect(manager, w.id, wconfig)
496+
catch ex
497+
try
498+
Distributed.deregister_worker(w.id)
499+
kill(manager, w.id, wconfig)
500+
finally
501+
rethrow(ex)
502+
end
503+
end
504+
505+
w = Distributed.Worker(w.id, r_s, w_s, manager; config = wconfig)
506+
# install a finalizer to perform cleanup if necessary
507+
finalizer(w) do w
508+
if myid() == 1
509+
Distributed.manage(w.manager, w.id, w.config, :finalize)
510+
end
511+
end
512+
513+
# set when the new worker has finished connections with all other workers
514+
ntfy_oid = Distributed.RRID()
515+
rr_ntfy_join = Distributed.lookup_ref(ntfy_oid)
516+
rr_ntfy_join.waitingfor = myid()
517+
518+
# Start a new task to handle inbound messages from connected worker in master.
519+
# Also calls `wait_connected` on TCP streams.
520+
Distributed.process_messages(w.r_stream, w.w_stream, false)
521+
522+
# send address information of all workers to the new worker.
523+
# Cluster managers set the address of each worker in `WorkerConfig.connect_at`.
524+
# A new worker uses this to setup an all-to-all network if topology :all_to_all is specified.
525+
# Workers with higher pids connect to workers with lower pids. Except process 1 (master) which
526+
# initiates connections to all workers.
527+
528+
# Connection Setup Protocol:
529+
# - Master sends 16-byte cookie followed by 16-byte version string and a JoinPGRP message to all workers
530+
# - On each worker
531+
# - Worker responds with a 16-byte version followed by a JoinCompleteMsg
532+
# - Connects to all workers less than its pid. Sends the cookie, version and an IdentifySocket message
533+
# - Workers with incoming connection requests write back their Version and an IdentifySocketAckMsg message
534+
# - On master, receiving a JoinCompleteMsg triggers rr_ntfy_join (signifies that worker setup is complete)
535+
536+
join_list = []
537+
if Distributed.PGRP.topology === :all_to_all
538+
# need to wait for lower worker pids to have completed connecting, since the numerical value
539+
# of pids is relevant to the connection process, i.e., higher pids connect to lower pids and they
540+
# require the value of config.connect_at which is set only upon connection completion
541+
for jw in Distributed.PGRP.workers
542+
if (jw.id != 1) && (jw.id < w.id)
543+
# wait for wl to join
544+
# We should access this atomically using (@atomic jw.state)
545+
# but this is only recently supported
546+
if jw.state === Distributed.W_CREATED
547+
lock(jw.c_state) do
548+
wait(jw.c_state)
549+
end
550+
end
551+
push!(join_list, jw)
552+
end
553+
end
554+
555+
elseif Distributed.PGRP.topology === :custom
556+
# wait for requested workers to be up before connecting to them.
557+
filterfunc(x) =
558+
(x.id != 1) &&
559+
isdefined(x, :config) &&
560+
(
561+
notnothing(x.config.ident) in
562+
something(wconfig.connect_idents, [])
563+
)
564+
565+
wlist = filter(filterfunc, Distributed.PGRP.workers)
566+
waittime = 0
567+
while wconfig.connect_idents !== nothing &&
568+
length(wlist) < length(wconfig.connect_idents)
569+
if waittime >= timeout
570+
error("peer workers did not connect within $timeout seconds")
571+
end
572+
sleep(1.0)
573+
waittime += 1
574+
wlist = filter(filterfunc, Distributed.PGRP.workers)
575+
end
576+
577+
for wl in wlist
578+
lock(wl.c_state) do
579+
if (@atomic wl.state) === Distributed.W_CREATED
580+
# wait for wl to join
581+
wait(wl.c_state)
582+
end
583+
end
584+
push!(join_list, wl)
585+
end
586+
end
587+
588+
all_locs = Base.mapany(
589+
x ->
590+
isa(x, Distributed.Worker) ?
591+
(something(x.config.connect_at, ()), x.id) : ((), x.id, true),
592+
join_list,
593+
)
594+
Distributed.send_connection_hdr(w, true)
595+
enable_threaded_blas = something(wconfig.enable_threaded_blas, false)
596+
597+
join_message = Distributed.JoinPGRPMsg(
598+
w.id,
599+
all_locs,
600+
Distributed.PGRP.topology,
601+
enable_threaded_blas,
602+
Distributed.isclusterlazy(),
603+
)
604+
Distributed.send_msg_now(
605+
w,
606+
Distributed.MsgHeader(Distributed.RRID(0, 0), ntfy_oid),
607+
join_message,
608+
)
609+
610+
# Ensure the initial expression is evaluated before any other code
611+
@info "Evaluating initial expression on worker $(w.id)"
612+
evaluate_initial_expression(w.id, manager.expr)
613+
614+
@async Distributed.manage(w.manager, w.id, w.config, :register)
615+
616+
# wait for rr_ntfy_join with timeout
617+
if timedwait(() -> isready(rr_ntfy_join), timeout) === :timed_out
618+
error("worker did not connect within $timeout seconds")
619+
end
620+
lock(Distributed.client_refs) do
621+
delete!(Distributed.PGRP.refs, ntfy_oid)
622+
end
623+
624+
return w.id
625+
end

test/pbs_manager_unit_tests.jl

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ using Test, ClimaCalibrate, Distributed, Logging
2121
rmprocs(p)
2222
@test nprocs() == 1
2323
@test workers() == [1]
24+
25+
# Test broken arguments
26+
@test_throws TaskFailedException p = addprocs(PBSManager(1), time = "w")
2427
end
2528

26-
@testset "PBSManager - multiple processes" begin
29+
@testset "Test PBSManager multiple tasks, output file" begin
2730
out_file = "pbs_unit_test.out"
2831
p = addprocs(
2932
PBSManager(2),
@@ -37,34 +40,6 @@ end
3740
@test workers() == p
3841
@test remotecall_fetch(+, p[1], 1, 1) == 2
3942

40-
@everywhere using ClimaCalibrate
41-
# Test function with no arguments
42-
p = workers()
43-
@test ClimaCalibrate.map_remotecall_fetch(myid) == p
44-
45-
# single argument
46-
x = rand(5)
47-
@test ClimaCalibrate.map_remotecall_fetch(identity, x) == fill(x, length(p))
48-
49-
# multiple arguments
50-
@test ClimaCalibrate.map_remotecall_fetch(+, 2, 3) == fill(5, length(p))
51-
52-
# Test specified workers list
53-
@test length(ClimaCalibrate.map_remotecall_fetch(myid; workers = p[1:2])) ==
54-
2
55-
56-
# Test with more complex data structure
57-
d = Dict("a" => 1, "b" => 2)
58-
@test ClimaCalibrate.map_remotecall_fetch(identity, d) == fill(d, length(p))
59-
60-
loggers = ClimaCalibrate.set_worker_loggers()
61-
@test length(loggers) == length(p)
62-
@test typeof(loggers) == Vector{Base.CoreLogging.SimpleLogger}
63-
64-
rmprocs(p)
65-
@test nprocs() == 1
66-
@test workers() == [1]
67-
6843
@test isfile(out_file)
6944
rm(out_file)
7045
end

test/slurm_manager_unit_tests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,49 @@ using Test, ClimaCalibrate, Distributed, Logging
2727
# Test incorrect generic arguments
2828
@test_throws TaskFailedException p = addprocs(SlurmManager(1), time = "w")
2929
end
30+
31+
@testset "SlurmManager Initialization Expressions" begin
32+
p = addprocs(SlurmManager(1; expr = :(@info "test")))
33+
rmprocs(p)
34+
test_logger = TestLogger()
35+
with_logger(test_logger) do
36+
p = addprocs(SlurmManager(1; expr = :(w + 2)))
37+
rmprocs(p)
38+
end
39+
@test test_logger.logs[end].message == "Initial worker expression errored:"
40+
end
41+
42+
@testset "Test remotecall utilities" begin
43+
p = addprocs(SlurmManager(2))
44+
@test nprocs() == length(p) + 1
45+
@test workers() == p
46+
@test remotecall_fetch(+, p[1], 1, 1) == 2
47+
48+
@everywhere using ClimaCalibrate
49+
# Test function with no arguments
50+
p = workers()
51+
@test ClimaCalibrate.map_remotecall_fetch(myid) == p
52+
53+
# single argument
54+
x = rand(5)
55+
@test ClimaCalibrate.map_remotecall_fetch(identity, x) == fill(x, length(p))
56+
57+
# multiple arguments
58+
@test ClimaCalibrate.map_remotecall_fetch(+, 2, 3) == fill(5, length(p))
59+
60+
# Test specified workers list
61+
@test length(ClimaCalibrate.map_remotecall_fetch(myid; workers = p[1:2])) ==
62+
2
63+
64+
# Test with more complex data structure
65+
d = Dict("a" => 1, "b" => 2)
66+
@test ClimaCalibrate.map_remotecall_fetch(identity, d) == fill(d, length(p))
67+
68+
loggers = ClimaCalibrate.set_worker_loggers()
69+
@test length(loggers) == length(p)
70+
@test typeof(loggers) == Vector{Base.CoreLogging.SimpleLogger}
71+
72+
rmprocs(p)
73+
@test nprocs() == 1
74+
@test workers() == [1]
75+
end

test/worker_backend.jl

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ include(
88
"utils.jl",
99
),
1010
)
11-
1211
# Expression to run on worker initialization, used instead of @everywhere
1312
expr = quote
13+
using ClimaCalibrate
1414
include(
1515
joinpath(
1616
pkgdir(ClimaCalibrate),
@@ -35,16 +35,6 @@ if nworkers() == 1
3535
end
3636
end
3737

38-
@everywhere using ClimaCalibrate
39-
@everywhere include(
40-
joinpath(
41-
pkgdir(ClimaCalibrate),
42-
"experiments",
43-
"surface_fluxes_perfect_model",
44-
"model_interface.jl",
45-
),
46-
)
47-
4838
eki = calibrate(
4939
WorkerBackend,
5040
ensemble_size,
@@ -75,6 +65,8 @@ convergence_plot(
7565
g_vs_iter_plot(eki)
7666

7767
@testset "Restarts" begin
68+
initialize(ensemble_size, observation, variance, prior, output_dir)
69+
7870
last_iter = ClimaCalibrate.last_completed_iteration(output_dir)
7971
@test last_iter == n_iterations - 1
8072
ClimaCalibrate.run_worker_iteration(

0 commit comments

Comments
 (0)