Skip to content

Commit 85c95be

Browse files
authored
Ensure worker reconnect registers existing tasks properly (#5103)
Resolve a deadlock triggered by a worker reconnect raising an exception
1 parent ca4e020 commit 85c95be

File tree

4 files changed

+194
-56
lines changed

4 files changed

+194
-56
lines changed

distributed/scheduler.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,7 +2696,13 @@ def transition_processing_memory(
26962696
ws,
26972697
key,
26982698
)
2699-
return recommendations, client_msgs, worker_msgs
2699+
worker_msgs[ts._processing_on.address] = [
2700+
{
2701+
"op": "cancel-compute",
2702+
"key": key,
2703+
"reason": "Finished on different worker",
2704+
}
2705+
]
27002706

27012707
has_compute_startstop: bool = False
27022708
compute_start: double
@@ -4234,19 +4240,25 @@ async def add_worker(
42344240
client_msgs: dict = {}
42354241
worker_msgs: dict = {}
42364242
if nbytes:
4243+
assert isinstance(nbytes, dict)
42374244
for key in nbytes:
42384245
ts: TaskState = parent._tasks.get(key)
4239-
if ts is not None and ts._state in ("processing", "waiting"):
4240-
t: tuple = parent._transition(
4241-
key,
4242-
"memory",
4243-
worker=address,
4244-
nbytes=nbytes[key],
4245-
typename=types[key],
4246-
)
4247-
recommendations, client_msgs, worker_msgs = t
4248-
parent._transitions(recommendations, client_msgs, worker_msgs)
4249-
recommendations = {}
4246+
if ts is not None:
4247+
if ts.state == "memory":
4248+
self.add_keys(worker=address, keys=[key])
4249+
else:
4250+
t: tuple = parent._transition(
4251+
key,
4252+
"memory",
4253+
worker=address,
4254+
nbytes=nbytes[key],
4255+
typename=types[key],
4256+
)
4257+
recommendations, client_msgs, worker_msgs = t
4258+
parent._transitions(
4259+
recommendations, client_msgs, worker_msgs
4260+
)
4261+
recommendations = {}
42504262

42514263
for ts in list(parent._unrunnable):
42524264
valid: set = self.valid_workers(ts)
@@ -4659,10 +4671,15 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
46594671
ts: TaskState = parent._tasks.get(key)
46604672
if ts is None:
46614673
return recommendations, client_msgs, worker_msgs
4674+
4675+
if ts.state == "memory":
4676+
self.add_keys(worker=worker, keys=[key])
4677+
return recommendations, client_msgs, worker_msgs
4678+
46624679
ws: WorkerState = parent._workers_dv[worker]
46634680
ts._metadata.update(kwargs["metadata"])
46644681

4665-
if ts._state == "processing":
4682+
if ts._state != "released":
46664683
r: tuple = parent._transition(key, "memory", worker=worker, **kwargs)
46674684
recommendations, client_msgs, worker_msgs = r
46684685

@@ -5580,12 +5597,11 @@ async def gather(self, comm=None, keys=None, serializers=None):
55805597
# Remove suspicious workers from the scheduler but allow them to
55815598
# reconnect.
55825599
await asyncio.gather(
5583-
*[
5600+
*(
55845601
self.remove_worker(address=worker, close=False)
55855602
for worker in missing_workers
5586-
]
5603+
)
55875604
)
5588-
55895605
recommendations: dict
55905606
client_msgs: dict = {}
55915607
worker_msgs: dict = {}

distributed/tests/test_scheduler.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
tls_only_security,
4444
varying,
4545
)
46-
from distributed.worker import dumps_function, dumps_task
46+
from distributed.worker import dumps_function, dumps_task, get_worker
4747

4848
if sys.version_info < (3, 8):
4949
try:
@@ -2182,38 +2182,67 @@ async def test_gather_no_workers(c, s, a, b):
21822182
assert list(res["keys"]) == ["x"]
21832183

21842184

2185+
@pytest.mark.slow
2186+
@pytest.mark.parametrize("reschedule_different_worker", [True, False])
2187+
@pytest.mark.parametrize("swap_data_insert_order", [True, False])
21852188
@gen_cluster(client=True, client_kwargs={"direct_to_workers": False})
2186-
async def test_gather_allow_worker_reconnect(c, s, a, b):
2189+
async def test_gather_allow_worker_reconnect(
2190+
c, s, a, b, reschedule_different_worker, swap_data_insert_order
2191+
):
21872192
"""
21882193
Test that client resubmissions allow failed workers to reconnect and re-use
21892194
their results. Failure scenario would be a connection issue during result
21902195
gathering.
21912196
Upon connection failure, the worker is flagged as suspicious and removed
21922197
from the scheduler. If the worker is healthy and reconnencts we want to use
21932198
its results instead of recomputing them.
2199+
2200+
See also distributed.tests.test_worker.py::test_worker_reconnects_mid_compute
21942201
"""
21952202
# GH3246
2196-
already_calculated = []
2197-
2198-
import time
2199-
2200-
def inc_slow(x):
2201-
# Once the graph below is rescheduled this computation runs again. We
2202-
# need to sleep for at least 0.5 seconds to give the worker a chance to
2203-
# reconnect (Heartbeat timing). In slow CI situations, the actual
2204-
# reconnect might take a bit longer, therefore wait more
2205-
if x in already_calculated:
2206-
time.sleep(2)
2207-
already_calculated.append(x)
2203+
if reschedule_different_worker:
2204+
from distributed.diagnostics.plugin import SchedulerPlugin
2205+
2206+
class SwitchRestrictions(SchedulerPlugin):
2207+
def __init__(self, scheduler):
2208+
self.scheduler = scheduler
2209+
2210+
def transition(self, key, start, finish, **kwargs):
2211+
if key in ("reducer", "final") and finish == "memory":
2212+
self.scheduler.tasks[key]._worker_restrictions = {b.address}
2213+
2214+
plugin = SwitchRestrictions(s)
2215+
s.add_plugin(plugin)
2216+
2217+
from distributed import Lock
2218+
2219+
b_address = b.address
2220+
2221+
def inc_slow(x, lock):
2222+
w = get_worker()
2223+
if w.address == b_address:
2224+
with lock:
2225+
return x + 1
22082226
return x + 1
22092227

2210-
x = c.submit(inc_slow, 1)
2211-
y = c.submit(inc_slow, 2)
2228+
lock = Lock()
2229+
2230+
await lock.acquire()
2231+
2232+
x = c.submit(inc_slow, 1, lock, workers=[a.address], allow_other_workers=True)
2233+
2234+
def reducer(*args):
2235+
return get_worker().address
22122236

2213-
def reducer(x, y):
2214-
return x + y
2237+
def finalizer(addr):
2238+
if swap_data_insert_order:
2239+
w = get_worker()
2240+
new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]}
2241+
w.data = new_data
2242+
return addr
22152243

