diff --git a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py index 238883c11..53fcb19ab 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py @@ -114,7 +114,7 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture): bec.metadata.update({"unit_test": "test_mv_scan_nested_device"}) dev = bec.device_manager.devices scans.mv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False).wait() - if not bec.connector._messages_queue.empty(): + if not bec.connector._message_callbacks_queue.empty(): print("Waiting for messages to be processed") time.sleep(0.5) current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"] @@ -126,7 +126,7 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture): current_pos_hexapod_y, 20, atol=dev.hexapod._config["deviceConfig"].get("tolerance", 0.5) ) scans.umv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False) - if not bec.connector._messages_queue.empty(): + if not bec.connector._message_callbacks_queue.empty(): print("Waiting for messages to be processed") time.sleep(0.5) current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"] diff --git a/bec_lib/bec_lib/messaging_services.py b/bec_lib/bec_lib/messaging_services.py index 50138a76c..aceee0793 100644 --- a/bec_lib/bec_lib/messaging_services.py +++ b/bec_lib/bec_lib/messaging_services.py @@ -127,7 +127,6 @@ def __init__(self, redis_connector: RedisConnector) -> None: self._redis_connector.register( MessageEndpoints.available_messaging_services(), cb=self._on_new_scope_change_msg, - parent=self, from_start=True, ) @@ -142,21 +141,19 @@ def set_default_scope(self, scope: str | list[str] | None) -> None: raise ValueError(f"Scope '{scope}' is not available for this messaging service.") self._default_scope = scope - @staticmethod def _on_new_scope_change_msg( - message: dict[str, messages.AvailableMessagingServicesMessage], parent: MessagingService + self, message: dict[str, messages.AvailableMessagingServicesMessage] ) -> None: """ Callback for scope changes. Currently a placeholder for future functionality. Args: message (dict[str, messages.AvailableMessagingServicesMessage]): The scope change message. - parent (MessagingService): The parent messaging service instance. """ msg = message["data"] # pylint: disable=protected-access - parent._service_config = msg - parent._update_messaging_services(msg) + self._service_config = msg + self._update_messaging_services(msg) def _update_messaging_services( self, service_info: messages.AvailableMessagingServicesMessage diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index da66d17ed..0e392eb2a 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -17,6 +17,7 @@ import time import traceback import warnings +from collections import defaultdict from collections.abc import MutableMapping, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -31,6 +32,7 @@ Generator, Iterable, Literal, + NamedTuple, ParamSpec, TypedDict, TypeVar, @@ -40,6 +42,7 @@ import louie import redis.client import redis.exceptions +from astroid.nodes import Unknown from redis.backoff import ExponentialBackoff from redis.client import Pipeline, Redis from redis.retry import Retry @@ -57,6 +60,7 @@ ) from bec_lib.serialization import MsgpackSerialization +logger = bec_logger.logger if TYPE_CHECKING: # pragma: no cover from concurrent.futures import Future @@ -84,6 +88,11 @@ class InvalidItemForOperation(ValueError): ... class WrongArguments(ValueError): ... +def _error_log_with_context(msg: str): + context = "".join(traceback.format_stack(limit=5)[:-1]) + logger.error(msg + f" Context:\n{context}") + + def _raise_incompatible_message(msg, endpoint): raise IncompatibleMessageForEndpoint( f"Message type {type(msg)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}" @@ -198,34 +207,194 @@ class GeneratorExecution: @dataclass -class StreamSubscriptionInfo: - id: str - topic: str - newest_only: bool - from_start: bool +class StreamSubInfo: cb_ref: Callable - kwargs: dict + kwargs: dict[str, Unknown] def __eq__(self, other): - if not isinstance(other, StreamSubscriptionInfo): + if not isinstance(other, StreamSubInfo): return False - return ( - self.topic == other.topic - and self.cb_ref == other.cb_ref - and self.from_start == other.from_start - ) + return self.cb_ref == other.cb_ref + + def __hash__(self) -> int: + return self.cb_ref.__hash__() @dataclass -class DirectReadingStreamSubscriptionInfo(StreamSubscriptionInfo): +class DirectReadStreamSubInfo(StreamSubInfo): stop_event: threading.Event - thread: threading.Thread | None = None + thread: threading.Thread + + def __hash__(self) -> int: + return self.cb_ref.__hash__() @dataclass class StreamMessage: msg: dict - callbacks: Iterable[tuple[Callable, dict]] + callbacks: Iterable[tuple[Callable, dict[str, Unknown]]] + + +class StreamSubsEntry(NamedTuple): + read_id: str + subs: set[StreamSubInfo] + + +StreamResponseList = list[tuple[bytes, list[tuple[bytes, dict[bytes, bytes]]]]] +StreamSubsRegistry = dict[str, StreamSubsEntry] + + +class StreamSubs: + def __init__(self) -> None: + """Manager for stream subscriptions. Since operations often need to be combined, + use the lock directly at point of call, it is generally not used in the methods.""" + self.lock = threading.RLock() + + self._subs: StreamSubsRegistry = {} + self._direct_read_subs: dict[ + str, dict[DirectReadStreamSubInfo, DirectReadStreamSubInfo] + ] = {} + self.from_start_subs: dict[str, set[StreamSubInfo]] = {} + + @property + def normal_subs(self): + return {t: s.subs for t, s in self._subs.items()} + + @property + def all_topics(self): + with self.lock: + from_start_keys = [k for k in self.from_start_subs if self.from_start_subs[k] != set()] + dr_sub_keys = [k for k in self._direct_read_subs if self._direct_read_subs[k] != set()] + return list(set((*self._subs.keys(), *dr_sub_keys, *from_start_keys))) + + def topic_ids(self) -> dict[str, str]: + """Get Redis read Ids for active subscriptions""" + return {topic: infos.read_id for topic, infos in self._subs.items()} + + def update_normal_ids(self, updated_ids: dict[str, str]): + for topic, id in updated_ids.items(): + if topic in self._subs: + self._subs[topic] = StreamSubsEntry(id, self._subs[topic].subs) + + def from_start_topics(self) -> set[str]: + """Get topics for new `from_start` subscriptions which haven't been read yet""" + return set(self.from_start_subs.keys()) + + def end_id(self, topic: str): + """Return the last read id for a given topic if given, or "+" """ + return self._subs[topic].read_id if topic in self._subs else "+" + + def move_from_start_to_normal(self, topics_and_end_ids: dict[str, str]): + if topics_and_end_ids.keys() != self.from_start_subs.keys(): + _error_log_with_context( + f"Mismatch of subs to move! {topics_and_end_ids.keys()=}, {self.from_start_subs.keys()=} Was a lock forgotten?" + ) + for topic in topics_and_end_ids: + if topic in self._subs: + if topics_and_end_ids[topic] != self._subs[topic].read_id: + _error_log_with_context(f"Mismatch of ID! Was a lock forgotten?") + for sub in self.from_start_subs.pop(topic): + self._subs[topic].subs.add(sub) # type: ignore + else: + self._subs[topic] = StreamSubsEntry( + read_id=topics_and_end_ids[topic], subs=self.from_start_subs.pop(topic) + ) + + def is_already_registered(self, topic: str, new_sub: StreamSubInfo): + return ( + (topic in self.from_start_subs and new_sub in self.from_start_subs[topic]) + or (topic in self._direct_read_subs and new_sub in self._direct_read_subs[topic]) + or (topic in self._subs and new_sub in self._subs[topic].subs) + ) + + def _check_registered(self, topic: str, new_sub: StreamSubInfo): + if self.is_already_registered(topic, new_sub): + raise ValueError(f"Received duplicate subscription for {new_sub=}.") + + def add_direct_listener(self, topic: str, new_sub: DirectReadStreamSubInfo): + self._check_registered(topic, new_sub) + if not topic in self._direct_read_subs: + self._direct_read_subs[topic] = {} + self._direct_read_subs[topic][new_sub] = new_sub + new_sub.thread.start() + + def add(self, from_start: bool, last_id: str, topic: str, new_sub: StreamSubInfo): + self._check_registered(topic, new_sub) + if from_start: + if topic in self.from_start_subs: + subs = self.from_start_subs[topic] + else: + subs = set() + self.from_start_subs[topic] = subs + else: + if not topic in self._subs: + subs = set() + self._subs[topic] = StreamSubsEntry(read_id=last_id, subs=subs) + else: + subs = self._subs[topic].subs + subs.add(new_sub) + + @staticmethod + def _kill_direct_stream(sub: DirectReadStreamSubInfo, topic: str): + sub.stop_event.set() + sub.thread.join(timeout=1) + if sub.thread.is_alive(): + _error_log_with_context( + f"RedisConnector direct stream callback thread for {topic=}, {sub.cb_ref=} failed to shutdown" + ) + + def remove(self, topic: str, cb: Callable | None = None) -> bool: + removed = False + if cb is None: # Remove all subs for the given topic + removed |= bool(self.from_start_subs.pop(topic, False)) + removed |= bool(self._subs.pop(topic, False)) + if (subs := self._direct_read_subs.pop(topic, None)) is not None: + for sub in subs: + self._kill_direct_stream(sub, topic) + removed = True + return removed + test_subinfo = StreamSubInfo(louie.saferef.safe_ref(cb), {}) + if topic in self.from_start_subs and test_subinfo in self.from_start_subs[topic]: + self.from_start_subs[topic].remove(test_subinfo) + removed = True + if len(self.from_start_subs[topic]) == 0: + del self.from_start_subs[topic] + if topic in self._direct_read_subs and test_subinfo in self._direct_read_subs[topic]: + sub = self._direct_read_subs[topic].pop(test_subinfo) # type: ignore # hash is the same + self._kill_direct_stream(sub, topic) + removed = True + if len(self._direct_read_subs[topic]) == 0: + del self._direct_read_subs[topic] + if topic in self._subs and test_subinfo in self._subs[topic].subs: + self._subs[topic].subs.remove(test_subinfo) + removed = True + if len(self._subs[topic].subs) == 0: + del self._subs[topic] + return removed + + def gc_cb_refs(self): + for topic, entry in list(self._subs.items()): + for info in list(entry.subs): + if not info.cb_ref(): + entry.subs.remove(info) + if len(self._subs[topic].subs) == 0: + del self._subs[topic] + for topic, entry in list(self._direct_read_subs.items()): + for info in list(entry.keys()): + if not info.cb_ref(): + info.stop_event.set() + info.thread.join(0.05) + if info.thread.is_alive(): + _error_log_with_context(f"Failed to garbage collect in 0.05s {info}") + del entry[info] + if self._direct_read_subs[topic] == {}: + del self._direct_read_subs[topic] + for topic, subs in list(self.from_start_subs.items()): + for info in list(subs): + if not info.cb_ref(): + subs.remove(info) + if len(self.from_start_subs[topic]) == 0: + del self.from_start_subs[topic] class RedisConnector: @@ -276,16 +445,15 @@ def redis_connect_func(_redis_conn): collections.defaultdict(list) ) self._topics_cb_lock = threading.Lock() - self._stream_topics_subscription = collections.defaultdict(list) - self._stream_topics_subscription_lock = threading.Lock() + self._stream_subs = StreamSubs() self._events_listener_thread: threading.Thread | None = None self._stream_events_listener_thread: threading.Thread | None = None self._events_dispatcher_thread: threading.Thread | None = None - self._messages_queue = queue.Queue() + self._message_callbacks_queue = queue.Queue() self._stop_events_listener_thread = threading.Event() self._stop_stream_events_listener_thread = threading.Event() - self.stream_keys: dict[str, str] = {} + self.stream_keys: dict[str, str] = {} # for explicit reads, not subscriptions self._generator_executor = ThreadPoolExecutor() @@ -392,12 +560,12 @@ def shutdown(self, per_thread_timeout_s: float | None = None): self._stream_events_listener_thread.join(timeout=per_thread_timeout_s) self._stream_events_listener_thread = None if self._events_dispatcher_thread: - self._messages_queue.put(StopIteration) + self._message_callbacks_queue.put(StopIteration) self._events_dispatcher_thread.join(timeout=per_thread_timeout_s) self._events_dispatcher_thread = None # this will take care of shutting down direct listening threads - self._unregister_stream(self._stream_topics_subscription) + self._unregister_stream(self._stream_subs.all_topics) # release all connections self._pubsub_conn.close() @@ -463,7 +631,7 @@ def raise_alarm(self, severity: Alarms, info: ErrorInfo, metadata: dict | None = >>> connector.raise_alarm( severity=Alarms.WARNING, info=ErrorInfo( - id=str(uuid.uuid4()), + id=str(uuid.uuid4()),_stream_topic_subscriptions error_message="ValueError", compact_error_message="test alarm", exception_type="ValueError", @@ -571,6 +739,19 @@ def _normalize_patterns(self, patterns) -> list[str]: raise ValueError("register: patterns must be a string or a list of strings") return patterns + def any_stream_is_registered( + self, topics: EndpointInfo | str | list[EndpointInfo] | list[str], cb: Callable + ) -> bool: + """Check if any stream in `topics` is already registered with this callback. + Does not check if the topic is a stream in Redis, it will just return False.""" + with self._stream_subs.lock: + return any( + self._stream_subs.is_already_registered( + topic, StreamSubInfo(louie.saferef.safe_ref(cb), {}) + ) + for topic in self._convert_endpointinfo(topics)[0] + ) + def register( self, topics: str | list[str] | EndpointInfo | list[EndpointInfo] | None = None, @@ -648,7 +829,7 @@ def register( self._topics_cb[topic].append(item) self._start_events_dispatcher_thread(start_thread) - def _add_direct_stream_listener(self, topic, cb_ref, **kwargs): + def _create_direct_stream_listener(self, topic, cb_ref, kwargs): """ Add a direct listener for a topic. This is used when newest_only is True. @@ -658,123 +839,98 @@ def _add_direct_stream_listener(self, topic, cb_ref, **kwargs): kwargs (dict): additional keyword arguments to be transmitted to the callback Returns: - None + DirectReadStreamSubInfo with an unstarted thread """ - info = DirectReadingStreamSubscriptionInfo( - id="-", - topic=topic, - newest_only=True, - from_start=False, - cb_ref=cb_ref, - kwargs=kwargs, - stop_event=threading.Event(), + stop_event = threading.Event() + thread = threading.Thread( + target=self._direct_stream_listener, args=(topic, stop_event, cb_ref, kwargs) ) - if info in self._stream_topics_subscription[topic]: - raise RuntimeError("Already registered stream topic with the same callback") - - info.thread = threading.Thread(target=self._direct_stream_listener, args=(info,)) - with self._stream_topics_subscription_lock: - self._stream_topics_subscription[topic].append(info) - info.thread.start() - - def _direct_stream_listener(self, info: DirectReadingStreamSubscriptionInfo): - stop_event = info.stop_event - cb_ref = info.cb_ref - kwargs = info.kwargs - topic = info.topic + return DirectReadStreamSubInfo(cb_ref, kwargs, stop_event, thread) + + def _direct_stream_listener(self, topic: str, stop_event: threading.Event, cb_ref, kwargs): + read_id = "-" while not stop_event.is_set(): - ret = self._redis_conn.xrevrange(topic, "+", info.id, count=1) - if not ret: - time.sleep(0.1) + if not (response := self._redis_conn.xrevrange(topic, "+", read_id, count=1)): + stop_event.wait(timeout=0.1) continue - redis_id, msg_dict = ret[0] # type: ignore : we are using Redis synchronously + redis_id, msg_dict = response[0] # type: ignore : we are using Redis synchronously timestamp, _, ind = redis_id.partition(b"-") - info.id = f"{timestamp.decode()}-{int(ind.decode())+1}" + read_id = f"{timestamp.decode()}-{int(ind.decode())+1}" stream_msg = StreamMessage( {key.decode(): MsgpackSerialization.loads(val) for key, val in msg_dict.items()}, ((cb_ref, kwargs),), ) - self._messages_queue.put(stream_msg) - - def _get_stream_topics_id(self) -> tuple[dict, dict]: - stream_topics_id = {} - from_start_stream_topics_id = {} - with self._stream_topics_subscription_lock: - for topic, subscription_info_list in self._stream_topics_subscription.items(): - for info in subscription_info_list: - if isinstance(info, DirectReadingStreamSubscriptionInfo): - continue - if info.from_start: - from_start_stream_topics_id[topic] = info.id - else: - stream_topics_id[topic] = info.id - return from_start_stream_topics_id, stream_topics_id + self._message_callbacks_queue.put(stream_msg) - def _handle_stream_msg_list(self, msg_list, from_start=False): - for topic, msgs in msg_list: - subscription_info_list = self._stream_topics_subscription[topic.decode()] - for index, record in msgs: - callbacks = [] - for info in subscription_info_list: - info.id = index.decode() - if from_start and not info.from_start: - continue - callbacks.append((info.cb_ref, info.kwargs)) - if callbacks: + def _handle_stream_msg_list( + self, redis_response: StreamResponseList, subs: dict[str, set[StreamSubInfo]] + ): + new_ids = {} + for btopic, msgs in redis_response: + for read_id, record in msgs: + topic: str = btopic.decode() if isinstance(btopic, bytes) else btopic # type: ignore + if callbacks := subs.get(topic): msg_dict = { k.decode(): MsgpackSerialization.loads(msg) for k, msg in record.items() } - msg = StreamMessage(msg_dict, callbacks) - self._messages_queue.put(msg) - for info in subscription_info_list: - info.from_start = False + msg = StreamMessage(msg_dict, [(cb.cb_ref, cb.kwargs) for cb in callbacks]) + self._message_callbacks_queue.put(msg) + new_ids[topic] = read_id.decode() + return new_ids + + def _try_read_streams(self, topics_ids: dict[str, str], from_start: bool = False): + try: + if from_start: + return [(t, self._redis_conn.xrange(t, "-", end)) for t, end in topics_ids.items()] + else: + return self._redis_conn.xread(topics_ids, block=200) or [] # type: ignore strs are fine key and id types + except redis.exceptions.ConnectionError: + logger.error("Failed to connect to redis. Is the server running?") + except redis.exceptions.NoPermissionError: + logger.error(f"Permission denied for stream topics: {set(topics_ids.keys())}") + # pylint: disable=broad-except + except Exception: + sys.excepthook(*sys.exc_info()) # type: ignore # inside except + + def _read_from_start_streams_and_migrate(self) -> bool: + """Returns whether there was an error""" + with self._stream_subs.lock: + if from_start_topics := self._stream_subs.from_start_topics(): + topics_and_end_ids = {t: self._stream_subs.end_id(t) for t in from_start_topics} + response = self._try_read_streams(topics_and_end_ids, from_start=True) + if response is not None: + updated_end_ids = self._handle_stream_msg_list( + response, self._stream_subs.from_start_subs + ) + new_end_ids = {t: "0-0" for t in from_start_topics} + new_end_ids.update(updated_end_ids) + self._stream_subs.move_from_start_to_normal(new_end_ids) + else: + return True + return False def _get_stream_messages_loop(self) -> None: """ Get stream messages loop. This method is run in a separate thread and listens for messages from the redis server. """ - error = False - while not self._stop_stream_events_listener_thread.is_set(): - try: - from_start_stream_topics_id, stream_topics_id = self._get_stream_topics_id() - if not any((stream_topics_id, from_start_stream_topics_id)): - self._stop_stream_events_listener_thread.wait(timeout=0.1) - continue - msg_list = [] - from_start_msg_list = [] - # first handle the 'from_start' streams ; - # in the case of reading from start what is expected is to call the - # callbacks for existing items, without waiting for a new element to be added - # to the stream - if from_start_stream_topics_id: - # read the streams contents from beginning - from_start_msg_list = self._redis_conn.xread( - from_start_stream_topics_id, block=200 - ) - if stream_topics_id: - msg_list = self._redis_conn.xread(stream_topics_id, block=200) - except redis.exceptions.ConnectionError: - if not error: - error = True - bec_logger.logger.error("Failed to connect to redis. Is the server running?") - self._stop_stream_events_listener_thread.wait(timeout=1) - except redis.exceptions.NoPermissionError: - bec_logger.logger.error( - f"Permission denied for stream topics: \n Topics id: {from_start_stream_topics_id}, Stream topics id: {stream_topics_id}" - ) - if not error: - error = True + # first clear any dead callbacks + with self._stream_subs.lock: + self._stream_subs.gc_cb_refs() + # First read the "from_start" streams, up until any id which is already in the normal + # subs, then all those them to the normal streams + error = self._read_from_start_streams_and_migrate() + # Then read all the normal streams + with self._stream_subs.lock: + normal_topics = self._stream_subs.topic_ids() + normal_subs = self._stream_subs.normal_subs + if normal_topics and (response := self._try_read_streams(normal_topics)) is not None: + updated_ids = self._handle_stream_msg_list(response, normal_subs) + with self._stream_subs.lock: + self._stream_subs.update_normal_ids(updated_ids) + if error: # Encountered an error on xread, wait a while without the lock self._stop_stream_events_listener_thread.wait(timeout=1) - # pylint: disable=broad-except - except Exception: - sys.excepthook(*sys.exc_info()) # type: ignore # inside except - else: - error = False - with self._stream_topics_subscription_lock: - self._handle_stream_msg_list(from_start_msg_list, from_start=True) - self._handle_stream_msg_list(msg_list) def _register_stream( self, @@ -805,50 +961,27 @@ def _register_stream( cb_ref = louie.saferef.safe_ref(cb) self._start_events_dispatcher_thread(start_thread) - - if newest_only: - # if newest_only is True, we need to provide a separate callback for each topic, - # directly calling the callback. This is because we need to have a backpressure - # mechanism in place, and we cannot rely on the dispatcher thread to handle it. + with self._stream_subs.lock: for topic in topics: - self._add_direct_stream_listener(topic, cb_ref, **kwargs) - else: - with self._stream_topics_subscription_lock: - for topic in topics: + if newest_only: + new_sub = self._create_direct_stream_listener(topic, cb_ref, kwargs) + self._stream_subs.add_direct_listener(topic, new_sub) + else: + new_sub = StreamSubInfo(cb_ref, kwargs) try: stream_info = self._redis_conn.xinfo_stream(topic) except redis.exceptions.ResponseError: - # no such key - last_id = "0-0" + last_id = "0-0" # no such key else: last_id = stream_info["last-entry"][0].decode() # type: ignore # we are using the sync Redis client - new_subscription = StreamSubscriptionInfo( - id="0-0" if from_start else last_id, - topic=topic, - newest_only=newest_only, - from_start=from_start, - cb_ref=cb_ref, - kwargs=kwargs, - ) - subscriptions = self._stream_topics_subscription[topic] - if new_subscription in subscriptions: - # raise an error if attempted to register a stream with the same callback, - # whereas it has already been registered as a 'direct reading' stream with - # newest_only=True ; it is clearly an error case that would produce weird results - index = subscriptions.index(new_subscription) - if isinstance(subscriptions[index], DirectReadingStreamSubscriptionInfo): - raise RuntimeError( - "Already registered stream topic with the same callback with 'newest_only=True'" - ) - else: - subscriptions.append(new_subscription) + self._stream_subs.add(from_start, last_id, topic, new_sub) - if self._stream_events_listener_thread is None: - # create the thread that will get all messages for this connector - self._stream_events_listener_thread = threading.Thread( - target=self._get_stream_messages_loop - ) - self._stream_events_listener_thread.start() + if self._stream_events_listener_thread is None: + # create the thread that will get all messages for this connector + self._stream_events_listener_thread = threading.Thread( + target=self._get_stream_messages_loop + ) + self._stream_events_listener_thread.start() def _filter_topics_cb(self, topics: list, cb: Callable | None): unsubscribe_list = [] @@ -870,60 +1003,43 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None): def unregister(self, topics=None, patterns=None, cb=None): if self._events_listener_thread is None: return - + if topics and patterns: + _error_log_with_context( + f"Unsubscribe called with both {topics=} and {patterns=}. Topics will be ignored in favour of patterns." + ) if patterns is not None: patterns = self._normalize_patterns(patterns) # see if registered streams can be unregistered for pattern in patterns: - self._unregister_stream( - fnmatch.filter(self._stream_topics_subscription, pattern), cb - ) + self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb) pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb) if pubsub_unsubscribe_list: self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list) - else: + elif topics is not None: topics, _ = self._convert_endpointinfo(topics, check_message_op=False) if not self._unregister_stream(topics, cb): unsubscribe_list = self._filter_topics_cb(topics, cb) if unsubscribe_list: self._pubsub_conn.unsubscribe(unsubscribe_list) + else: + with self._topics_cb_lock: + topics = list(self._topics_cb.keys()) + self.unregister(topics, cb) + self.unregister(self._stream_subs.all_topics, cb) def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool: - """ - Unregister a stream listener. - - Args: - topics (list[str]): list of stream topics - - Returns: - bool: True if the stream listener has been removed, False otherwise - """ - unsubscribe_list = [] - with self._stream_topics_subscription_lock: - for topic in topics: - subscription_infos = self._stream_topics_subscription[topic] - # remove from list if callback corresponds - self._stream_topics_subscription[topic] = list( - filter(lambda sub_info: cb and sub_info.cb_ref() is not cb, subscription_infos) - ) - if not self._stream_topics_subscription[topic]: - # no callbacks left, unsubscribe - unsubscribe_list += subscription_infos - # clean the topics that have been unsubscribed - for subscription_info in unsubscribe_list: - if isinstance(subscription_info, DirectReadingStreamSubscriptionInfo): - subscription_info.stop_event.set() - if subscription_info.thread: - subscription_info.thread.join() - # it is possible to register the same stream multiple times with different - # callbacks, in this case when unregistering with cb=None (unregister all) - # the topic can be deleted multiple times, hence try...except in code below - try: - del self._stream_topics_subscription[subscription_info.topic] - except KeyError: - pass + """Unregister callbacks from a list of topics. Returns true if any were removed""" + with self._stream_subs.lock: + return any([self._stream_subs.remove(topic, cb) for topic in topics]) - return len(unsubscribe_list) > 0 + def _garbage_collect_cb_refs(self): + """Only handles normal subscriptions, for streams, see StreamSubs.gc_cb_refs()""" + with self._topics_cb_lock: + for topic, subs in list(self._topics_cb.items()): + for cb_ref, kwargs in reversed(subs): + if not cb_ref(): + idx = self._topics_cb[topic].index((cb_ref, kwargs)) + self._topics_cb[topic].pop(idx) def _get_messages_loop(self) -> None: """ @@ -935,6 +1051,7 @@ def _get_messages_loop(self) -> None: """ error = False while not self._stop_events_listener_thread.is_set(): + self._garbage_collect_cb_refs() try: msg = self._pubsub_conn.get_message(timeout=0.2) except redis.exceptions.ConnectionError: @@ -948,7 +1065,7 @@ def _get_messages_loop(self) -> None: else: error = False if msg is not None: - self._messages_queue.put(msg) + self._message_callbacks_queue.put(msg) def _execute_callback(self, cb, msg, kwargs): try: @@ -959,17 +1076,16 @@ def _execute_callback(self, cb, msg, kwargs): else: if inspect.isgenerator(g): # reschedule execution to delineate the generator - self._messages_queue.put(g) + self._message_callbacks_queue.put(g) def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessage): if inspect.isgenerator(msg): g = msg fut = self._generator_executor.submit(next, g) - self._messages_queue.put(GeneratorExecution(fut, g)) + self._message_callbacks_queue.put(GeneratorExecution(fut, g)) elif isinstance(msg, StreamMessage): for cb_ref, kwargs in msg.callbacks: - cb = cb_ref() - if cb: + if cb := cb_ref(): self._execute_callback(cb, msg.msg, kwargs) elif isinstance(msg, GeneratorExecution): fut, g = msg.fut, msg.g @@ -980,9 +1096,9 @@ def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessag pass else: fut = self._generator_executor.submit(g.send, res) - self._messages_queue.put(GeneratorExecution(fut, g)) + self._message_callbacks_queue.put(GeneratorExecution(fut, g)) else: - self._messages_queue.put(GeneratorExecution(fut, g)) + self._message_callbacks_queue.put(GeneratorExecution(fut, g)) else: channel = msg["channel"].decode() with self._topics_cb_lock: @@ -992,8 +1108,7 @@ def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessag callbacks = self._topics_cb[channel] msg_obj = MessageObject(topic=channel, value=MsgpackSerialization.loads(msg["data"])) for cb_ref, kwargs in callbacks: - cb = cb_ref() - if cb: + if cb := cb_ref(): self._execute_callback(cb, msg_obj, kwargs) def poll_messages(self, timeout: float | None = None) -> bool: @@ -1011,7 +1126,7 @@ def poll_messages(self, timeout: float | None = None) -> bool: while True: try: # wait for a message and return it before timeout expires - msg = self._messages_queue.get(timeout=remaining_timeout, block=True) + msg = self._message_callbacks_queue.get(timeout=remaining_timeout, block=True) except queue.Empty as exc: remaining_timeout = cast(float, remaining_timeout) timeout = cast(float, timeout) @@ -1032,7 +1147,7 @@ def poll_messages(self, timeout: float | None = None) -> bool: bec_logger.logger.error(f"Error handling message {msg}:\n{content}") if timeout is None: - if self._messages_queue.empty(): + if self._message_callbacks_queue.empty(): # no message to process return True else: diff --git a/bec_lib/tests/test_messaging_service.py b/bec_lib/tests/test_messaging_service.py index 70552de78..e6a29649f 100644 --- a/bec_lib/tests/test_messaging_service.py +++ b/bec_lib/tests/test_messaging_service.py @@ -26,9 +26,7 @@ def scilog_service(connected_connector): ], session_services=[], ) - SciLogMessagingService._on_new_scope_change_msg( - message={"data": available_services}, parent=service - ) + service._on_new_scope_change_msg(message={"data": available_services}) yield service @@ -48,9 +46,7 @@ def signal_service(connected_connector): ], session_services=[], ) - SignalMessagingService._on_new_scope_change_msg( - message={"data": available_services}, parent=service - ) + service._on_new_scope_change_msg(message={"data": available_services}) yield service @@ -271,9 +267,7 @@ def test_disabled_service_cannot_create_message(connected_connector): ], session_services=[], ) - SciLogMessagingService._on_new_scope_change_msg( - message={"data": available_services}, parent=service - ) + service._on_new_scope_change_msg(message={"data": available_services}) with pytest.raises(RuntimeError, match="Messaging service 'scilog' is not enabled."): service.new() @@ -295,9 +289,7 @@ def test_disabled_service_cannot_send_message(connected_connector): ], session_services=[], ) - SciLogMessagingService._on_new_scope_change_msg( - message={"data": available_services}, parent=service - ) + service._on_new_scope_change_msg(message={"data": available_services}) message = service.new() message.add_text("Test message") @@ -315,9 +307,7 @@ def test_disabled_service_cannot_send_message(connected_connector): ], session_services=[], ) - SciLogMessagingService._on_new_scope_change_msg( - message={"data": disabled_services}, parent=service - ) + service._on_new_scope_change_msg(message={"data": available_services}) with pytest.raises(RuntimeError, match="Messaging service 'scilog' is not enabled."): message.send() @@ -386,9 +376,7 @@ def test_signal_message_service_uses_default_scope(connected_connector): ], session_services=[], ) - SignalMessagingService._on_new_scope_change_msg( - message={"data": available_services}, parent=service - ) + service._on_new_scope_change_msg(message={"data": available_services}) service.set_default_scope("user") message = service.new() diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index 43026ed45..3ac30f063 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -1,3 +1,4 @@ +import gc import threading import time from unittest import mock @@ -369,20 +370,22 @@ def test_redis_connector_register_stream(connected_connector): connector.poll_messages() cb_mock1.assert_not_called() cb_mock2.assert_called_once_with({"data": 2}, a=2) + assert "test" in connector._stream_subs.all_topics connector.unregister("test") - assert connector._stream_topics_subscription["test"] == [] + assert connector._stream_subs.all_topics == [] +@pytest.mark.timeout(10) def test_redis_connector_register_stream_identical(connected_connector): connector = connected_connector received_event1 = mock.Mock(spec=[]) received_event2 = mock.Mock(spec=[]) - connector.register(TestStreamEndpoint, cb=received_event1, start_thread=False) connector.register(TestStreamEndpoint, cb=received_event1, start_thread=False) connector.register(TestStreamEndpoint, cb=received_event2, start_thread=False) connector.register(TestStreamEndpoint2, cb=received_event1, start_thread=False) + connector.register(TestStreamEndpoint2, cb=received_event2, start_thread=False) connector.xadd("test", {"data": 1}) connector.poll_messages(timeout=1) assert received_event1.call_count == 1 @@ -392,14 +395,11 @@ def test_redis_connector_register_stream_identical(connected_connector): assert received_event1.call_count == 2 try: - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): connector.register( TestStreamEndpoint2, cb=received_event1, newest_only=True, start_thread=False ) - connector.register( - TestStreamEndpoint2, cb=received_event2, newest_only=True, start_thread=False - ) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): connector.register(TestStreamEndpoint2, cb=received_event2, start_thread=False) finally: connector.unregister(TestStreamEndpoint2) @@ -427,7 +427,8 @@ def test_redis_connector_register_stream_list(connected_connector, endpoint): connector.poll_messages() assert mock.call({"data": 2}, a=1) in cb_mock.mock_calls connector.unregister(endpoint) - assert len(connector._stream_topics_subscription) == 0 + all_topics = connector._stream_subs.all_topics + assert len(all_topics) == 0 @pytest.mark.timeout(10) @@ -448,12 +449,14 @@ def test_redis_connector_register_stream_from_start(connected_connector): cb_mock1.assert_called_once_with({"data": 3}, a=1) cb_mock2.assert_called_once_with({"data": 3}, a=2) cb_mock1.reset_mock() + connector.unregister(TestStreamEndpoint, cb=cb_mock1) connector.register(TestStreamEndpoint, cb=cb_mock1, from_start=True, start_thread=False, a=3) connector.poll_messages(timeout=1) cb_mock1.assert_has_calls( [mock.call({"data": 1}, a=3), mock.call({"data": 2}, a=3), mock.call({"data": 3}, a=3)] ) cb_mock1.reset_mock() + connector.unregister(TestStreamEndpoint, cb=cb_mock1) connector.register(TestStreamEndpoint, cb=cb_mock1, start_thread=False, a=4) with pytest.raises(TimeoutError): connector.poll_messages(timeout=1) @@ -608,3 +611,122 @@ def cb(msg): assert res.metrics["m2"].value == 5.5 assert res.metrics["m3"].value == "test" assert res.metrics["m4"].value is True + + +def test_merging_streams_does_not_skip_messages(connected_connector: RedisConnector): + connector = connected_connector + cb_normal = mock.Mock(spec=[]) # spec is here to remove all attributes + cb_from_start = mock.Mock(spec=[]) # spec is here to remove all attributes + + connector.xadd("test", {"data": 1}) + connector.xadd("test", {"data": 2}) + + connector.register(TestStreamEndpoint, cb=cb_normal, start_thread=False, key="normal") + with pytest.raises(TimeoutError): + connector.poll_messages(timeout=0.1) + cb_normal.assert_not_called() + + connector.xadd("test", {"data": 3}) + connector.poll_messages() + cb_normal.assert_called_once_with({"data": 3}, key="normal") + cb_normal.reset_mock() + + assert (id_3 := connected_connector._stream_subs.end_id("test")) != "+" + + connector.xadd("test", {"data": 4}) + connector.xadd("test", {"data": 5}) + cb_normal.assert_not_called() + connector.register( + TestStreamEndpoint, cb=cb_from_start, from_start=True, start_thread=False, key="from_start" + ) + + connected_connector._read_from_start_streams_and_migrate() + connector.poll_messages(timeout=0) + connector.poll_messages(timeout=0) + connector.poll_messages(timeout=0) + + assert cb_from_start.call_count == 3 + + with pytest.raises(TimeoutError): + connector.poll_messages(timeout=0) + + assert cb_from_start.call_count == 3 + assert connected_connector._stream_subs.from_start_subs == {} + assert connected_connector._stream_subs.end_id("test") == id_3 + + connector.poll_messages() + + assert cb_from_start.call_count == 5 + assert cb_normal.call_count == 2 + + +def test_subs_garbage_collectioon(connected_connector): + sub1 = mock.MagicMock(spec=[]) + sub2 = mock.MagicMock(spec=[]) + sub3 = mock.MagicMock(spec=[]) + + connected_connector.register("test", cb=sub1) + connected_connector.register("test", cb=sub2) + connected_connector.register("test", cb=sub3) + + assert len(connected_connector._topics_cb["test"]) == 3 + connected_connector._garbage_collect_cb_refs() + assert len(connected_connector._topics_cb["test"]) == 3 + del sub2 + gc.collect() + connected_connector._garbage_collect_cb_refs() + assert len(connected_connector._topics_cb["test"]) == 2 + + +def test_stream_subs_garbage_collectioon(connected_connector): + sub1 = mock.MagicMock(spec=[]) + sub2 = mock.MagicMock(spec=[]) + sub3 = mock.MagicMock(spec=[]) + + connected_connector.register(TestStreamEndpoint, cb=sub1) + connected_connector.register(TestStreamEndpoint, cb=sub2) + connected_connector.register(TestStreamEndpoint, cb=sub3) + + assert len(connected_connector._stream_subs._subs["test"].subs) == 3 + connected_connector._stream_subs.gc_cb_refs() + assert len(connected_connector._stream_subs._subs["test"].subs) == 3 + + del sub2 + gc.collect() + connected_connector._stream_subs.gc_cb_refs() + + assert len(connected_connector._stream_subs._subs["test"].subs) == 2 + + sub4 = mock.MagicMock(spec=[]) + sub5 = mock.MagicMock(spec=[]) + + connected_connector.register(TestStreamEndpoint, cb=sub4, from_start=True) + connected_connector.register(TestStreamEndpoint, cb=sub5, from_start=True) + + assert len(connected_connector._stream_subs._subs["test"].subs) == 2 + assert len(connected_connector._stream_subs.from_start_subs["test"]) == 2 + + del sub4 + del sub3 + gc.collect() + connected_connector._stream_subs.gc_cb_refs() + + assert len(connected_connector._stream_subs._subs["test"].subs) == 1 + assert len(connected_connector._stream_subs.from_start_subs["test"]) == 1 + + sub6 = mock.MagicMock(spec=[]) + connected_connector.register(TestStreamEndpoint, cb=sub6, newest_only=True) + + assert len(connected_connector._stream_subs._subs["test"].subs) == 1 + assert len(connected_connector._stream_subs.from_start_subs["test"]) == 1 + assert len(connected_connector._stream_subs._direct_read_subs["test"]) == 1 + + del sub1 + del sub5 + del sub6 + gc.collect() + connected_connector._stream_subs.gc_cb_refs() + + assert "test" not in connected_connector._stream_subs._subs + assert "test" not in connected_connector._stream_subs.from_start_subs + assert "test" not in connected_connector._stream_subs._direct_read_subs