Skip to content

Commit 5003c16

Browse files
committed
Try safe_recv()
1 parent 6dc02d7 commit 5003c16

File tree

9 files changed

+25
-16
lines changed

9 files changed

+25
-16
lines changed

hivemind/averaging/averager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
enter_asynchronously,
3838
switch_to_uvloop,
3939
)
40+
from hivemind.utils.compat import safe_recv
4041
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
4142
from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
4243
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -313,7 +314,7 @@ async def _run():
313314
if not self._inner_pipe.poll():
314315
continue
315316
try:
316-
method, args, kwargs = self._inner_pipe.recv()
317+
method, args, kwargs = safe_recv(self._inner_pipe)
317318
except (OSError, ConnectionError, RuntimeError) as e:
318319
logger.exception(e)
319320
await asyncio.sleep(self.request_timeout)
@@ -774,7 +775,7 @@ def _background_thread_fetch_current_state(
774775
"""
775776
while True:
776777
try:
777-
trigger, future = pipe.recv()
778+
trigger, future = safe_recv(pipe)
778779
except BaseException as e:
779780
logger.debug(f"Averager background thread finished: {repr(e)}")
780781
break

hivemind/dht/dht.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
1515
from hivemind.p2p import P2P, PeerID
1616
from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
17+
from hivemind.utils.compat import safe_recv
1718
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration
1819

1920
logger = get_logger(__name__)
@@ -126,7 +127,7 @@ async def _run():
126127
if not self._inner_pipe.poll():
127128
continue
128129
try:
129-
method, args, kwargs = self._inner_pipe.recv()
130+
method, args, kwargs = safe_recv(self._inner_pipe)
130131
except (OSError, ConnectionError, RuntimeError) as e:
131132
logger.exception(e)
132133
await asyncio.sleep(self._node.protocol.wait_timeout)

hivemind/moe/server/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from hivemind.moe.server.runtime import Runtime
2727
from hivemind.p2p import PeerInfo
2828
from hivemind.proto.runtime_pb2 import CompressionType
29+
from hivemind.utils.compat import safe_recv
2930
from hivemind.utils.logging import get_logger
3031
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
3132

@@ -314,7 +315,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
314315
runner.start()
315316
# once the server is ready, runner will send us
316317
# either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
317-
start_ok, data = pipe.recv()
318+
start_ok, data = safe_recv(pipe)
318319
if start_ok:
319320
yield data
320321
pipe.send("SHUTDOWN") # on exit from context, send shutdown signal
@@ -339,7 +340,7 @@ def _server_runner(pipe, *args, **kwargs):
339340
try:
340341
dht_maddrs = server.dht.get_visible_maddrs()
341342
pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
342-
pipe.recv() # wait for shutdown signal
343+
safe_recv(pipe) # wait for shutdown signal
343344

344345
finally:
345346
logger.info("Shutting down server...")

hivemind/moe/server/task_pool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
from hivemind.utils import get_logger
18+
from hivemind.utils.compat import safe_recv
1819
from hivemind.utils.mpfuture import InvalidStateError, MPFuture
1920

2021
logger = get_logger(__name__)
@@ -195,7 +196,7 @@ def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
195196

196197
while True:
197198
logger.debug(f"{self.name} waiting for results from runtime")
198-
batch_index, batch_outputs_or_exception = self.outputs_receiver.recv()
199+
batch_index, batch_outputs_or_exception = safe_recv(self.outputs_receiver)
199200
batch_tasks = pending_batches.pop(batch_index)
200201

201202
if isinstance(batch_outputs_or_exception, BaseException):
@@ -234,7 +235,7 @@ def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[to
234235
if not self.batch_receiver.poll(timeout):
235236
raise TimeoutError()
236237

237-
batch_index, batch_inputs = self.batch_receiver.recv()
238+
batch_index, batch_inputs = safe_recv(self.batch_receiver)
238239
batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
239240
return batch_index, batch_inputs
240241

hivemind/utils/mpfuture.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch # used for py3.7-compatible shared memory
1616

17+
from hivemind.utils.compat import safe_recv
1718
from hivemind.utils.logging import get_logger
1819

1920
logger = get_logger(__name__)
@@ -182,7 +183,7 @@ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection)
182183
if cls._pipe_waiter_thread is not threading.current_thread():
183184
break # backend was reset, a new background thread has started
184185

185-
uid, update_type, payload = receiver_pipe.recv()
186+
uid, update_type, payload = safe_recv(receiver_pipe)
186187
future = None
187188
future_ref = cls._active_futures.pop(uid, None)
188189
if future_ref is not None:

tests/test_dht_crypto.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from hivemind.dht.crypto import RSASignatureValidator
99
from hivemind.dht.node import DHTNode
1010
from hivemind.dht.validation import DHTRecord
11+
from hivemind.utils.compat import safe_recv
1112
from hivemind.utils.crypto import RSAPrivateKey
1213
from hivemind.utils.timed_storage import get_dht_time
1314

@@ -79,8 +80,8 @@ def test_validator_instance_is_picklable():
7980

8081

8182
def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
82-
validator = conn.recv()
83-
record = conn.recv()
83+
validator = safe_recv(conn)
84+
record = safe_recv(conn)
8485

8586
record = dataclasses.replace(record, value=validator.sign_value(record))
8687

@@ -101,7 +102,7 @@ def test_signing_in_different_process():
101102
)
102103
parent_conn.send(record)
103104

104-
signed_record = parent_conn.recv()
105+
signed_record = safe_recv(parent_conn)
105106
assert b"[signature:" in signed_record.value
106107
assert validator.validate(signed_record)
107108

tests/test_dht_protocol.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from hivemind.dht import DHTID
1313
from hivemind.dht.protocol import DHTProtocol
1414
from hivemind.dht.storage import DictionaryDHTValue
15+
from hivemind.utils.compat import safe_recv
1516

1617
logger = get_logger(__name__)
1718

@@ -56,7 +57,7 @@ def launch_protocol_listener(
5657
dht_id = DHTID.generate()
5758
process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
5859
process.start()
59-
peer_id, visible_maddrs = local_conn.recv()
60+
peer_id, visible_maddrs = safe_recv(local_conn)
6061

6162
return dht_id, process, peer_id, visible_maddrs
6263

tests/test_p2p_daemon.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
1515
from hivemind.proto import dht_pb2, test_pb2
16+
from hivemind.utils.compat import safe_recv
1617
from hivemind.utils.serializer import MSGPackSerializer
1718

1819
from test_utils.networking import get_free_port
@@ -328,8 +329,8 @@ async def test_call_peer_different_processes():
328329
proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
329330
proc.start()
330331

331-
peer_id = client_side.recv()
332-
peer_maddrs = client_side.recv()
332+
peer_id = safe_recv(client_side)
333+
peer_maddrs = safe_recv(client_side)
333334

334335
client = await P2P.create(initial_peers=peer_maddrs)
335336
client_pid = client._child.pid

tests/test_util_modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
1414
from hivemind.proto.runtime_pb2 import CompressionType
1515
from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
16+
from hivemind.utils.compat import safe_recv
1617
from hivemind.utils.asyncio import (
1718
achain,
1819
aenumerate,
@@ -260,7 +261,7 @@ def _check_result_and_set(future):
260261
p = mp.Process(target=_future_creator)
261262
p.start()
262263

263-
future1, future2 = receiver.recv()
264+
future1, future2 = safe_recv(receiver)
264265
future1.set_result(123)
265266

266267
with pytest.raises(RuntimeError):
@@ -309,7 +310,7 @@ def _run_peer():
309310
p = mp.Process(target=_run_peer)
310311
p.start()
311312

312-
some_fork_futures = receiver.recv()
313+
some_fork_futures = safe_recv(recv)
313314

314315
time.sleep(0.1) # giving enough time for the futures to be destroyed
315316
assert len(hivemind.MPFuture._active_futures) == 700

0 commit comments

Comments
 (0)