2216-
z = c.submit(reducer, x, y)
2244+
z = c.submit(reducer, x, key="reducer", workers=[a.address])
2245+
fin = c.submit(finalizer, z, key="final", workers=[a.address])
22172246

22182247
s.rpc = await FlakyConnectionPool(failing_connections=1)
22192248

@@ -2227,9 +2256,31 @@ def reducer(x, y):
22272256
) as client_logger:
22282257
# Gather using the client (as an ordinary user would)
22292258
# Upon a missing key, the client will reschedule the computations
2230-
res = await c.gather(z)
2259+
res = None
2260+
while not res:
2261+
try:
2262+
# This reduces test runtime by about a second since we're
2263+
# depending on a worker heartbeat for a reconnect.
2264+
res = await asyncio.wait_for(fin, 0.1)
2265+
except asyncio.TimeoutError:
2266+
await a.heartbeat()
2267+
2268+
# Ensure that we're actually reusing the result
2269+
assert res == a.address
2270+
await lock.release()
2271+
2272+
while not all(all(ts.state == "memory" for ts in w.tasks.values()) for w in [a, b]):
2273+
await asyncio.sleep(0.01)
22312274

2232-
assert res == 5
2275+
assert z.key in a.tasks
2276+
assert z.key not in b.tasks
2277+
assert b.executed_count == 1
2278+
for w in [a, b]:
2279+
assert x.key in w.tasks
2280+
assert w.tasks[x.key].state == "memory"
2281+
while not len(s.tasks[x.key].who_has) == 2:
2282+
await asyncio.sleep(0.01)
2283+
assert len(s.tasks[z.key].who_has) == 1
22332284

22342285
sched_logger = sched_logger.getvalue()
22352286
client_logger = client_logger.getvalue()
@@ -2245,24 +2296,6 @@ def reducer(x, y):
22452296
# is rather an artifact and not the intention
22462297
assert "Workers don't have promised key" in sched_logger
22472298

2248-
# Once the worker reconnects, it will also submit the keys it holds such
2249-
# that the scheduler again knows about the result.
2250-
# The final reduce step should then be used from the re-connected worker
2251-
# instead of recomputing it.
2252-
transitions_to_processing = [
2253-
(key, start, timestamp)
2254-
for key, start, finish, recommendations, timestamp in s.transition_log
2255-
if finish == "processing" and "reducer" in key
2256-
]
2257-
assert len(transitions_to_processing) == 1
2258-
2259-
finish_processing_transitions = 0
2260-
for transition in s.transition_log:
2261-
key, start, finish, recommendations, timestamp = transition
2262-
if "reducer" in key and finish == "processing":
2263-
finish_processing_transitions += 1
2264-
assert finish_processing_transitions == 1
2265-
22662299

