Skip to content

Commit b6ab4c6

Browse files
committed
feat: garbage collect subs
1 parent 4c74599 commit b6ab4c6

1 file changed

Lines changed: 49 additions & 7 deletions

File tree

bec_lib/bec_lib/redis_connector.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,29 @@ def remove(self, topic: str, cb: Callable | None = None) -> bool:
372372
del self._subs[topic]
373373
return removed
374374

375+
def gc_cb_refs(self):
376+
for topic, entry in list(self._subs.items()):
377+
for info in entry.subs:
378+
if not info.cb_ref():
379+
entry.subs.remove(info)
380+
if len(self._subs[topic].subs) == 0:
381+
del self._subs[topic]
382+
for topic, entry in list(self._direct_read_subs.items()):
383+
for info in entry:
384+
if not info.cb_ref():
385+
info.stop_event.set()
386+
info.thread.join(0.05)
387+
if info.thread.is_alive():
388+
_error_log_with_context(f"Failed to garbage collect in 0.05s {info}")
389+
if self._direct_read_subs[topic] == {}:
390+
del self._direct_read_subs[topic]
391+
for topic, subs in list(self.from_start_subs.items()):
392+
for info in subs:
393+
if not info.cb_ref():
394+
subs.remove(info)
395+
if len(self.from_start_subs[topic]) == 0:
396+
del self.from_start_subs[topic]
397+
375398

376399
class RedisConnector:
377400
"""
@@ -429,7 +452,7 @@ def redis_connect_func(_redis_conn):
429452
self._message_callbacks_queue = queue.Queue()
430453
self._stop_events_listener_thread = threading.Event()
431454
self._stop_stream_events_listener_thread = threading.Event()
432-
self.stream_keys: dict[str, str] = {}
455+
self.stream_keys: dict[str, str] = {} # for explicit reads, not subscriptions
433456

434457
self._generator_executor = ThreadPoolExecutor()
435458

@@ -891,6 +914,9 @@ def _get_stream_messages_loop(self) -> None:
891914
for messages from the redis server.
892915
"""
893916
while not self._stop_stream_events_listener_thread.is_set():
917+
# first clear any dead callbacks
918+
with self._stream_subs.lock:
919+
self._stream_subs.gc_cb_refs()
894920
# First read the "from_start" streams, up until any id which is already in the normal
895921
# subs, then all those them to the normal streams
896922
error = self._read_from_start_streams_and_migrate()
@@ -976,7 +1002,10 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None):
9761002
def unregister(self, topics=None, patterns=None, cb=None):
9771003
if self._events_listener_thread is None:
9781004
return
979-
1005+
if topics and patterns:
1006+
_error_log_with_context(
1007+
f"Unsubscribe called with both {topics=} and {patterns=}. Topics will be ignored in favour of patterns."
1008+
)
9801009
if patterns is not None:
9811010
patterns = self._normalize_patterns(patterns)
9821011
# see if registered streams can be unregistered
@@ -985,18 +1014,32 @@ def unregister(self, topics=None, patterns=None, cb=None):
9851014
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb)
9861015
if pubsub_unsubscribe_list:
9871016
self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list)
988-
else:
1017+
elif topics is not None:
9891018
topics, _ = self._convert_endpointinfo(topics, check_message_op=False)
9901019
if not self._unregister_stream(topics, cb):
9911020
unsubscribe_list = self._filter_topics_cb(topics, cb)
9921021
if unsubscribe_list:
9931022
self._pubsub_conn.unsubscribe(unsubscribe_list)
1023+
else:
1024+
with self._topics_cb_lock:
1025+
topics = list(self._topics_cb.keys())
1026+
self.unregister(topics, cb)
1027+
self.unregister(self._stream_subs.all_topics, cb)
9941028

9951029
def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool:
9961030
"""Unregister callbacks from a list of topics. Returns true if any were removed"""
9971031
with self._stream_subs.lock:
9981032
return any([self._stream_subs.remove(topic, cb) for topic in topics])
9991033

1034+
def _garbage_collect_cb_refs(self):
1035+
"""Only handles normal subscriptions, for streams, see StreamSubs.gc_cb_refs()"""
1036+
with self._topics_cb_lock:
1037+
for topic, subs in list(self._topics_cb.items()):
1038+
for cb_ref, kwargs in reversed(subs):
1039+
if not cb_ref():
1040+
idx = self._topics_cb[topic].index((cb_ref, kwargs))
1041+
self._topics_cb[topic].pop(idx)
1042+
10001043
def _get_messages_loop(self) -> None:
10011044
"""
10021045
Get messages loop. This method is run in a separate thread and listens
@@ -1007,6 +1050,7 @@ def _get_messages_loop(self) -> None:
10071050
"""
10081051
error = False
10091052
while not self._stop_events_listener_thread.is_set():
1053+
self._garbage_collect_cb_refs()
10101054
try:
10111055
msg = self._pubsub_conn.get_message(timeout=0.2)
10121056
except redis.exceptions.ConnectionError:
@@ -1040,8 +1084,7 @@ def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessag
10401084
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
10411085
elif isinstance(msg, StreamMessage):
10421086
for cb_ref, kwargs in msg.callbacks:
1043-
cb = cb_ref()
1044-
if cb:
1087+
if cb := cb_ref():
10451088
self._execute_callback(cb, msg.msg, kwargs)
10461089
elif isinstance(msg, GeneratorExecution):
10471090
fut, g = msg.fut, msg.g
@@ -1064,8 +1107,7 @@ def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessag
10641107
callbacks = self._topics_cb[channel]
10651108
msg_obj = MessageObject(topic=channel, value=MsgpackSerialization.loads(msg["data"]))
10661109
for cb_ref, kwargs in callbacks:
1067-
cb = cb_ref()
1068-
if cb:
1110+
if cb := cb_ref():
10691111
self._execute_callback(cb, msg_obj, kwargs)
10701112

10711113
def poll_messages(self, timeout: float | None = None) -> bool:

0 commit comments

Comments
 (0)