From ef81d8b0087a22f3facfd68959a1421b7fc21cc6 Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Tue, 7 Oct 2025 13:14:34 -0700 Subject: [PATCH 1/6] Add clear terminology, type hints, and failing tests for spec name consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Establish four key concepts with precise terminology: - Spec name: String key in worker_spec dict (e.g., "0", "my-worker") - Worker name: String scheduler sees (e.g., "0-0", "0-1") - Worker spec: Dict with 'cls', 'options', 'group' (in worker_spec) - Worker instance: Actual Worker/Nanny object (in workers) Changes: - Rename _new_worker_name() → _new_spec_name() for clarity - Add type hints: worker_spec: dict[str, dict], workers: dict[str, Worker | Nanny] - Update SpecCluster docstring with Terminology and Grouped Workers sections - Fix examples to use string keys ("0", "1") and consistent variable names - Add failing tests (TDD): test_new_spec_name_returns_string, test_worker_spec_keys_are_strings - Ensure spec numbers are strings in worker spec dictionary --- distributed/deploy/spec.py | 166 +++++++++++------- distributed/deploy/tests/test_spec_cluster.py | 63 ++++++- 2 files changed, 154 insertions(+), 75 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 7da310a2e6..9d88e96d08 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -7,10 +7,10 @@ import logging import math import weakref -from collections.abc import Awaitable, Generator +from collections.abc import Awaitable, Generator, Iterable from contextlib import suppress from inspect import isawaitable -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast from tornado import gen from tornado.ioloop import IOLoop @@ -137,16 +137,48 @@ class does handle all of the logic around asynchronously cleanly setting up and tearing things down at the right times. Hopefully it can form a base for other more user-centric classes. + Terminology + ----------- + **Spec name**: The string key in the ``worker_spec`` dictionary (e.g., ``"0"``, + ``"my-worker"``). This identifies a worker specification entry. + + **Worker name**: The actual name a worker reports to the scheduler (e.g., ``"0"``, + ``"0-0"``, ``"0-1"``). This is what appears in ``scheduler.workers``. + + For **regular workers**: spec name == worker name (one-to-one mapping) + For **grouped workers**: one spec name → multiple worker names (one-to-many mapping) + + Grouped Workers + --------------- + A single spec entry can generate multiple Dask workers by including a ``"group"`` + element with suffixes. This is useful for: + - HPC systems (e.g., SLURM) where multiple processes are allocated together + - Any worker class that manages multiple workers as a unit (e.g., MultiWorker) + + >>> cluster.worker_spec = { + ... "0": {"cls": MultiWorker, "options": {"processes": 3}, "group": ["-0", "-1", "-2"]}, + ... "1": {"cls": MultiWorker, "options": {"processes": 2}, "group": ["-0", "-1"]} + ... } + + The scheduler sees individual workers with concatenated names: + + >>> [ws.name for ws in cluster.scheduler.workers.values()] + ["0-0", "0-1", "0-2", "1-0", "1-1"] + + When any worker in a group fails, the entire spec is removed so the group + can be recreated as a unit (important for HPC where the whole allocation fails). + Parameters ---------- - workers: dict - A dictionary mapping names to worker classes and their specifications - See example below + workers: dict[str, dict], optional + A dictionary mapping spec names (strings) to worker specifications. + Each worker spec is a dict with 'cls' and optionally 'options' and 'group'. + Spec names must be strings. scheduler: dict, optional - A similar mapping for a scheduler - worker: dict - A specification of a single worker. - This is used for any new workers that are created. + A specification for the scheduler with 'cls' and 'options' keys + worker: dict, optional + A worker specification template used when calling scale(). + This template is used to auto-generate new worker specs. asynchronous: bool If this is intended to be used directly within an event loop with async/await @@ -161,17 +193,17 @@ class does handle all of the logic around asynchronously cleanly setting up Examples -------- - To create a SpecCluster you specify how to set up a Scheduler and Workers + To create a SpecCluster you specify worker specifications and a scheduler spec >>> from dask.distributed import Scheduler, Worker, Nanny >>> scheduler = {'cls': Scheduler, 'options': {"dashboard_address": ':8787'}} - >>> workers = { + >>> worker_spec = { ... 'my-worker': {"cls": Worker, "options": {"nthreads": 1}}, ... 'my-nanny': {"cls": Nanny, "options": {"nthreads": 2}}, ... } - >>> cluster = SpecCluster(scheduler=scheduler, workers=workers) + >>> cluster = SpecCluster(scheduler=scheduler, workers=worker_spec) - The worker spec is stored as the ``.worker_spec`` attribute + The worker specs are stored in the ``.worker_spec`` attribute >>> cluster.worker_spec { @@ -179,8 +211,8 @@ class does handle all of the logic around asynchronously cleanly setting up 'my-nanny': {"cls": Nanny, "options": {"nthreads": 2}}, } - While the instantiation of this spec is stored in the ``.workers`` - attribute + The actual Worker instances created from these specs are stored in the + ``.workers`` attribute >>> cluster.workers { @@ -188,53 +220,36 @@ class does handle all of the logic around asynchronously cleanly setting up 'my-nanny': } - Should the spec change, we can await the cluster or call the - ``._correct_state`` method to align the actual state to the specified - state. + Should the worker_spec change, we can await the cluster or call the + ``._correct_state`` method to align the actual Worker instances to the + specified state. - We can also ``.scale(...)`` the cluster, which adds new workers of a given - form. + We can also ``.scale(...)`` the cluster, which adds new worker specs using + the template provided via the ``worker`` parameter. - >>> worker = {'cls': Worker, 'options': {}} - >>> cluster = SpecCluster(scheduler=scheduler, worker=worker) + >>> worker_template = {'cls': Worker, 'options': {}} + >>> cluster = SpecCluster(scheduler=scheduler, worker=worker_template) >>> cluster.worker_spec {} >>> cluster.scale(3) >>> cluster.worker_spec { - 0: {'cls': Worker, 'options': {}}, - 1: {'cls': Worker, 'options': {}}, - 2: {'cls': Worker, 'options': {}}, + "0": {'cls': Worker, 'options': {}}, + "1": {'cls': Worker, 'options': {}}, + "2": {'cls': Worker, 'options': {}}, } Note that above we are using the standard ``Worker`` and ``Nanny`` classes, however in practice other classes could be used that handle resource - management like ``KubernetesPod`` or ``SLURMJob``. The spec does not need - to conform to the expectations of the standard Dask Worker class. It just - needs to be called with the provided options, support ``__await__`` and - ``close`` methods and the ``worker_address`` property.. - - Also note that uniformity of the specification is not required. Other API - could be added externally (in subclasses) that adds workers of different - specifications into the same dictionary. - - If a single entry in the spec will generate multiple dask workers then - please provide a `"group"` element to the spec, that includes the suffixes - that will be added to each name (this should be handled by your worker - class). - - >>> cluster.worker_spec - { - 0: {"cls": MultiWorker, "options": {"processes": 3}, "group": ["-0", "-1", -2"]} - 1: {"cls": MultiWorker, "options": {"processes": 2}, "group": ["-0", "-1"]} - } - - These suffixes should correspond to the names used by the workers when - they deploy. - - >>> [ws.name for ws in cluster.scheduler.workers.values()] - ["0-0", "0-1", "0-2", "1-0", "1-1"] + management like ``KubernetesPod`` or ``SLURMJob``. Worker specs do not need + to conform to the expectations of the standard Dask Worker class. They just + need to be called with the provided options, support ``__await__`` and + ``close`` methods and the ``worker_address`` property. + + Also note that uniformity of worker specs is not required. Other API + could be added externally (in subclasses) that adds worker specs of different + types into the same worker_spec dictionary. """ _instances: ClassVar[weakref.WeakSet[SpecCluster]] = weakref.WeakSet() @@ -260,10 +275,10 @@ def __init__( self._created = weakref.WeakSet() self.scheduler_spec = copy.copy(scheduler) - self.worker_spec = copy.copy(workers) or {} - self.new_spec = copy.copy(worker) + self.worker_spec: dict[str, dict[str, Any]] = copy.copy(workers) or {} + self.new_spec: dict[str, Any] | None = copy.copy(worker) self.scheduler = None - self.workers = {} + self.workers: dict[str, Worker | Nanny] = {} self._i = 0 self.security = security or Security() self._futures = set() @@ -538,37 +553,56 @@ def scale(self, n=0, memory=None, cores=None): if self.asynchronous: return NoOpAwaitable() - def _new_worker_name(self, worker_number): - """Returns new worker name. + def _new_spec_name(self, spec_number: int) -> str: + """Returns new spec name (key for worker_spec dict). - This can be overridden in SpecCluster derived classes to customise the - worker names. + This generates a spec name for auto-created worker specs. For regular + workers, the spec name will also be the worker name. For grouped workers, + the spec name is the prefix, and actual worker names will have suffixes + appended (e.g., spec name "0" with group ["-0", "-1"] creates workers + "0-0" and "0-1"). + + This can be overridden in SpecCluster derived classes to customize spec + naming. + + Parameters + ---------- + spec_number : int + The numeric identifier for this spec (typically from self._i) + + Returns + ------- + str + The spec name to use as a key in worker_spec dict """ - return worker_number + return str(spec_number) - def new_worker_spec(self): - """Return name and spec for the next worker + def new_worker_spec(self) -> dict[str, dict[str, Any]]: + """Return name and spec for the next worker spec Returns ------- - d: dict mapping names to worker specs + dict[str, dict] + A dictionary with a single entry mapping a spec name (string) to + a worker specification dict See Also -------- scale """ - new_worker_name = self._new_worker_name(self._i) - while new_worker_name in self.worker_spec: + spec_name = self._new_spec_name(self._i) + while spec_name in self.worker_spec: self._i += 1 - new_worker_name = self._new_worker_name(self._i) + spec_name = self._new_spec_name(self._i) - return {new_worker_name: self.new_spec} + return {spec_name: cast(dict[str, Any], self.new_spec)} @property def _supports_scaling(self): return bool(self.new_spec) - async def scale_down(self, workers): + async def scale_down(self, workers: Iterable[str]) -> None: + """Scale down by removing worker specs.""" # We may have groups, if so, map worker addresses to job names if not all(w in self.worker_spec for w in workers): mapping = {} diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index ce39d20bb1..f87c4e8500 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -484,30 +484,31 @@ async def test_run_spec(c, s): @gen_test() -async def test_run_spec_cluster_worker_names(): +async def test_run_spec_cluster_custom_spec_names(): + """Test that _new_spec_name() can be overridden to customize spec names""" worker = {"cls": Worker, "options": {"nthreads": 1}} class MyCluster(SpecCluster): - def _new_worker_name(self, worker_number): - return f"prefix-{self.name}-{worker_number}-suffix" + def _new_spec_name(self, spec_number): + return f"prefix-{self.name}-{spec_number}-suffix" async with SpecCluster( asynchronous=True, scheduler=scheduler, worker=worker ) as cluster: cluster.scale(2) await cluster - worker_names = [0, 1] - assert list(cluster.worker_spec) == worker_names - assert sorted(list(cluster.workers)) == worker_names + spec_names = [0, 1] + assert list(cluster.worker_spec) == spec_names + assert sorted(list(cluster.workers)) == spec_names async with MyCluster( asynchronous=True, scheduler=scheduler, worker=worker, name="test-name" ) as cluster: - worker_names = ["prefix-test-name-0-suffix", "prefix-test-name-1-suffix"] + spec_names = ["prefix-test-name-0-suffix", "prefix-test-name-1-suffix"] cluster.scale(2) await cluster - assert list(cluster.worker_spec) == worker_names - assert sorted(list(cluster.workers)) == worker_names + assert list(cluster.worker_spec) == spec_names + assert sorted(list(cluster.workers)) == spec_names @gen_test() @@ -544,3 +545,47 @@ async def test_shutdown_scheduler(): assert isinstance(s, Scheduler) assert s.status == Status.closed + + +@gen_test() +async def test_new_spec_name_returns_string(): + """Test that _new_spec_name() returns strings, not integers. + + Spec names (keys in worker_spec dict) should always be strings, whether + auto-generated or user-provided. This ensures type consistency and + eliminates int/str conversion issues throughout the codebase. + """ + async with SpecCluster( + workers={}, scheduler=scheduler, asynchronous=True + ) as cluster: + # Test that _new_spec_name returns a string + name = cluster._new_spec_name(0) + assert isinstance(name, str), f"Expected str, got {type(name).__name__}" + assert name == "0" + + name = cluster._new_spec_name(42) + assert isinstance(name, str), f"Expected str, got {type(name).__name__}" + assert name == "42" + + +@gen_test() +async def test_worker_spec_keys_are_strings(): + """Test that worker_spec keys are strings after scaling. + + When workers are added via scale(), the resulting spec names (keys in + worker_spec dict) should be strings to maintain consistency with + user-provided specs. + """ + worker_template = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + workers={}, scheduler=scheduler, worker=worker_template, asynchronous=True + ) as cluster: + # Scale up to create auto-generated worker specs + cluster.scale(2) + await cluster + + # All keys in worker_spec should be strings + for key in cluster.worker_spec.keys(): + assert isinstance( + key, str + ), f"Expected str key, got {type(key).__name__}: {key}" From ab91dc639ae391cf81b7936e9d69d7e8c0e9273d Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Tue, 7 Oct 2025 13:17:53 -0700 Subject: [PATCH 2/6] fix test_run_spec_cluster_custom_spec_names to also use string names as now expected --- distributed/deploy/tests/test_spec_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index f87c4e8500..069cda8900 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -497,7 +497,7 @@ def _new_spec_name(self, spec_number): ) as cluster: cluster.scale(2) await cluster - spec_names = [0, 1] + spec_names = ["0", "1"] assert list(cluster.worker_spec) == spec_names assert sorted(list(cluster.workers)) == spec_names From 3bf4f3be33796e5d86c7f3c00777b6fa1e5545a4 Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Tue, 7 Oct 2025 13:27:48 -0700 Subject: [PATCH 3/6] =?UTF-8?q?Add=20helper=20methods=20for=20spec=20name?= =?UTF-8?q?=20=E2=86=94=20worker=20name=20conversion=20and=20associated=20?= =?UTF-8?q?tests=20to=20enable=20grouped=20worker=20support=20in=20SpecClu?= =?UTF-8?q?ster=20=20=20-=20`=5Fspec=5Fname=5Fto=5Fworker=5Fnames()`:=20Ma?= =?UTF-8?q?ps=20spec=20name=20=E2=86=92=20worker=20names=20=20=20=20=20-?= =?UTF-8?q?=20Regular=20workers:=201:1=20mapping=20(spec=20"0"=20=E2=86=92?= =?UTF-8?q?=20worker=20"0")=20=20=20=20=20-=20Grouped=20workers:=201:many?= =?UTF-8?q?=20mapping=20(spec=20"0"=20=E2=86=92=20workers=20"0-0",=20"0-1"?= =?UTF-8?q?,=20"0-2")?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `_worker_name_to_spec_name()`: Maps worker name → spec name - Handles both regular and grouped workers - Returns None if worker name not found --- distributed/deploy/spec.py | 87 +++++++++++++ distributed/deploy/tests/test_spec_cluster.py | 117 ++++++++++++++++++ 2 files changed, 204 insertions(+) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 9d88e96d08..700f88f0c4 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -577,6 +577,93 @@ def _new_spec_name(self, spec_number: int) -> str: """ return str(spec_number) + def _spec_name_to_worker_names(self, spec_name: str) -> set[str]: + """Convert a spec name to the set of worker names it generates. + + For regular workers, the spec name equals the worker name (1:1 mapping). + For grouped workers, one spec name maps to multiple worker names (1:many). + + Parameters + ---------- + spec_name : str + The spec name (key in worker_spec dict) + + Returns + ------- + set[str] + Set of worker names the scheduler will see for this spec + + Examples + -------- + Regular worker (no "group" key): + >>> cluster.worker_spec = {"0": {"cls": Worker, "options": {}}} + >>> cluster._spec_name_to_worker_names("0") + {"0"} + + Grouped worker (has "group" key): + >>> cluster.worker_spec = { + ... "0": {"cls": MultiWorker, "options": {}, "group": ["-0", "-1", "-2"]} + ... } + >>> cluster._spec_name_to_worker_names("0") + {"0-0", "0-1", "0-2"} + """ + if spec_name not in self.worker_spec: + return set() + + spec = self.worker_spec[spec_name] + if "group" in spec: + # Grouped worker: concatenate spec_name with each suffix + return {spec_name + suffix for suffix in spec["group"]} + else: + # Regular worker: spec name == worker name + return {spec_name} + + def _worker_name_to_spec_name(self, worker_name: str) -> str | None: + """Convert a worker name to its corresponding spec name. + + For regular workers, the worker name equals the spec name. + For grouped workers, extract the spec name prefix from the worker name. + + Parameters + ---------- + worker_name : str + The worker name (as seen by the scheduler) + + Returns + ------- + str | None + The spec name (key in worker_spec dict), or None if not found + + Examples + -------- + Regular worker: + >>> cluster.worker_spec = {"0": {"cls": Worker, "options": {}}} + >>> cluster._worker_name_to_spec_name("0") + "0" + + Grouped worker: + >>> cluster.worker_spec = { + ... "0": {"cls": MultiWorker, "options": {}, "group": ["-0", "-1", "-2"]} + ... } + >>> cluster._worker_name_to_spec_name("0-1") + "0" + + Not found: + >>> cluster._worker_name_to_spec_name("nonexistent") + None + """ + # First check if worker_name is directly a spec name (regular worker) + if worker_name in self.worker_spec: + return worker_name + + # For grouped workers, check each spec to see if this worker belongs to it + for spec_name in self.worker_spec: + worker_names = self._spec_name_to_worker_names(spec_name) + if worker_name in worker_names: + return spec_name + + return None + def new_worker_spec(self) -> dict[str, dict[str, Any]]: """Return name and spec for the next worker spec diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 069cda8900..8829fbc7eb 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -589,3 +589,120 @@ async def test_worker_spec_keys_are_strings(): assert isinstance( key, str ), f"Expected str key, got {type(key).__name__}: {key}" + + +@gen_test() +async def test_spec_name_to_worker_names_regular(): + """Test _spec_name_to_worker_names() with regular (non-grouped) workers.""" + worker_template = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + workers={}, scheduler=scheduler, worker=worker_template, asynchronous=True + ) as cluster: + # Scale to create regular workers + cluster.scale(2) + await cluster + + # Regular workers: spec name == worker name (1:1) + assert cluster._spec_name_to_worker_names("0") == {"0"} + assert cluster._spec_name_to_worker_names("1") == {"1"} + + # Non-existent spec returns empty set + assert cluster._spec_name_to_worker_names("nonexistent") == set() + + +@gen_test() +async def test_spec_name_to_worker_names_grouped(): + """Test _spec_name_to_worker_names() with grouped workers.""" + async with SpecCluster( + workers={ + "0": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-0", "-1", "-2"], + }, + "1": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-a", "-b"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + # Grouped workers: one spec name → multiple worker names + assert cluster._spec_name_to_worker_names("0") == {"0-0", "0-1", "0-2"} + assert cluster._spec_name_to_worker_names("1") == {"1-a", "1-b"} + + +@gen_test() +async def test_worker_name_to_spec_name_regular(): + """Test _worker_name_to_spec_name() with regular (non-grouped) workers.""" + worker_template = {"cls": Worker, "options": {"nthreads": 1}} + async with SpecCluster( + workers={}, scheduler=scheduler, worker=worker_template, asynchronous=True + ) as cluster: + # Scale to create regular workers + cluster.scale(2) + await cluster + + # Regular workers: worker name == spec name + assert cluster._worker_name_to_spec_name("0") == "0" + assert cluster._worker_name_to_spec_name("1") == "1" + + # Non-existent worker returns None + assert cluster._worker_name_to_spec_name("nonexistent") is None + + +@gen_test() +async def test_worker_name_to_spec_name_grouped(): + """Test _worker_name_to_spec_name() with grouped workers.""" + async with SpecCluster( + workers={ + "0": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-0", "-1", "-2"], + }, + "1": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-a", "-b"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + # All workers from group "0" map back to spec "0" + assert cluster._worker_name_to_spec_name("0-0") == "0" + assert cluster._worker_name_to_spec_name("0-1") == "0" + assert cluster._worker_name_to_spec_name("0-2") == "0" + + # All workers from group "1" map back to spec "1" + assert cluster._worker_name_to_spec_name("1-a") == "1" + assert cluster._worker_name_to_spec_name("1-b") == "1" + + # Non-existent worker returns None + assert cluster._worker_name_to_spec_name("nonexistent") is None + + +@gen_test() +async def test_worker_name_to_spec_name_mixed(): + """Test _worker_name_to_spec_name() with mixed regular and grouped workers.""" + async with SpecCluster( + workers={ + "regular": {"cls": Worker, "options": {"nthreads": 1}}, + "grouped": { + "cls": Worker, + "options": {"nthreads": 1}, + "group": ["-0", "-1"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + # Regular worker + assert cluster._worker_name_to_spec_name("regular") == "regular" + + # Grouped workers + assert cluster._worker_name_to_spec_name("grouped-0") == "grouped" + assert cluster._worker_name_to_spec_name("grouped-1") == "grouped" From 1da03477feb46ae0c356169de01a2edaae5d1ca2 Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Tue, 7 Oct 2025 13:46:46 -0700 Subject: [PATCH 4/6] Updated core SpecCluster methods to support grouped workers: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. _update_worker_status(): When a grouped worker fails, removes entire spec 2. _correct_state_internal(): Maps spec names → worker names before retiring workers 3. scale(): Correctly identifies launched specs by mapping worker names → spec names 4. scale_down(): Simplified using helper methods, handles both regular and grouped workers --- distributed/deploy/spec.py | 72 ++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 700f88f0c4..510df6f915 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -371,7 +371,21 @@ async def _correct_state_internal(self) -> None: to_close = set(self.workers) - set(self.worker_spec) if to_close: if self.scheduler.status == Status.running: - await self.scheduler_comm.retire_workers(workers=list(to_close)) + # Map spec names to worker names for retirement + workers_to_retire: list[str] = [] + for spec_name in to_close: + worker_names = self._spec_name_to_worker_names(spec_name) + # Only retire workers that actually exist in the scheduler + scheduler_worker_names = { + w["name"] for w in self.scheduler_info["workers"].values() + } + workers_to_retire.extend(worker_names & scheduler_worker_names) + + if workers_to_retire: + await self.scheduler_comm.retire_workers( + workers=workers_to_retire + ) + tasks = [ asyncio.create_task(self.workers[w].close()) for w in to_close @@ -426,6 +440,14 @@ def f(): self._futures.add(asyncio.ensure_future(self.workers[name].close())) del self.workers[name] + spec_name = self._worker_name_to_spec_name(name) + if spec_name and spec_name in self.worker_spec: + spec = self.worker_spec[spec_name] + + # Grouped worker: remove entire spec so adaptive can recreate it + if "group" in spec: + del self.worker_spec[spec_name] + delay = parse_timedelta( dask.config.get("distributed.deploy.lost-worker-timeout") ) @@ -535,9 +557,17 @@ def scale(self, n=0, memory=None, cores=None): n = max(n, int(math.ceil(cores / self._threads_per_worker()))) if len(self.worker_spec) > n: - not_yet_launched = set(self.worker_spec) - { + # Build set of launched spec names by mapping worker names back to spec names + scheduler_worker_names = { v["name"] for v in self.scheduler_info["workers"].values() } + launched_spec_names = set() + for worker_name in scheduler_worker_names: + spec_name = self._worker_name_to_spec_name(worker_name) + if spec_name: + launched_spec_names.add(spec_name) + + not_yet_launched = set(self.worker_spec) - launched_spec_names while len(self.worker_spec) > n and not_yet_launched: del self.worker_spec[not_yet_launched.pop()] @@ -689,22 +719,28 @@ def _supports_scaling(self): return bool(self.new_spec) async def scale_down(self, workers: Iterable[str]) -> None: - """Scale down by removing worker specs.""" - # We may have groups, if so, map worker addresses to job names - if not all(w in self.worker_spec for w in workers): - mapping = {} - for name, spec in self.worker_spec.items(): - if "group" in spec: - for suffix in spec["group"]: - mapping[str(name) + suffix] = name - else: - mapping[name] = name - - workers = {mapping.get(w, w) for w in workers} - - for w in workers: - if w in self.worker_spec: - del self.worker_spec[w] + """Scale down by removing worker specs. + + Parameters + ---------- + workers : Iterable[str] + Worker names (as seen by the scheduler) to scale down + """ + # Map worker names to spec names (handles both regular and grouped workers) + spec_names_to_remove = set() + for worker_name in workers: + # First check if it's directly a spec name (for backward compatibility) + if worker_name in self.worker_spec: + spec_names_to_remove.add(worker_name) + else: + # Otherwise, map worker name to spec name + spec_name = self._worker_name_to_spec_name(worker_name) + if spec_name: + spec_names_to_remove.add(spec_name) + + for spec_name in spec_names_to_remove: + if spec_name in self.worker_spec: + del self.worker_spec[spec_name] await self scale_up = scale # backwards compatibility From eef6e1684de0b558081eb0998c1e516ce7adcae6 Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Tue, 7 Oct 2025 14:54:47 -0700 Subject: [PATCH 5/6] Fix SpecCluster grouped worker scaling semantics to be conservative MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes scale() behavior for grouped workers to use conservative (floor) rounding instead of ceiling, ensuring resource limits are never exceeded: - scale(n): Rounds DOWN to complete specs (e.g., scale(5) with 2-worker specs → 4 workers, not 6) - scale(memory/cores): Uses floor division for grouped workers to stay under limits - Special case: When scaling from 0 workers, creates at least 1 spec to prevent deadlock (e.g., scale(1) with 2-worker specs → 2 workers) This makes all scaling parameters consistent ("at most X") and prevents resource overcommitment, OOM, and CPU oversubscription. Updated scale() docstring with comprehensive documentation of conservative behavior and updated tests to reflect new semantics. --- distributed/deploy/spec.py | 181 +++++++++++++- distributed/deploy/tests/test_spec_cluster.py | 230 +++++++++++++++++- 2 files changed, 392 insertions(+), 19 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 510df6f915..d2586a0ac0 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -523,17 +523,28 @@ async def __aenter__(self): raise def _threads_per_worker(self) -> int: - """Return the number of threads per worker for new workers""" + """Return the number of threads per worker for new workers. + + For grouped workers, this returns the threads per individual worker + (total spec threads divided by number of workers in the group). + """ if not self.new_spec: # pragma: no cover raise ValueError("To scale by cores= you must specify cores per worker") for name in ["nthreads", "ncores", "threads", "cores"]: with suppress(KeyError): - return self.new_spec["options"][name] + total_threads = self.new_spec["options"][name] + # For grouped workers, divide by number of workers in the group + workers_per_spec = self._workers_per_spec(self.new_spec) + return total_threads // workers_per_spec raise RuntimeError("unreachable") def _memory_per_worker(self) -> int: - """Return the memory limit per worker for new workers""" + """Return the memory limit per worker for new workers. + + For grouped workers, this returns the memory per individual worker + (total spec memory divided by number of workers in the group). + """ if not self.new_spec: # pragma: no cover raise ValueError( "to scale by memory= your worker definition must include a " @@ -542,21 +553,135 @@ def _memory_per_worker(self) -> int: for name in ["memory_limit", "memory"]: with suppress(KeyError): - return parse_bytes(self.new_spec["options"][name]) + total_memory = parse_bytes(self.new_spec["options"][name]) + # For grouped workers, divide by number of workers in the group + workers_per_spec = self._workers_per_spec(self.new_spec) + return total_memory // workers_per_spec raise ValueError( "to use scale(memory=...) your worker definition must include a " "memory_limit definition" ) + def _count_workers_in_specs(self) -> int: + """Count total number of workers across all specs. + + For regular workers, each spec = 1 worker. + For grouped workers, each spec = number of group members. + + Returns + ------- + int + Total number of workers that would be created by current worker_spec + """ + total = 0 + for _spec_name, spec in self.worker_spec.items(): + if "group" in spec: + total += len(spec["group"]) + else: + total += 1 + return total + + def _workers_per_spec(self, spec: dict[str, Any]) -> int: + """Get number of workers a single spec will create. + + Parameters + ---------- + spec : dict + Worker specification dict + + Returns + ------- + int + Number of workers this spec creates (1 for regular, len(group) for grouped) + """ + if "group" in spec: + return len(spec["group"]) + return 1 + def scale(self, n=0, memory=None, cores=None): + """Scale cluster to a target number of workers or resource level. + + Parameters + ---------- + n : int, optional + Target maximum number of workers. For grouped workers, rounds down to + complete specs. Default is 0. + memory : str, optional + Target total memory (e.g., "10 GB"). Scales conservatively - will NOT + exceed this limit. For grouped workers, rounds down to complete specs. + cores : int, optional + Target total cores/threads. Scales conservatively - will NOT exceed + this limit. For grouped workers, rounds down to complete specs. + + Notes + ----- + **All scaling is conservative (rounds down to number of complete specs):** + - Ensures limits are not exceeded (except special case below) + - Prevents resource overcommitment and surprises + - For grouped workers, may get fewer workers than requested + + **Special case - minimum viability:** + - If target > 0 and current workers = 0, creates at least 1 spec + - Prevents deadlock where no workers can be created + - Example: `scale(1)` with 2-worker specs → 1 spec = 2 workers (exceeds target!) + + **Examples:** + - `scale(5)` with 2-worker specs → 2 specs = 4 workers (not 5) + - `scale(1)` with 2-worker specs → 1 spec = 2 workers (special case!) + - `scale(memory="6GB")` with 4GB/spec → 1 spec = 4GB (not 8GB) + - `scale(cores=10)` with 4 cores/spec → 2 specs = 8 cores (not 12) + + **Why conservative?** + - User expectation: "at most N" not "at least N" + - Safety: prevents OOM, CPU oversubscription + - Consistency: all parameters use same rounding + + Examples + -------- + Regular workers (1 worker per spec): + >>> cluster.scale(5) # Creates 5 specs = 5 workers + + Grouped workers (2 workers per spec): + >>> cluster.scale(5) # Creates 2 specs = 4 workers (conservative) + >>> cluster.scale(6) # Creates 3 specs = 6 workers (exact match) + >>> cluster.scale(memory="6GB") # With 4GB/spec: creates 1 spec = 4GB + """ if memory is not None: - n = max(n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker()))) + # For grouped workers, scale by complete specs to avoid exceeding limit + # Use floor division to be conservative (never exceed requested memory) + if self.new_spec and "group" in self.new_spec: + memory_per_spec = self._memory_per_worker() * self._workers_per_spec( + self.new_spec + ) + target_specs = int(parse_bytes(memory) // memory_per_spec) + n = max(n, target_specs * self._workers_per_spec(self.new_spec)) + else: + # Regular workers: use ceiling (match old behavior) + n = max( + n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker())) + ) if cores is not None: - n = max(n, int(math.ceil(cores / self._threads_per_worker()))) + # For grouped workers, scale by complete specs to avoid exceeding limit + # Use floor division to be conservative (never exceed requested cores) + if self.new_spec and "group" in self.new_spec: + cores_per_spec = self._threads_per_worker() * self._workers_per_spec( + self.new_spec + ) + target_specs = int(cores // cores_per_spec) + n = max(n, target_specs * self._workers_per_spec(self.new_spec)) + else: + # Regular workers: use ceiling (match old behavior) + n = max(n, int(math.ceil(cores / self._threads_per_worker()))) + + # n is the target number of workers (not specs) + # For grouped workers, we need to scale by specs, where each spec creates multiple workers - if len(self.worker_spec) > n: + current_worker_count = self._count_workers_in_specs() + + # Scale down if we have too many workers + if current_worker_count > n: # Build set of launched spec names by mapping worker names back to spec names scheduler_worker_names = { v["name"] for v in self.scheduler_info["workers"].values() @@ -568,15 +693,45 @@ def scale(self, n=0, memory=None, cores=None): launched_spec_names.add(spec_name) not_yet_launched = set(self.worker_spec) - launched_spec_names - while len(self.worker_spec) > n and not_yet_launched: - del self.worker_spec[not_yet_launched.pop()] - while len(self.worker_spec) > n: - self.worker_spec.popitem() + # Remove unlaunched specs first + while current_worker_count > n and not_yet_launched: + spec_name = not_yet_launched.pop() + spec = self.worker_spec[spec_name] + workers_in_spec = self._workers_per_spec(spec) + del self.worker_spec[spec_name] + current_worker_count -= workers_in_spec + + # Remove launched specs if still over target + while current_worker_count > n and self.worker_spec: + spec_name, spec = self.worker_spec.popitem() + workers_in_spec = self._workers_per_spec(spec) + current_worker_count -= workers_in_spec + # Scale up if we need more workers if self.status not in (Status.closing, Status.closed): - while len(self.worker_spec) < n: - self.worker_spec.update(self.new_worker_spec()) + while current_worker_count < n: + # For grouped workers, check if adding next spec would exceed target + # This ensures we never exceed the requested worker count (conservative scaling) + workers_in_next_spec = ( + self._workers_per_spec(self.new_spec) if self.new_spec else 1 + ) + if current_worker_count + workers_in_next_spec > n: + # Don't add spec if it would exceed target + # Exception: if we have 0 workers and n > 0, add at least one spec + # This ensures we can always create workers when requested (avoids deadlock) + if current_worker_count == 0 and n > 0: + pass # Add the spec even if it exceeds target + else: + break + + new_spec_dict = self.new_worker_spec() + self.worker_spec.update(new_spec_dict) + # Get the spec we just added to count its workers + spec_name = list(new_spec_dict.keys())[0] + spec = new_spec_dict[spec_name] + workers_in_spec = self._workers_per_spec(spec) + current_worker_count += workers_in_spec self.loop.add_callback(self._correct_state) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 8829fbc7eb..6804a9e9f5 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -435,7 +435,8 @@ async def test_MultiWorker(): ) as cluster: s = cluster.scheduler async with Client(cluster, asynchronous=True) as client: - cluster.scale(2) + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) await cluster assert len(cluster.worker_spec) == 2 await client.wait_for_workers(4) @@ -448,20 +449,26 @@ async def test_MultiWorker(): workers_line = re.search("(Workers.+)", cluster._repr_html_()).group(1) assert re.match("Workers.*4", workers_line) - cluster.scale(1) + # Scale to 2 workers (1 spec with 2 workers) + cluster.scale(2) await cluster assert len(s.workers) == 2 + # Scale to 6 GB memory: 4GB per spec, conservatively scales to 1 spec = 4GB + # (rounds DOWN to avoid exceeding 6GB limit) cluster.scale(memory="6GB") await cluster - assert len(cluster.worker_spec) == 2 - assert len(s.workers) == 4 + assert len(cluster.worker_spec) == 1 + assert len(s.workers) == 2 assert cluster.plan == {ws.name for ws in s.workers.values()} + # Scale to 10 cores: 4 cores per spec, conservatively scales to 2 specs = 8 cores + # (rounds DOWN to avoid exceeding 10 cores limit) cluster.scale(cores=10) await cluster - assert len(cluster.workers) == 3 + assert len(cluster.workers) == 2 + # Adaptive with maximum=4 means maximum 4 workers = 2 specs adapt = cluster.adapt(minimum=0, maximum=4) for _ in range(adapt.wait_count): # relax down to 0 workers @@ -469,9 +476,64 @@ async def test_MultiWorker(): await cluster assert not s.workers + # Submit work - adaptive will request workers based on workload future = client.submit(lambda x: x + 1, 10) await future - assert len(cluster.workers) == 1 + # With 2-worker specs and conservative scaling with minimum viability: + # When adaptive requests 1 worker, scale creates 1 spec = 2 workers + # (special case: creates at least 1 spec when scaling from 0) + assert len(cluster.workers) == 1 # 1 spec created + + +@gen_test() +async def test_grouped_worker_death_removes_spec(): + """Test that when a single worker in a group dies, the entire spec is removed.""" + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) + await cluster + assert len(cluster.worker_spec) == 2 + await client.wait_for_workers(4) + + # Get the spec names + spec_names = list(cluster.worker_spec.keys()) + assert len(spec_names) == 2 + + # Get worker names for the first spec + first_spec_name = spec_names[0] + worker_names = cluster._spec_name_to_worker_names(first_spec_name) + assert len(worker_names) == 2 + + # Kill one worker from the first group + worker_to_kill = list(worker_names)[0] + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_to_kill + ][0] + + # Simulate abrupt worker death (like HPC pre-emption) + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test" + ) + + # Wait for lost-worker-timeout + await asyncio.sleep(0.2) + + # The entire spec should be removed (not just the one worker) + assert first_spec_name not in cluster.worker_spec + # The other spec should still exist + assert spec_names[1] in cluster.worker_spec @gen_cluster(client=True, nthreads=[]) @@ -706,3 +768,159 @@ async def test_worker_name_to_spec_name_mixed(): # Grouped workers assert cluster._worker_name_to_spec_name("grouped-0") == "grouped" assert cluster._worker_name_to_spec_name("grouped-1") == "grouped" + + +@gen_test() +async def test_unexpected_close_whole_worker_group(): + """Test that when all workers in a group die abruptly, the spec is removed and recreated.""" + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) + await cluster + assert len(cluster.worker_spec) == 2 + await client.wait_for_workers(4) + + # Get the spec names + spec_names = list(cluster.worker_spec.keys()) + assert len(spec_names) == 2 + + # Get all worker names for the first spec + first_spec_name = spec_names[0] + worker_names = cluster._spec_name_to_worker_names(first_spec_name) + assert len(worker_names) == 2 + + # Kill all workers from the first group (simulate HPC job kill) + for worker_name in worker_names: + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_name + ][0] + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test" + ) + + # Wait for lost-worker-timeout + await asyncio.sleep(0.2) + + # The entire spec should be removed + assert first_spec_name not in cluster.worker_spec + # The other spec should still exist + assert spec_names[1] in cluster.worker_spec + # Should have 1 spec remaining + assert len(cluster.worker_spec) == 1 + + # With adaptive enabled (minimum=4 workers), the cluster should recreate the missing spec + cluster.adapt(minimum=4, maximum=4) + await client.wait_for_workers(4) + + # Should have 2 specs again (but with a new spec name for the recreated one) + assert len(cluster.worker_spec) == 2 + # Old spec name should not exist + assert first_spec_name not in cluster.worker_spec + # Should have 4 workers total + assert len(cluster.scheduler.workers) == 4 + + +@gen_test() +async def test_scale_down_with_grouped_workers(): + """Test that scale_down correctly maps worker names to spec names.""" + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + # Scale to 4 workers (2 specs with 2 workers each) + cluster.scale(4) + await cluster + assert len(cluster.worker_spec) == 2 + + # Get spec names + spec_names = list(cluster.worker_spec.keys()) + first_spec_name = spec_names[0] + + # Get worker names for the first spec + worker_names = cluster._spec_name_to_worker_names(first_spec_name) + worker_names_list = list(worker_names) + + # Call scale_down with actual worker names (what scheduler knows) + await cluster.scale_down(worker_names_list) + + # The first spec should be removed + assert first_spec_name not in cluster.worker_spec + # The second spec should still exist + assert spec_names[1] in cluster.worker_spec + # Should have 1 spec and 2 workers left + assert len(cluster.worker_spec) == 1 + + +@gen_test() +async def test_mixed_regular_and_grouped_workers(): + """Test cluster with both regular and grouped worker specs.""" + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + workers={ + "regular-1": {"cls": Worker, "options": {"nthreads": 2}}, + "regular-2": {"cls": Worker, "options": {"nthreads": 2}}, + "grouped": { + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2}, + "group": ["-0", "-1"], + }, + }, + scheduler=scheduler, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Should have 3 specs, 4 workers total (2 regular + 2 grouped) + await client.wait_for_workers(4) + assert len(cluster.worker_spec) == 3 + + # Test regular worker failure - spec should remain + regular_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == "regular-1" + ][0] + await cluster.scheduler.remove_worker( + address=regular_addr, close=False, stimulus_id="test" + ) + await asyncio.sleep(0.2) + + # Regular worker spec should still exist (cluster can recreate it) + assert "regular-1" in cluster.worker_spec + assert len(cluster.worker_spec) == 3 + + # Test grouped worker failure - entire spec should be removed + grouped_worker_names = cluster._spec_name_to_worker_names("grouped") + one_grouped_worker = list(grouped_worker_names)[0] + grouped_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == one_grouped_worker + ][0] + await cluster.scheduler.remove_worker( + address=grouped_addr, close=False, stimulus_id="test" + ) + await asyncio.sleep(0.2) + + # Grouped spec should be removed entirely + assert "grouped" not in cluster.worker_spec + # Regular specs should still exist + assert "regular-1" in cluster.worker_spec + assert "regular-2" in cluster.worker_spec + assert len(cluster.worker_spec) == 2 From 6901bc1b0bb602faa64db01b62841a1fcf67c4c0 Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Tue, 7 Oct 2025 15:46:56 -0700 Subject: [PATCH 6/6] Fix grouped worker spec removal in SpecCluster MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `_update_worker_status()` method was checking worker names (e.g., "2-0") against `self.workers` which is keyed by spec names (e.g., "2"). This caused the grouped worker spec removal code to never execute. The bug manifested when: 1. Spec "0" is removed (leaving spec "1") 2. Adaptive scaling creates new spec "2" 3. Attempting to remove spec "2" fails silently Fix: - Map worker name → spec name first using `_worker_name_to_spec_name()` - Use spec name to access `self.workers[spec_name]` for grouped workers - Properly close and remove MultiWorker instances before removing specs Testing: - Added `test_grouped_worker_spec_removal_multiple_rounds` to validate spec removal works across multiple rounds with different spec names also checked that all usages of self.workers used correct keys (spec names) --- distributed/deploy/spec.py | 47 +++++--- distributed/deploy/tests/test_spec_cluster.py | 105 ++++++++++++++++++ 2 files changed, 136 insertions(+), 16 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index d2586a0ac0..963e526316 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -148,6 +148,11 @@ class does handle all of the logic around asynchronously cleanly setting up For **regular workers**: spec name == worker name (one-to-one mapping) For **grouped workers**: one spec name → multiple worker names (one-to-many mapping) + **Important**: The ``self.workers`` dict is always keyed by **spec names** (not worker + names), mapping to Worker class instances. When accessing this dict with + a worker name from the scheduler, you must first map it to a spec name using + ``_worker_name_to_spec_name()``. + Grouped Workers --------------- A single spec entry can generate multiple Dask workers by including a ``"group"`` @@ -429,24 +434,34 @@ def _update_worker_status(self, op, msg): name = self.scheduler_info["workers"][msg]["name"] def f(): - if ( - name in self.workers - and msg not in self.scheduler_info["workers"] - and not any( - d["name"] == name - for d in self.scheduler_info["workers"].values() - ) - ): - self._futures.add(asyncio.ensure_future(self.workers[name].close())) - del self.workers[name] - + # Find the spec this worker belongs to spec_name = self._worker_name_to_spec_name(name) - if spec_name and spec_name in self.worker_spec: - spec = self.worker_spec[spec_name] - # Grouped worker: remove entire spec so adaptive can recreate it - if "group" in spec: - del self.worker_spec[spec_name] + # Check if worker/spec is still missing (not re-registered) + if msg not in self.scheduler_info["workers"] and not any( + d["name"] == name for d in self.scheduler_info["workers"].values() + ): + # For regular workers: close and remove from self.workers + if spec_name and spec_name == name and name in self.workers: + self._futures.add( + asyncio.ensure_future(self.workers[name].close()) + ) + del self.workers[name] + + # For grouped workers: remove the entire spec + if spec_name and spec_name in self.worker_spec: + spec = self.worker_spec[spec_name] + if "group" in spec: + # Close the MultiWorker instance + if spec_name in self.workers: + self._futures.add( + asyncio.ensure_future( + self.workers[spec_name].close() + ) + ) + del self.workers[spec_name] + # Remove the spec so adaptive can recreate it + del self.worker_spec[spec_name] delay = parse_timedelta( dask.config.get("distributed.deploy.lost-worker-timeout") diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 6804a9e9f5..8368fe48a9 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -536,6 +536,111 @@ async def test_grouped_worker_death_removes_spec(): assert spec_names[1] in cluster.worker_spec +@gen_test() +async def test_grouped_worker_spec_removal_multiple_rounds(): + """Test that spec removal works correctly for multiple rounds with different spec names. + + This test ensures that the spec removal mechanism in _update_worker_status() correctly: + 1. Maps worker names to spec names (e.g., "2-0" -> "2") + 2. Closes and removes the MultiWorker instance from self.workers + 3. Removes the spec from worker_spec + 4. Works for any spec name, not just "0" + + This catches bugs where worker names ("0-0") were incorrectly checked against + self.workers keys (which are spec names like "0"). + """ + with dask.config.set({"distributed.deploy.lost-worker-timeout": "100ms"}): + async with SpecCluster( + scheduler=scheduler, + worker={ + "cls": MultiWorker, + "options": {"n": 2, "nthreads": 2, "memory_limit": "2 GB"}, + "group": ["-0", "-1"], + }, + asynchronous=True, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + # Scale to 4 workers (2 specs) + cluster.scale(4) + await cluster + await client.wait_for_workers(4) + + # Initial state: 2 specs, 2 MultiWorker instances + assert len(cluster.worker_spec) == 2 + assert len(cluster.workers) == 2 + initial_specs = set(cluster.worker_spec.keys()) + + # Round 1: Remove spec "0" + spec_to_remove = "0" + assert spec_to_remove in cluster.worker_spec + assert spec_to_remove in cluster.workers + + worker_names = cluster._spec_name_to_worker_names(spec_to_remove) + worker_to_kill = list(worker_names)[0] + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_to_kill + ][0] + + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test-round-1" + ) + await asyncio.sleep(0.2) + + # Verify spec "0" is completely removed + assert spec_to_remove not in cluster.worker_spec + assert ( + spec_to_remove not in cluster.workers + ) # MultiWorker instance removed + assert len(cluster.worker_spec) == 1 + assert len(cluster.workers) == 1 + + # Scale back up to create a new spec (will be "2" since "0" was removed) + cluster.scale(4) + await client.wait_for_workers(4) + + # Should have 2 specs again, but with different names + assert len(cluster.worker_spec) == 2 + assert len(cluster.workers) == 2 + current_specs = set(cluster.worker_spec.keys()) + + # Specs should be "1" and "2" (not "0") + assert "0" not in current_specs + assert "1" in current_specs + assert "2" in current_specs + + # Round 2: Remove spec "2" (this would fail with the old buggy code) + spec_to_remove = "2" + assert spec_to_remove in cluster.worker_spec + assert spec_to_remove in cluster.workers + + worker_names = cluster._spec_name_to_worker_names(spec_to_remove) + worker_to_kill = list(worker_names)[0] + worker_addr = [ + addr + for addr, ws in cluster.scheduler.workers.items() + if ws.name == worker_to_kill + ][0] + + await cluster.scheduler.remove_worker( + address=worker_addr, close=False, stimulus_id="test-round-2" + ) + await asyncio.sleep(0.2) + + # Verify spec "2" is completely removed + assert spec_to_remove not in cluster.worker_spec + assert ( + spec_to_remove not in cluster.workers + ) # MultiWorker instance removed + assert len(cluster.worker_spec) == 1 + assert len(cluster.workers) == 1 + + # Only spec "1" should remain + assert list(cluster.worker_spec.keys()) == ["1"] + assert list(cluster.workers.keys()) == ["1"] + + @gen_cluster(client=True, nthreads=[]) async def test_run_spec(c, s): workers = await run_spec(worker_spec, s.address)