22672300
@gen_cluster(client=True)
22682301
async def test_too_many_groups(c, s, a, b):

distributed/tests/test_worker.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,71 @@ async def test_hold_on_to_replicas(c, s, *workers):
23722372
await asyncio.sleep(0.01)
23732373

23742374

2375+
@gen_cluster(client=True)
2376+
async def test_worker_reconnects_mid_compute(c, s, a, b):
2377+
"""
2378+
This test ensure that if a worker disconnects while computing a result, the scheduler will still accept the result.
2379+
2380+
There is also an edge case tested which ensures that the reconnect is
2381+
successful if a task is currently executing, see
2382+
https://github.com/dask/distributed/issues/5078
2383+
2384+
See also distributed.tests.test_scheduler.py::test_gather_allow_worker_reconnect
2385+
"""
2386+
with captured_logger("distributed.scheduler") as s_logs:
2387+
# Let's put one task in memory to ensure the reconnect has tasks in
2388+
# different states
2389+
f1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
2390+
await f1
2391+
a_address = a.address
2392+
a.periodic_callbacks["heartbeat"].stop()
2393+
await a.heartbeat()
2394+
a.heartbeat_active = True
2395+
2396+
from distributed import Lock
2397+
2398+
def fast_on_a(lock):
2399+
w = get_worker()
2400+
import time
2401+
2402+
if w.address != a_address:
2403+
lock.acquire()
2404+
else:
2405+
time.sleep(1)
2406+
2407+
lock = Lock()
2408+
# We want to be sure that A is the only one computing this result
2409+
async with lock:
2410+
2411+
f2 = c.submit(
2412+
fast_on_a, lock, workers=[a.address], allow_other_workers=True
2413+
)
2414+
2415+
while f2.key not in a.tasks:
2416+
await asyncio.sleep(0.01)
2417+
2418+
await s.stream_comms[a.address].close()
2419+
2420+
assert len(s.workers) == 1
2421+
a.heartbeat_active = False
2422+
await a.heartbeat()
2423+
assert len(s.workers) == 2
2424+
# Since B is locked, this is ensured to originate from A
2425+
await f2
2426+
2427+
assert "Unexpected worker completed task" in s_logs.getvalue()
2428+
2429+
while not len(s.tasks[f2.key].who_has) == 2:
2430+
await asyncio.sleep(0.001)
2431+
2432+
# Ensure that all keys have been properly registered and will also be
2433+
# cleaned up nicely.
2434+
del f1, f2
2435+
2436+
while any(w.tasks for w in [a, b]):
2437+
await asyncio.sleep(0.001)
2438+
2439+
23752440
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
23762441
async def test_forget_dependents_after_release(c, s, a):
23772442

distributed/worker.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def __init__(
699699
stream_handlers = {
700700
"close": self.close,
701701
"compute-task": self.add_task,
702+
"cancel-compute": self.cancel_compute,
702703
"free-keys": self.handle_free_keys,
703704
"superfluous-data": self.handle_superfluous_data,
704705
"steal-request": self.steal_request,
@@ -901,7 +902,14 @@ async def _register_with_scheduler(self):
901902
keys=list(self.data),
902903
nthreads=self.nthreads,
903904
name=self.name,
904-
nbytes={ts.key: ts.get_nbytes() for ts in self.tasks.values()},
905+
nbytes={
906+
ts.key: ts.get_nbytes()
907+
for ts in self.tasks.values()
908+
# Only if the task is in memory this is a sensible
909+
# result since otherwise it simply submits the
910+
# default value
911+
if ts.state == "memory"
912+
},
905913
types={k: typename(v) for k, v in self.data.items()},
906914
now=time(),
907915
resources=self.total_resources,
@@ -1544,6 +1552,22 @@ async def set_resources(self, **resources):
15441552
# Task Management #
15451553
###################
15461554

1555+
def cancel_compute(self, key, reason):
1556+
"""
1557+
Cancel a task on a best effort basis. This is only possible while a task
1558+
is in state `waiting` or `ready`.
1559+
Nothing will happen otherwise.
1560+
"""
1561+
ts = self.tasks.get(key)
1562+
if ts and ts.state in ("waiting", "ready"):
1563+
self.log.append((key, "cancel-compute", reason))
1564+
ts.scheduler_holds_ref = False
1565+
# All possible dependents of TS should not be in state Processing on
1566+
# scheduler side and therefore should not be assigned to a worker,
1567+
# yet.
1568+
assert not ts.dependents
1569+
self.release_key(key, reason=reason, report=False)
1570+
15471571
def add_task(
15481572
self,
15491573
key,

0 commit comments

Comments
 (0)