@@ -372,6 +372,29 @@ def remove(self, topic: str, cb: Callable | None = None) -> bool:
372372 del self ._subs [topic ]
373373 return removed
374374
375+ def gc_cb_refs (self ):
376+ for topic , entry in list (self ._subs .items ()):
377+ for info in entry .subs :
378+ if not info .cb_ref ():
379+ entry .subs .remove (info )
380+ if len (self ._subs [topic ].subs ) == 0 :
381+ del self ._subs [topic ]
382+ for topic , entry in list (self ._direct_read_subs .items ()):
383+ for info in entry :
384+ if not info .cb_ref ():
385+ info .stop_event .set ()
386+ info .thread .join (0.05 )
387+ if info .thread .is_alive ():
388+ _error_log_with_context (f"Failed to garbage collect in 0.05s { info } " )
389+ if self ._direct_read_subs [topic ] == {}:
390+ del self ._direct_read_subs [topic ]
391+ for topic , subs in list (self .from_start_subs .items ()):
392+ for info in subs :
393+ if not info .cb_ref ():
394+ subs .remove (info )
395+ if len (self .from_start_subs [topic ]) == 0 :
396+ del self .from_start_subs [topic ]
397+
375398
376399class RedisConnector :
377400 """
@@ -429,7 +452,7 @@ def redis_connect_func(_redis_conn):
429452 self ._message_callbacks_queue = queue .Queue ()
430453 self ._stop_events_listener_thread = threading .Event ()
431454 self ._stop_stream_events_listener_thread = threading .Event ()
432- self .stream_keys : dict [str , str ] = {}
455+ self .stream_keys : dict [str , str ] = {} # for explicit reads, not subscriptions
433456
434457 self ._generator_executor = ThreadPoolExecutor ()
435458
@@ -891,6 +914,9 @@ def _get_stream_messages_loop(self) -> None:
891914 for messages from the redis server.
892915 """
893916 while not self ._stop_stream_events_listener_thread .is_set ():
917+ # first clear any dead callbacks
918+ with self ._stream_subs .lock :
919+ self ._stream_subs .gc_cb_refs ()
894920 # First read the "from_start" streams, up until any id which is already in the normal
895921 # subs, then all those them to the normal streams
896922 error = self ._read_from_start_streams_and_migrate ()
@@ -976,7 +1002,10 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None):
9761002 def unregister (self , topics = None , patterns = None , cb = None ):
9771003 if self ._events_listener_thread is None :
9781004 return
979-
1005+ if topics and patterns :
1006+ _error_log_with_context (
1007+ f"Unsubscribe called with both { topics = } and { patterns = } . Topics will be ignored in favour of patterns."
1008+ )
9801009 if patterns is not None :
9811010 patterns = self ._normalize_patterns (patterns )
9821011 # see if registered streams can be unregistered
@@ -985,18 +1014,32 @@ def unregister(self, topics=None, patterns=None, cb=None):
9851014 pubsub_unsubscribe_list = self ._filter_topics_cb (patterns , cb )
9861015 if pubsub_unsubscribe_list :
9871016 self ._pubsub_conn .punsubscribe (pubsub_unsubscribe_list )
988- else :
1017+ elif topics is not None :
9891018 topics , _ = self ._convert_endpointinfo (topics , check_message_op = False )
9901019 if not self ._unregister_stream (topics , cb ):
9911020 unsubscribe_list = self ._filter_topics_cb (topics , cb )
9921021 if unsubscribe_list :
9931022 self ._pubsub_conn .unsubscribe (unsubscribe_list )
1023+ else :
1024+ with self ._topics_cb_lock :
1025+ topics = list (self ._topics_cb .keys ())
1026+ self .unregister (topics , cb )
1027+ self .unregister (self ._stream_subs .all_topics , cb )
9941028
9951029 def _unregister_stream (self , topics : list [str ], cb : Callable | None = None ) -> bool :
9961030 """Unregister callbacks from a list of topics. Returns true if any were removed"""
9971031 with self ._stream_subs .lock :
9981032 return any ([self ._stream_subs .remove (topic , cb ) for topic in topics ])
9991033
1034+ def _garbage_collect_cb_refs (self ):
1035+ """Only handles normal subscriptions, for streams, see StreamSubs.gc_cb_refs()"""
1036+ with self ._topics_cb_lock :
1037+ for topic , subs in list (self ._topics_cb .items ()):
1038+ for cb_ref , kwargs in reversed (subs ):
1039+ if not cb_ref ():
1040+ idx = self ._topics_cb [topic ].index ((cb_ref , kwargs ))
1041+ self ._topics_cb [topic ].pop (idx )
1042+
10001043 def _get_messages_loop (self ) -> None :
10011044 """
10021045 Get messages loop. This method is run in a separate thread and listens
@@ -1007,6 +1050,7 @@ def _get_messages_loop(self) -> None:
10071050 """
10081051 error = False
10091052 while not self ._stop_events_listener_thread .is_set ():
1053+ self ._garbage_collect_cb_refs ()
10101054 try :
10111055 msg = self ._pubsub_conn .get_message (timeout = 0.2 )
10121056 except redis .exceptions .ConnectionError :
@@ -1040,8 +1084,7 @@ def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessag
10401084 self ._message_callbacks_queue .put (GeneratorExecution (fut , g ))
10411085 elif isinstance (msg , StreamMessage ):
10421086 for cb_ref , kwargs in msg .callbacks :
1043- cb = cb_ref ()
1044- if cb :
1087+ if cb := cb_ref ():
10451088 self ._execute_callback (cb , msg .msg , kwargs )
10461089 elif isinstance (msg , GeneratorExecution ):
10471090 fut , g = msg .fut , msg .g
@@ -1064,8 +1107,7 @@ def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessag
10641107 callbacks = self ._topics_cb [channel ]
10651108 msg_obj = MessageObject (topic = channel , value = MsgpackSerialization .loads (msg ["data" ]))
10661109 for cb_ref , kwargs in callbacks :
1067- cb = cb_ref ()
1068- if cb :
1110+ if cb := cb_ref ():
10691111 self ._execute_callback (cb , msg_obj , kwargs )
10701112
10711113 def poll_messages (self , timeout : float | None = None ) -> bool :
0 commit comments