diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index d1c3fe3ff..bef6c6fba 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -41,7 +41,6 @@ jobs: - name: Test run: | cd tests - export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor pytest --durations=0 --durations-min=1.0 -v build_and_test_p2pd: runs-on: ubuntu-latest @@ -72,7 +71,6 @@ jobs: - name: Test run: | cd tests - export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor pytest -k "p2p" -v codecov_in_develop_mode: runs-on: ubuntu-latest @@ -101,7 +99,6 @@ jobs: pip install -e . --no-use-pep517 - name: Test run: | - export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor pytest --cov hivemind --cov-config=pyproject.toml -v tests - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py index a6c9484c9..f84d94520 100644 --- a/hivemind/p2p/p2p_daemon.py +++ b/hivemind/p2p/p2p_daemon.py @@ -684,18 +684,26 @@ def _maddrs_to_str(maddrs: List[Multiaddr]) -> str: async def _read_outputs(self, ready: asyncio.Future) -> None: last_line = None - while True: - line = await self._child.stdout.readline() - if not line: # Stream closed - break - last_line = line.rstrip().decode(errors="ignore") - - self._log_p2pd_message(last_line) - if last_line.startswith("Peer ID:"): - ready.set_result(None) - - if not ready.done(): - ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}")) + try: + while True: + try: + line = await self._child.stdout.readline() + if not line: # Stream closed + break + last_line = line.rstrip().decode(errors="ignore") + + self._log_p2pd_message(last_line) + if last_line.startswith("Peer ID:"): + ready.set_result(None) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed + break + + if not ready.done(): + ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}")) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed during cleanup + pass @staticmethod def _log_p2pd_message(line: str) -> None: diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py index 059267f31..8ca49fff8 100644 --- a/hivemind/p2p/p2p_daemon_bindings/control.py +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -5,7 +5,7 @@ """ import asyncio -from contextlib import asynccontextmanager, closing +from contextlib import asynccontextmanager from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple from uuid import UUID, uuid4 @@ -170,46 +170,62 @@ async def listen(self) -> AsyncIterator["ControlClient"]: yield self async def _read_from_persistent_conn(self, reader: asyncio.StreamReader): - while True: - resp = p2pd_pb.PersistentConnectionResponse() - try: - await read_pbmsg_safe(reader, resp) - except asyncio.IncompleteReadError: - break - - call_id = UUID(bytes=resp.callId) - - if resp.HasField("callUnaryResponse"): - if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"): - self._pending_calls[call_id].set_result(resp.callUnaryResponse.response) - elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"): - remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore")) - self._pending_calls[call_id].set_exception(remote_exc) - else: - logger.debug(f"Received unexpected unary call: {resp}") + try: + while True: + resp = p2pd_pb.PersistentConnectionResponse() + try: + await read_pbmsg_safe(reader, resp) + except asyncio.IncompleteReadError: + break + + call_id = UUID(bytes=resp.callId) - elif resp.HasField("requestHandling"): - handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling)) - self._handler_tasks[call_id] = handler_task + if resp.HasField("callUnaryResponse"): + if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"): + self._pending_calls[call_id].set_result(resp.callUnaryResponse.response) + elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"): + remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore")) + self._pending_calls[call_id].set_exception(remote_exc) + else: + logger.debug(f"Received unexpected unary call: {resp}") - elif call_id in self._handler_tasks and resp.HasField("cancel"): - cancel_task_if_running(self._handler_tasks[call_id]) + elif resp.HasField("requestHandling"): + handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling)) + self._handler_tasks[call_id] = handler_task - elif call_id in self._pending_calls and resp.HasField("daemonError"): - daemon_exc = P2PDaemonError(resp.daemonError.message) - self._pending_calls[call_id].set_exception(daemon_exc) + elif call_id in self._handler_tasks and resp.HasField("cancel"): + cancel_task_if_running(self._handler_tasks[call_id]) - elif call_id in self._pending_calls: - self._pending_calls[call_id].set_result(None) + elif call_id in self._pending_calls and resp.HasField("daemonError"): + daemon_exc = P2PDaemonError(resp.daemonError.message) + self._pending_calls[call_id].set_exception(daemon_exc) - else: - logger.debug(f"Received unexpected response from daemon: {resp}") + elif call_id in self._pending_calls: + self._pending_calls[call_id].set_result(None) + + else: + logger.debug(f"Received unexpected response from daemon: {resp}") + except asyncio.CancelledError: + # Task was cancelled, clean up gracefully + pass async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter): - with closing(writer): + try: while True: msg = await self._pending_messages.get() await write_pbmsg(writer, msg) + except (asyncio.CancelledError, RuntimeError): + # Task was cancelled or event loop closed + pass + finally: + # Close writer safely, avoiding "Event loop is closed" errors + try: + if not writer.is_closing(): + writer.close() + await writer.wait_closed() + except (RuntimeError, ConnectionError): + # Event loop might be closed or connection already closed + pass async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest): if request.proto not in self.unary_handlers: diff --git a/hivemind/utils/asyncio.py b/hivemind/utils/asyncio.py index cf761f9ae..da73be689 100644 --- a/hivemind/utils/asyncio.py +++ b/hivemind/utils/asyncio.py @@ -205,6 +205,6 @@ def cancel_task_if_running(task: Optional[asyncio.Task]) -> None: if loop.is_running(): task.cancel() except RuntimeError as e: - # Only ignore event loop closure errors - if "Event loop is closed" not in str(e): + # Ignore event loop closure and missing event loop errors + if "Event loop is closed" not in str(e) and "There is no current event loop" not in str(e): raise diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py index c3192e926..246cb95d3 100644 --- a/hivemind/utils/mpfuture.py +++ b/hivemind/utils/mpfuture.py @@ -2,60 +2,31 @@ import asyncio import concurrent.futures._base as base +import mmap import multiprocessing as mp import os +import tempfile import threading import uuid from concurrent.futures import InvalidStateError from contextlib import nullcontext from enum import Enum, auto -from multiprocessing.reduction import ForkingPickler +from multiprocessing.connection import Connection from typing import Any, Callable, Dict, Generic, Optional, TypeVar from weakref import ref -import torch # used for py3.7-compatible shared memory - from hivemind.utils.logging import get_logger logger = get_logger(__name__) -torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system")) # flavour types ResultType = TypeVar("ResultType") -PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection +PID, UID, State, PipeEnd = int, int, str, Connection ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} -class SharedBytes: - """ - A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes. - - Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes. - The chunks are deallocated by the garbage collector, - when it detects that all processes no longer use any bytes from this chunk. - """ - - _lock = mp.Lock() - _pid: Optional[PID] = None - _buffer: Optional[torch.Tensor] = None - _index: int = 0 - - @classmethod - def next(cls) -> torch.Tensor: - """Create another shared byte value, represented as a scalar uint8 tensor""" - with cls._lock: - if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer): - buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16)) - cls._pid = os.getpid() - cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_() - cls._index = 0 - - cls._index += 1 - return cls._buffer[cls._index - 1] - - class UpdateType(Enum): RESULT = auto() EXCEPTION = auto() @@ -90,7 +61,14 @@ def __init__(self, *, use_lock: bool = True): self._maybe_initialize_mpfuture_backend() self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int - self._shared_state_code = SharedBytes.next() + + # Create shared state using mmap with temporary file (avoids SharedMemory cleanup issues) + self._temp_file = tempfile.NamedTemporaryFile(delete=False) + self._temp_file.write(bytes([ALL_STATES.index(base.PENDING)])) + self._temp_file.flush() + self._mmap = mmap.mmap(self._temp_file.fileno(), 1) + self._shared_state_code = memoryview(self._mmap) + self._temp_file_name = self._temp_file.name self._state_cache: Dict[State, State] = {} # mapping from global to cached local future used that makes updates immediately # available on setter side; dictionary-based cache works because future can visit any state at most once @@ -111,17 +89,19 @@ def __init__(self, *, use_lock: bool = True): @property def _state(self) -> State: - shared_state = ALL_STATES[self._shared_state_code.item()] + shared_state = ALL_STATES[self._shared_state_code[0]] return self._state_cache.get(shared_state, shared_state) @_state.setter def _state(self, new_state: State): - with torch.inference_mode(): - self._shared_state_code[...] = ALL_STATES.index(new_state) - if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set(): + self._shared_state_code[0] = ALL_STATES.index(new_state) + if new_state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set(): self._set_event_threadsafe() def _set_event_threadsafe(self): + if not self._loop or not self._aio_event: + return + try: running_loop = asyncio.get_running_loop() except RuntimeError: @@ -130,9 +110,7 @@ def _set_event_threadsafe(self): async def _event_setter(): self._aio_event.set() - if self._loop.is_closed(): - return # do nothing, the loop is already closed - elif self._loop.is_running() and running_loop == self._loop: + if self._loop.is_running() and running_loop == self._loop: asyncio.create_task(_event_setter()) elif self._loop.is_running() and running_loop != self._loop: asyncio.run_coroutine_threadsafe(_event_setter(), self._loop) @@ -164,7 +142,6 @@ def reset_backend(): MPFuture._active_pid = None MPFuture._initialization_lock = mp.Lock() MPFuture._update_lock = mp.Lock() - SharedBytes._lock = mp.Lock() @classmethod def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection): @@ -175,25 +152,29 @@ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection) break # backend was reset, a new background thread has started uid, update_type, payload = receiver_pipe.recv() - future = None - future_ref = cls._active_futures.pop(uid, None) - if future_ref is not None: - future = future_ref() + future_ref = cls._active_futures.get(uid) + future = future_ref() if future_ref else None if future is None: # The MPFuture instance is already destroyed in this process # (the caller is not interested in the result) + cls._active_futures.pop(uid, None) # Clean up the stale reference continue + + # Process the update and set the corresponding state if update_type == UpdateType.RESULT: future.set_result(payload) + future._state = base.FINISHED elif update_type == UpdateType.EXCEPTION: future.set_exception(payload) + future._state = base.FINISHED elif update_type == UpdateType.CANCEL: future.cancel() + future._state = base.CANCELLED else: raise RuntimeError(f"Received unexpected update type {update_type}") except (BrokenPipeError, EOFError, ConnectionError): - logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})") + logger.debug(f"Update pipe was shut down unexpectedly (pid={pid})") except Exception as e: logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})") @@ -212,6 +193,8 @@ def set_result(self, result: ResultType): elif self._state in TERMINAL_STATES: raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})") else: + # Don't update shared state immediately in subprocess - let the origin process do it + # This prevents race condition where shared state says "finished" but result isn't ready yet self._state_cache[self._state], self._result = base.FINISHED, result self._send_update(UpdateType.RESULT, result) @@ -222,6 +205,7 @@ def set_exception(self, exception: Optional[BaseException]): elif self._state in TERMINAL_STATES: raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})") else: + # Don't update shared state immediately in subprocess - let the origin process do it self._state_cache[self._state], self._exception = base.FINISHED, exception self._send_update(UpdateType.EXCEPTION, exception) @@ -232,6 +216,7 @@ def cancel(self) -> bool: elif self._state in [base.RUNNING, base.FINISHED]: return False else: + # Don't update shared state immediately in subprocess - let the origin process do it self._state_cache[self._state] = base.CANCELLED self._send_update(UpdateType.CANCEL) return True @@ -248,25 +233,32 @@ def set_running_or_notify_cancel(self): ) def result(self, timeout: Optional[float] = None) -> ResultType: - if self._state not in TERMINAL_STATES: - if os.getpid() != self._origin_pid: + if os.getpid() != self._origin_pid: + # Non-origin process: check shared state and return cached result + if self._state not in TERMINAL_STATES: raise RuntimeError("Only the process that created MPFuture can await result") - return super().result(timeout) - elif self._state == base.CANCELLED: - raise base.CancelledError() - elif self._exception: - raise self._exception + elif self._state == base.CANCELLED: + raise base.CancelledError() + elif self._exception: + raise self._exception + else: + return self._result else: - return self._result + # Origin process: use parent's result() method which properly waits for completion + # The parent class handles the waiting and state management correctly + return super().result(timeout) def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]: - if self._state not in TERMINAL_STATES: - if os.getpid() != self._origin_pid: + if os.getpid() != self._origin_pid: + # Non-origin process: check shared state and return cached exception + if self._state not in TERMINAL_STATES: raise RuntimeError("Only the process that created MPFuture can await exception") + elif self._state == base.CANCELLED: + raise base.CancelledError() + return self._exception + else: + # Origin process: always use parent's exception() method which properly waits return super().exception(timeout) - elif self._state == base.CANCELLED: - raise base.CancelledError() - return self._exception def done(self) -> bool: return self._state in TERMINAL_STATES @@ -292,15 +284,43 @@ def __await__(self): raise asyncio.CancelledError() def __del__(self): - if getattr(self, "_origin_pid", None) == os.getpid() and MPFuture._active_futures is not None: + # Only clean up if we have the necessary attributes to avoid exceptions during teardown + if not hasattr(self, "_origin_pid"): + return + + is_origin_process = getattr(self, "_origin_pid", None) == os.getpid() + + if is_origin_process and MPFuture._active_futures is not None and hasattr(self, "_uid"): MPFuture._active_futures.pop(self._uid, None) - if getattr(self, "_aio_event", None): - self._aio_event.set() + + # Clean up mmap and temp file if we're the origin process + if is_origin_process: + if hasattr(self, "_shared_state_code"): + try: + self._shared_state_code.release() + except (BufferError, ValueError): + pass + if hasattr(self, "_mmap"): + try: + self._mmap.close() + except (OSError, ValueError): + pass + if hasattr(self, "_temp_file"): + try: + self._temp_file.close() + except (OSError, ValueError): + pass + if hasattr(self, "_temp_file_name"): + try: + os.unlink(self._temp_file_name) + except (OSError, FileNotFoundError): + pass def __getstate__(self): return dict( _sender_pipe=self._sender_pipe, - _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(), + _shared_state_code=bytes(self._shared_state_code), + _temp_file_name=getattr(self, "_temp_file_name", None), _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock, @@ -310,14 +330,24 @@ def __getstate__(self): def __setstate__(self, state): self._sender_pipe = state["_sender_pipe"] + self._temp_file_name = state.get("_temp_file_name") + + # Try to reconnect to the shared state file try: - self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"]) - except RuntimeError: - # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer, - # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object") - # here since it is not possible to connect to this buffer anymore. To address this, we just replace - # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore. - self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8) + if self._temp_file_name and os.path.exists(self._temp_file_name): + # Reconnect to existing mmap + with open(self._temp_file_name, "r+b") as f: + self._mmap = mmap.mmap(f.fileno(), 1) + self._shared_state_code = memoryview(self._mmap) + else: + # Fall back to local copy + state_bytes = state["_shared_state_code"] + self._shared_state_code = memoryview(bytearray(state_bytes)) + except (OSError, ValueError): + # If mmap fails, fall back to local copy + state_bytes = state["_shared_state_code"] + self._shared_state_code = memoryview(bytearray(state_bytes)) + self._origin_pid, self._uid = state["_origin_pid"], state["_uid"] self._result, self._exception = state["_result"], state["_exception"] self._use_lock = state["_use_lock"] diff --git a/modal_ci.py b/modal_ci.py index 43b977ee7..84d1f9e55 100644 --- a/modal_ci.py +++ b/modal_ci.py @@ -81,15 +81,10 @@ def setup_environment(*, build_p2pd=False): subprocess.run(install_cmd, check=True) - environment = os.environ.copy() - environment["HIVEMIND_MEMORY_SHARING_STRATEGY"] = "file_descriptor" - - return environment - @app.function(image=image, timeout=600, cpu=8, memory=8192) def run_tests(): - environment = setup_environment(build_p2pd=False) + setup_environment(build_p2pd=False) subprocess.run( [ @@ -102,13 +97,12 @@ def run_tests(): "tests", ], check=True, - env=environment, ) @app.function(image=image, timeout=900, cpu=8, memory=8192, secrets=[codecov_secret]) def run_codecov(): - environment = setup_environment(build_p2pd=False) + setup_environment(build_p2pd=False) subprocess.run( [ @@ -120,10 +114,10 @@ def run_codecov(): "tests", ], check=True, - env=environment, ) # Forward GitHub Actions environment variables to the codecov command + environment = os.environ.copy() environment.update( { "CODECOV_TOKEN": os.environ["CODECOV_TOKEN"], @@ -149,7 +143,7 @@ def run_codecov(): @app.function(image=image_with_golang, timeout=600, cpu=1, memory=4096) def build_and_test_p2pd(): - environment = setup_environment(build_p2pd=True) + setup_environment(build_p2pd=True) subprocess.run( [ @@ -160,5 +154,4 @@ def build_and_test_p2pd(): "tests", ], check=True, - env=environment, )