From 43b899e17d63eed1fdeedaf8ecf0a2fb7e42fb40 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 11:30:54 +0200 Subject: [PATCH 1/8] Rewrite MPFuture to use native SharedMemory --- hivemind/utils/mpfuture.py | 143 +++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 62 deletions(-) diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py index c3192e926..a0109787c 100644 --- a/hivemind/utils/mpfuture.py +++ b/hivemind/utils/mpfuture.py @@ -9,17 +9,14 @@ from concurrent.futures import InvalidStateError from contextlib import nullcontext from enum import Enum, auto -from multiprocessing.reduction import ForkingPickler +from multiprocessing import shared_memory from typing import Any, Callable, Dict, Generic, Optional, TypeVar from weakref import ref -import torch # used for py3.7-compatible shared memory - from hivemind.utils.logging import get_logger logger = get_logger(__name__) -torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system")) # flavour types ResultType = TypeVar("ResultType") @@ -28,34 +25,6 @@ TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} -class SharedBytes: - """ - A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes. - - Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes. - The chunks are deallocated by the garbage collector, - when it detects that all processes no longer use any bytes from this chunk. - """ - - _lock = mp.Lock() - _pid: Optional[PID] = None - _buffer: Optional[torch.Tensor] = None - _index: int = 0 - - @classmethod - def next(cls) -> torch.Tensor: - """Create another shared byte value, represented as a scalar uint8 tensor""" - with cls._lock: - if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer): - buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16)) - cls._pid = os.getpid() - cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_() - cls._index = 0 - - cls._index += 1 - return cls._buffer[cls._index - 1] - - class UpdateType(Enum): RESULT = auto() EXCEPTION = auto() @@ -90,7 +59,11 @@ def __init__(self, *, use_lock: bool = True): self._maybe_initialize_mpfuture_backend() self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int - self._shared_state_code = SharedBytes.next() + + # Create a dedicated 1-byte shared memory for this future's state + self._shared_memory = shared_memory.SharedMemory(create=True, size=1) + self._shared_state_code = memoryview(self._shared_memory.buf) + self._shared_memory_name = self._shared_memory.name self._state_cache: Dict[State, State] = {} # mapping from global to cached local future used that makes updates immediately # available on setter side; dictionary-based cache works because future can visit any state at most once @@ -111,14 +84,13 @@ def __init__(self, *, use_lock: bool = True): @property def _state(self) -> State: - shared_state = ALL_STATES[self._shared_state_code.item()] + shared_state = ALL_STATES[self._shared_state_code[0]] return self._state_cache.get(shared_state, shared_state) @_state.setter def _state(self, new_state: State): - with torch.inference_mode(): - self._shared_state_code[...] = ALL_STATES.index(new_state) - if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set(): + self._shared_state_code[0] = ALL_STATES.index(new_state) + if new_state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set(): self._set_event_threadsafe() def _set_event_threadsafe(self): @@ -164,7 +136,6 @@ def reset_backend(): MPFuture._active_pid = None MPFuture._initialization_lock = mp.Lock() MPFuture._update_lock = mp.Lock() - SharedBytes._lock = mp.Lock() @classmethod def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection): @@ -176,24 +147,30 @@ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection) uid, update_type, payload = receiver_pipe.recv() future = None - future_ref = cls._active_futures.pop(uid, None) + future_ref = cls._active_futures.get(uid) if future_ref is not None: future = future_ref() if future is None: # The MPFuture instance is already destroyed in this process # (the caller is not interested in the result) + cls._active_futures.pop(uid, None) # Clean up the stale reference continue + + # Process the update and set the corresponding state if update_type == UpdateType.RESULT: future.set_result(payload) + future._state = base.FINISHED elif update_type == UpdateType.EXCEPTION: future.set_exception(payload) + future._state = base.FINISHED elif update_type == UpdateType.CANCEL: future.cancel() + future._state = base.CANCELLED else: raise RuntimeError(f"Received unexpected update type {update_type}") except (BrokenPipeError, EOFError, ConnectionError): - logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})") + logger.debug(f"Update pipe was shut down unexpectedly (pid={pid})") except Exception as e: logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})") @@ -212,6 +189,8 @@ def set_result(self, result: ResultType): elif self._state in TERMINAL_STATES: raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})") else: + # Don't update shared state immediately in subprocess - let the origin process do it + # This prevents race condition where shared state says "finished" but result isn't ready yet self._state_cache[self._state], self._result = base.FINISHED, result self._send_update(UpdateType.RESULT, result) @@ -222,6 +201,7 @@ def set_exception(self, exception: Optional[BaseException]): elif self._state in TERMINAL_STATES: raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})") else: + # Don't update shared state immediately in subprocess - let the origin process do it self._state_cache[self._state], self._exception = base.FINISHED, exception self._send_update(UpdateType.EXCEPTION, exception) @@ -232,6 +212,7 @@ def cancel(self) -> bool: elif self._state in [base.RUNNING, base.FINISHED]: return False else: + # Don't update shared state immediately in subprocess - let the origin process do it self._state_cache[self._state] = base.CANCELLED self._send_update(UpdateType.CANCEL) return True @@ -248,25 +229,32 @@ def set_running_or_notify_cancel(self): ) def result(self, timeout: Optional[float] = None) -> ResultType: - if self._state not in TERMINAL_STATES: - if os.getpid() != self._origin_pid: + if os.getpid() != self._origin_pid: + # Non-origin process: check shared state and return cached result + if self._state not in TERMINAL_STATES: raise RuntimeError("Only the process that created MPFuture can await result") - return super().result(timeout) - elif self._state == base.CANCELLED: - raise base.CancelledError() - elif self._exception: - raise self._exception + elif self._state == base.CANCELLED: + raise base.CancelledError() + elif self._exception: + raise self._exception + else: + return self._result else: - return self._result + # Origin process: use parent's result() method which properly waits for completion + # The parent class handles the waiting and state management correctly + return super().result(timeout) def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]: - if self._state not in TERMINAL_STATES: - if os.getpid() != self._origin_pid: + if os.getpid() != self._origin_pid: + # Non-origin process: check shared state and return cached exception + if self._state not in TERMINAL_STATES: raise RuntimeError("Only the process that created MPFuture can await exception") + elif self._state == base.CANCELLED: + raise base.CancelledError() + return self._exception + else: + # Origin process: always use parent's exception() method which properly waits return super().exception(timeout) - elif self._state == base.CANCELLED: - raise base.CancelledError() - return self._exception def done(self) -> bool: return self._state in TERMINAL_STATES @@ -292,15 +280,33 @@ def __await__(self): raise asyncio.CancelledError() def __del__(self): - if getattr(self, "_origin_pid", None) == os.getpid() and MPFuture._active_futures is not None: + is_origin_process = getattr(self, "_origin_pid", None) == os.getpid() + + if is_origin_process and MPFuture._active_futures is not None: MPFuture._active_futures.pop(self._uid, None) if getattr(self, "_aio_event", None): self._aio_event.set() + # Clean up shared memory if we're the origin process + if is_origin_process and hasattr(self, "_shared_memory"): + try: + self._shared_memory.unlink() # Remove from system + self._shared_memory.close() # Close our handle + except (FileNotFoundError, AttributeError, BufferError): + pass # Already cleaned up or not accessible + + # Release the memoryview reference + if hasattr(self, "_shared_state_code"): + try: + self._shared_state_code.release() + except (AttributeError, BufferError): + pass # already released or not a releasable view + def __getstate__(self): return dict( _sender_pipe=self._sender_pipe, - _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(), + _shared_state_code=bytes(self._shared_state_code), + _shared_memory_name=self._shared_memory_name, _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock, @@ -310,14 +316,27 @@ def __getstate__(self): def __setstate__(self, state): self._sender_pipe = state["_sender_pipe"] + self._shared_memory_name = state.get("_shared_memory_name") + try: - self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"]) - except RuntimeError: - # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer, - # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object") - # here since it is not possible to connect to this buffer anymore. To address this, we just replace - # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore. - self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8) + # Try to reconnect to the shared memory + if self._shared_memory_name: + try: + # Reconnect to existing shared memory (don't store reference since we don't own it) + reconnected_mem = shared_memory.SharedMemory(name=self._shared_memory_name) + self._shared_state_code = memoryview(reconnected_mem.buf) + except FileNotFoundError: + # Shared memory no longer exists, fall back to local copy + state_bytes = state["_shared_state_code"] + self._shared_state_code = memoryview(bytearray(state_bytes)) + else: + # No shared memory name available, use local copy + state_bytes = state["_shared_state_code"] + self._shared_state_code = memoryview(bytearray(state_bytes)) + except (RuntimeError, FileNotFoundError): + # If the shared memory is no longer available, fall back to local copy + self._shared_state_code = memoryview(bytearray([ALL_STATES.index(base.PENDING)])) + self._origin_pid, self._uid = state["_origin_pid"], state["_uid"] self._result, self._exception = state["_result"], state["_exception"] self._use_lock = state["_use_lock"] From 27240f81acab0d5ba998b31baa99001c3d552134 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 12:03:28 +0200 Subject: [PATCH 2/8] Remove setting of HIVEMIND_MEMORY_SHARING_STRATEGY --- .github/workflows/run-tests.yml | 3 --- modal_ci.py | 15 ++++----------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index d1c3fe3ff..bef6c6fba 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -41,7 +41,6 @@ jobs: - name: Test run: | cd tests - export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor pytest --durations=0 --durations-min=1.0 -v build_and_test_p2pd: runs-on: ubuntu-latest @@ -72,7 +71,6 @@ jobs: - name: Test run: | cd tests - export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor pytest -k "p2p" -v codecov_in_develop_mode: runs-on: ubuntu-latest @@ -101,7 +99,6 @@ jobs: pip install -e . --no-use-pep517 - name: Test run: | - export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor pytest --cov hivemind --cov-config=pyproject.toml -v tests - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/modal_ci.py b/modal_ci.py index 43b977ee7..84d1f9e55 100644 --- a/modal_ci.py +++ b/modal_ci.py @@ -81,15 +81,10 @@ def setup_environment(*, build_p2pd=False): subprocess.run(install_cmd, check=True) - environment = os.environ.copy() - environment["HIVEMIND_MEMORY_SHARING_STRATEGY"] = "file_descriptor" - - return environment - @app.function(image=image, timeout=600, cpu=8, memory=8192) def run_tests(): - environment = setup_environment(build_p2pd=False) + setup_environment(build_p2pd=False) subprocess.run( [ @@ -102,13 +97,12 @@ def run_tests(): "tests", ], check=True, - env=environment, ) @app.function(image=image, timeout=900, cpu=8, memory=8192, secrets=[codecov_secret]) def run_codecov(): - environment = setup_environment(build_p2pd=False) + setup_environment(build_p2pd=False) subprocess.run( [ @@ -120,10 +114,10 @@ def run_codecov(): "tests", ], check=True, - env=environment, ) # Forward GitHub Actions environment variables to the codecov command + environment = os.environ.copy() environment.update( { "CODECOV_TOKEN": os.environ["CODECOV_TOKEN"], @@ -149,7 +143,7 @@ def run_codecov(): @app.function(image=image_with_golang, timeout=600, cpu=1, memory=4096) def build_and_test_p2pd(): - environment = setup_environment(build_p2pd=True) + setup_environment(build_p2pd=True) subprocess.run( [ @@ -160,5 +154,4 @@ def build_and_test_p2pd(): "tests", ], check=True, - env=environment, ) From 3236b13190f6b4667757841f905b0dc5f5e8dc24 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 12:04:44 +0200 Subject: [PATCH 3/8] Make the import of SharedMemory direct --- hivemind/utils/mpfuture.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py index a0109787c..baee6f95e 100644 --- a/hivemind/utils/mpfuture.py +++ b/hivemind/utils/mpfuture.py @@ -9,7 +9,7 @@ from concurrent.futures import InvalidStateError from contextlib import nullcontext from enum import Enum, auto -from multiprocessing import shared_memory +from multiprocessing.shared_memory import SharedMemory from typing import Any, Callable, Dict, Generic, Optional, TypeVar from weakref import ref @@ -61,7 +61,7 @@ def __init__(self, *, use_lock: bool = True): self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int # Create a dedicated 1-byte shared memory for this future's state - self._shared_memory = shared_memory.SharedMemory(create=True, size=1) + self._shared_memory = SharedMemory(create=True, size=1) self._shared_state_code = memoryview(self._shared_memory.buf) self._shared_memory_name = self._shared_memory.name self._state_cache: Dict[State, State] = {} @@ -323,7 +323,7 @@ def __setstate__(self, state): if self._shared_memory_name: try: # Reconnect to existing shared memory (don't store reference since we don't own it) - reconnected_mem = shared_memory.SharedMemory(name=self._shared_memory_name) + reconnected_mem = SharedMemory(name=self._shared_memory_name) self._shared_state_code = memoryview(reconnected_mem.buf) except FileNotFoundError: # Shared memory no longer exists, fall back to local copy From f988258f2001dd9618ef1defc12b1d23b505a8ee Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 13:43:37 +0200 Subject: [PATCH 4/8] Use tempfile/mmap instead --- hivemind/utils/mpfuture.py | 89 +++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 39 deletions(-) diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py index baee6f95e..f12e85477 100644 --- a/hivemind/utils/mpfuture.py +++ b/hivemind/utils/mpfuture.py @@ -2,14 +2,15 @@ import asyncio import concurrent.futures._base as base +import mmap import multiprocessing as mp import os +import tempfile import threading import uuid from concurrent.futures import InvalidStateError from contextlib import nullcontext from enum import Enum, auto -from multiprocessing.shared_memory import SharedMemory from typing import Any, Callable, Dict, Generic, Optional, TypeVar from weakref import ref @@ -60,10 +61,13 @@ def __init__(self, *, use_lock: bool = True): self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int - # Create a dedicated 1-byte shared memory for this future's state - self._shared_memory = SharedMemory(create=True, size=1) - self._shared_state_code = memoryview(self._shared_memory.buf) - self._shared_memory_name = self._shared_memory.name + # Create shared state using mmap with temporary file (avoids SharedMemory cleanup issues) + self._temp_file = tempfile.NamedTemporaryFile(delete=False) + self._temp_file.write(bytes([ALL_STATES.index(base.PENDING)])) + self._temp_file.flush() + self._mmap = mmap.mmap(self._temp_file.fileno(), 1) + self._shared_state_code = memoryview(self._mmap) + self._temp_file_name = self._temp_file.name self._state_cache: Dict[State, State] = {} # mapping from global to cached local future used that makes updates immediately # available on setter side; dictionary-based cache works because future can visit any state at most once @@ -146,10 +150,8 @@ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection) break # backend was reset, a new background thread has started uid, update_type, payload = receiver_pipe.recv() - future = None future_ref = cls._active_futures.get(uid) - if future_ref is not None: - future = future_ref() + future = future_ref() if future_ref else None if future is None: # The MPFuture instance is already destroyed in this process @@ -280,33 +282,45 @@ def __await__(self): raise asyncio.CancelledError() def __del__(self): + # Only clean up if we have the necessary attributes to avoid exceptions during teardown + if not hasattr(self, "_origin_pid"): + return + is_origin_process = getattr(self, "_origin_pid", None) == os.getpid() - if is_origin_process and MPFuture._active_futures is not None: + if is_origin_process and MPFuture._active_futures is not None and hasattr(self, "_uid"): MPFuture._active_futures.pop(self._uid, None) if getattr(self, "_aio_event", None): self._aio_event.set() - # Clean up shared memory if we're the origin process - if is_origin_process and hasattr(self, "_shared_memory"): - try: - self._shared_memory.unlink() # Remove from system - self._shared_memory.close() # Close our handle - except (FileNotFoundError, AttributeError, BufferError): - pass # Already cleaned up or not accessible - - # Release the memoryview reference - if hasattr(self, "_shared_state_code"): - try: - self._shared_state_code.release() - except (AttributeError, BufferError): - pass # already released or not a releasable view + # Clean up mmap and temp file if we're the origin process + if is_origin_process: + if hasattr(self, "_shared_state_code"): + try: + self._shared_state_code.release() + except (BufferError, ValueError): + pass + if hasattr(self, "_mmap"): + try: + self._mmap.close() + except (OSError, ValueError): + pass + if hasattr(self, "_temp_file"): + try: + self._temp_file.close() + except (OSError, ValueError): + pass + if hasattr(self, "_temp_file_name"): + try: + os.unlink(self._temp_file_name) + except (OSError, FileNotFoundError): + pass def __getstate__(self): return dict( _sender_pipe=self._sender_pipe, _shared_state_code=bytes(self._shared_state_code), - _shared_memory_name=self._shared_memory_name, + _temp_file_name=getattr(self, "_temp_file_name", None), _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock, @@ -316,26 +330,23 @@ def __getstate__(self): def __setstate__(self, state): self._sender_pipe = state["_sender_pipe"] - self._shared_memory_name = state.get("_shared_memory_name") + self._temp_file_name = state.get("_temp_file_name") + # Try to reconnect to the shared state file try: - # Try to reconnect to the shared memory - if self._shared_memory_name: - try: - # Reconnect to existing shared memory (don't store reference since we don't own it) - reconnected_mem = SharedMemory(name=self._shared_memory_name) - self._shared_state_code = memoryview(reconnected_mem.buf) - except FileNotFoundError: - # Shared memory no longer exists, fall back to local copy - state_bytes = state["_shared_state_code"] - self._shared_state_code = memoryview(bytearray(state_bytes)) + if self._temp_file_name and os.path.exists(self._temp_file_name): + # Reconnect to existing mmap + with open(self._temp_file_name, "r+b") as f: + self._mmap = mmap.mmap(f.fileno(), 1) + self._shared_state_code = memoryview(self._mmap) else: - # No shared memory name available, use local copy + # Fall back to local copy state_bytes = state["_shared_state_code"] self._shared_state_code = memoryview(bytearray(state_bytes)) - except (RuntimeError, FileNotFoundError): - # If the shared memory is no longer available, fall back to local copy - self._shared_state_code = memoryview(bytearray([ALL_STATES.index(base.PENDING)])) + except (OSError, ValueError): + # If mmap fails, fall back to local copy + state_bytes = state["_shared_state_code"] + self._shared_state_code = memoryview(bytearray(state_bytes)) self._origin_pid, self._uid = state["_origin_pid"], state["_uid"] self._result, self._exception = state["_result"], state["_exception"] From ae79958dc4c11572c482ac5ac4c54d0f475ef3d4 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 20:03:30 +0200 Subject: [PATCH 5/8] Rearrange the code --- hivemind/utils/mpfuture.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py index f12e85477..246cb95d3 100644 --- a/hivemind/utils/mpfuture.py +++ b/hivemind/utils/mpfuture.py @@ -11,6 +11,7 @@ from concurrent.futures import InvalidStateError from contextlib import nullcontext from enum import Enum, auto +from multiprocessing.connection import Connection from typing import Any, Callable, Dict, Generic, Optional, TypeVar from weakref import ref @@ -21,7 +22,7 @@ # flavour types ResultType = TypeVar("ResultType") -PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection +PID, UID, State, PipeEnd = int, int, str, Connection ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} @@ -98,6 +99,9 @@ def _state(self, new_state: State): self._set_event_threadsafe() def _set_event_threadsafe(self): + if not self._loop or not self._aio_event: + return + try: running_loop = asyncio.get_running_loop() except RuntimeError: @@ -106,9 +110,7 @@ def _set_event_threadsafe(self): async def _event_setter(): self._aio_event.set() - if self._loop.is_closed(): - return # do nothing, the loop is already closed - elif self._loop.is_running() and running_loop == self._loop: + if self._loop.is_running() and running_loop == self._loop: asyncio.create_task(_event_setter()) elif self._loop.is_running() and running_loop != self._loop: asyncio.run_coroutine_threadsafe(_event_setter(), self._loop) @@ -290,8 +292,6 @@ def __del__(self): if is_origin_process and MPFuture._active_futures is not None and hasattr(self, "_uid"): MPFuture._active_futures.pop(self._uid, None) - if getattr(self, "_aio_event", None): - self._aio_event.set() # Clean up mmap and temp file if we're the origin process if is_origin_process: From dcf2a015d4386eeab4198a665923f08cd0581a9e Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 20:59:01 +0200 Subject: [PATCH 6/8] Attempt to fix P2P errors --- hivemind/p2p/p2p_daemon.py | 32 +++++--- hivemind/p2p/p2p_daemon_bindings/control.py | 81 +++++++++++++-------- 2 files changed, 69 insertions(+), 44 deletions(-) diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py index a6c9484c9..bb85f998d 100644 --- a/hivemind/p2p/p2p_daemon.py +++ b/hivemind/p2p/p2p_daemon.py @@ -684,18 +684,26 @@ def _maddrs_to_str(maddrs: List[Multiaddr]) -> str: async def _read_outputs(self, ready: asyncio.Future) -> None: last_line = None - while True: - line = await self._child.stdout.readline() - if not line: # Stream closed - break - last_line = line.rstrip().decode(errors="ignore") - - self._log_p2pd_message(last_line) - if last_line.startswith("Peer ID:"): - ready.set_result(None) - - if not ready.done(): - ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}")) + try: + while True: + try: + line = await self._child.stdout.readline() + if not line: # Stream closed + break + last_line = line.rstrip().decode(errors="ignore") + + self._log_p2pd_message(last_line) + if last_line.startswith("Peer ID:"): + ready.set_result(None) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed + break + + if not ready.done(): + ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}")) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed during cleanup + pass @staticmethod def _log_p2pd_message(line: str) -> None: diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py index 059267f31..e3928606d 100644 --- a/hivemind/p2p/p2p_daemon_bindings/control.py +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -170,46 +170,63 @@ async def listen(self) -> AsyncIterator["ControlClient"]: yield self async def _read_from_persistent_conn(self, reader: asyncio.StreamReader): - while True: - resp = p2pd_pb.PersistentConnectionResponse() - try: - await read_pbmsg_safe(reader, resp) - except asyncio.IncompleteReadError: - break - - call_id = UUID(bytes=resp.callId) - - if resp.HasField("callUnaryResponse"): - if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"): - self._pending_calls[call_id].set_result(resp.callUnaryResponse.response) - elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"): - remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore")) - self._pending_calls[call_id].set_exception(remote_exc) - else: - logger.debug(f"Received unexpected unary call: {resp}") + try: + while True: + resp = p2pd_pb.PersistentConnectionResponse() + try: + await read_pbmsg_safe(reader, resp) + except asyncio.IncompleteReadError: + break + + call_id = UUID(bytes=resp.callId) - elif resp.HasField("requestHandling"): - handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling)) - self._handler_tasks[call_id] = handler_task + if resp.HasField("callUnaryResponse"): + if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"): + self._pending_calls[call_id].set_result(resp.callUnaryResponse.response) + elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"): + remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore")) + self._pending_calls[call_id].set_exception(remote_exc) + else: + logger.debug(f"Received unexpected unary call: {resp}") - elif call_id in self._handler_tasks and resp.HasField("cancel"): - cancel_task_if_running(self._handler_tasks[call_id]) + elif resp.HasField("requestHandling"): + handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling)) + self._handler_tasks[call_id] = handler_task - elif call_id in self._pending_calls and resp.HasField("daemonError"): - daemon_exc = P2PDaemonError(resp.daemonError.message) - self._pending_calls[call_id].set_exception(daemon_exc) + elif call_id in self._handler_tasks and resp.HasField("cancel"): + cancel_task_if_running(self._handler_tasks[call_id]) - elif call_id in self._pending_calls: - self._pending_calls[call_id].set_result(None) + elif call_id in self._pending_calls and resp.HasField("daemonError"): + daemon_exc = P2PDaemonError(resp.daemonError.message) + self._pending_calls[call_id].set_exception(daemon_exc) - else: - logger.debug(f"Received unexpected response from daemon: {resp}") + elif call_id in self._pending_calls: + self._pending_calls[call_id].set_result(None) + + else: + logger.debug(f"Received unexpected response from daemon: {resp}") + except asyncio.CancelledError: + # Task was cancelled, clean up gracefully + pass async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter): - with closing(writer): + try: while True: - msg = await self._pending_messages.get() - await write_pbmsg(writer, msg) + try: + msg = await self._pending_messages.get() + await write_pbmsg(writer, msg) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed + break + finally: + # Close writer safely, avoiding "Event loop is closed" errors + try: + if not writer.is_closing(): + writer.close() + await writer.wait_closed() + except (RuntimeError, ConnectionError): + # Event loop might be closed or connection already closed + pass async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest): if request.proto not in self.unary_handlers: From eafc0d93d1719dc2dbf18cbf1bafccab058614d3 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 27 Jun 2025 21:07:20 +0200 Subject: [PATCH 7/8] Attempt to fix P2P errors --- hivemind/p2p/p2p_daemon.py | 2 +- hivemind/p2p/p2p_daemon_bindings/control.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py index bb85f998d..f84d94520 100644 --- a/hivemind/p2p/p2p_daemon.py +++ b/hivemind/p2p/p2p_daemon.py @@ -698,7 +698,7 @@ async def _read_outputs(self, ready: asyncio.Future) -> None: except (asyncio.CancelledError, RuntimeError): # Task was cancelled or event loop closed break - + if not ready.done(): ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}")) except (asyncio.CancelledError, RuntimeError): diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py index e3928606d..30b4dbf85 100644 --- a/hivemind/p2p/p2p_daemon_bindings/control.py +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -5,7 +5,7 @@ """ import asyncio -from contextlib import asynccontextmanager, closing +from contextlib import asynccontextmanager from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple from uuid import UUID, uuid4 From 4bb054f29f5528974a78155d39927b18899920a0 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 28 Jun 2025 00:38:44 +0200 Subject: [PATCH 8/8] Simplify the code --- hivemind/p2p/p2p_daemon_bindings/control.py | 11 +++++------ hivemind/utils/asyncio.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py index 30b4dbf85..8ca49fff8 100644 --- a/hivemind/p2p/p2p_daemon_bindings/control.py +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -212,12 +212,11 @@ async def _read_from_persistent_conn(self, reader: asyncio.StreamReader): async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter): try: while True: - try: - msg = await self._pending_messages.get() - await write_pbmsg(writer, msg) - except (asyncio.CancelledError, RuntimeError): - # Task was cancelled or event loop closed - break + msg = await self._pending_messages.get() + await write_pbmsg(writer, msg) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed + pass finally: # Close writer safely, avoiding "Event loop is closed" errors try: diff --git a/hivemind/utils/asyncio.py b/hivemind/utils/asyncio.py index cf761f9ae..da73be689 100644 --- a/hivemind/utils/asyncio.py +++ b/hivemind/utils/asyncio.py @@ -205,6 +205,6 @@ def cancel_task_if_running(task: Optional[asyncio.Task]) -> None: if loop.is_running(): task.cancel() except RuntimeError as e: - # Only ignore event loop closure errors - if "Event loop is closed" not in str(e): + # Ignore event loop closure and missing event loop errors + if "Event loop is closed" not in str(e) and "There is no current event loop" not in str(e): raise