4040)
4141
4242from glide import (
43+ OK ,
4344 ConditionalChange ,
4445 ExpirySet ,
4546 ExpiryType ,
7778 BaseModelV1 = BaseModelV2
7879
7980import wrapt
80- from redis .exceptions import ResponseError
8181
8282import reflex .istate .dynamic
8383from reflex import constants
100100 ImmutableStateError ,
101101 InvalidStateManagerMode ,
102102 LockExpiredError ,
103+ RedisConfigError ,
103104 ReflexRuntimeError ,
104105 SetUndefinedStateVarError ,
105106 StateSchemaMismatchError ,
@@ -3217,29 +3218,33 @@ class StateManagerRedis(StateManager):
32173218 }
32183219
32193220 # This lock is used to ensure we only subscribe to keyspace events once per token and worker
3220- # _pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
3221-
3222- _pubsub_clients : Dict [bytes , GlideClient ] = pydantic .PrivateAttr ({})
3221+ _pubsub_locks : Dict [bytes , asyncio .Lock ] = pydantic .PrivateAttr ({})
32233222
32243223 async def get_redis (self ) -> GlideClient :
32253224 """Get the redis client.
32263225
32273226 Returns:
32283227 The redis client.
3228+
3229+ Raises:
3230+ RedisConfigError: If the redis client could not be configured.
32293231 """
32303232 if self .redis is not None :
32313233 return self .redis
32323234 redis = await prerequisites .get_redis ()
32333235 assert redis is not None
3234- try :
3235- _ = await redis .config_set (
3236- {"notify-keyspace-events" : self ._redis_notify_keyspace_events },
3236+ config_result = await redis .config_set (
3237+ {"notify-keyspace-events" : self ._redis_notify_keyspace_events },
3238+ )
3239+ # Some redis servers only allow out-of-band configuration, so ignore errors here.
3240+ if (
3241+ config_result != OK
3242+ and not environment .REFLEX_IGNORE_REDIS_CONFIG_ERROR .get ()
3243+ ):
3244+ raise RedisConfigError (
3245+ f"Failed to set notify-keyspace-events: { config_result } "
32373246 )
3238- # TODO: adjust exception for glide
3239- except ResponseError :
3240- # Some redis servers only allow out-of-band configuration, so ignore errors here.
3241- if not environment .REFLEX_IGNORE_REDIS_CONFIG_ERROR .get ():
3242- raise
3247+
32433248 self .redis = redis
32443249 return redis
32453250
@@ -3407,6 +3412,7 @@ async def set_state(
34073412 """
34083413 # Check that we're holding the lock.
34093414 redis = await self .get_redis ()
3415+
34103416 if lock_id is not None and await redis .get (self ._lock_key (token )) != lock_id :
34113417 raise LockExpiredError (
34123418 f"Lock expired for token { token } while processing. Consider increasing "
@@ -3440,9 +3446,15 @@ async def set_state(
34403446 _ = await redis .set (
34413447 _substate_key (client_token , state ),
34423448 pickle_state ,
3443- expiry = self .expiry ,
3444- # ex=self.token_expiration,
3449+ expiry = ExpirySet (
3450+ expiry_type = ExpiryType .MILLSEC ,
3451+ value = self .token_expiration ,
3452+ ),
34453453 )
3454+ # if str(res) != OK:
3455+ # raise RuntimeError(
3456+ # f"Failed to set state for token {token}. {res} {OK}"
3457+ # )
34463458
34473459 # Wait for substates to be persisted.
34483460 for t in tasks :
@@ -3478,18 +3490,6 @@ def _lock_key(token: str) -> bytes:
34783490 client_token = _split_substate_key (token )[0 ]
34793491 return f"{ client_token } _lock" .encode ()
34803492
3481- @property
3482- def expiry (self ) -> ExpirySet :
3483- """Get the expiry set for the token.
3484-
3485- Returns:
3486- The expiry set for the token.
3487- """
3488- return ExpirySet (
3489- expiry_type = ExpiryType .SEC ,
3490- value = self .token_expiration ,
3491- )
3492-
34933493 async def _try_get_lock (self , lock_key : bytes , lock_id : bytes ) -> bool | None :
34943494 """Try to get a redis lock for a token.
34953495
@@ -3504,10 +3504,13 @@ async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
35043504 response = await redis .set (
35053505 lock_key ,
35063506 lock_id ,
3507- expiry = self .expiry ,
3507+ expiry = ExpirySet (
3508+ expiry_type = ExpiryType .MILLSEC ,
3509+ value = self .lock_expiration ,
3510+ ),
35083511 conditional_set = ConditionalChange .ONLY_IF_DOES_NOT_EXIST ,
35093512 )
3510- return bool (response )
3513+ return str (response ) == OK
35113514
35123515 async def get_pubsub (self , lock_key : bytes ) -> GlideClient :
35133516 """Get the pubsub client for a lock key channel.
@@ -3519,12 +3522,10 @@ async def get_pubsub(self, lock_key: bytes) -> GlideClient:
35193522 The pubsub client.
35203523 """
35213524 lock_key_channel = f"__keyspace@0__:{ lock_key .decode ()} "
3522- if lock_key_channel in self ._pubsub_clients :
3523- return self ._pubsub_clients [lock_key_channel ]
35243525 pubsub_config = GlideClientConfiguration .PubSubSubscriptions (
35253526 channels_and_patterns = {
35263527 GlideClientConfiguration .PubSubChannelModes .Pattern : {lock_key_channel },
3527- GlideClientConfiguration .PubSubChannelModes .Exact : {lock_key_channel },
3528+ # GlideClientConfiguration.PubSubChannelModes.Exact: {lock_key_channel},
35283529 },
35293530 callback = None ,
35303531 context = None ,
@@ -3534,7 +3535,6 @@ async def get_pubsub(self, lock_key: bytes) -> GlideClient:
35343535 )
35353536 assert config is not None
35363537 pubsub = await GlideClient .create (config )
3537- self ._pubsub_clients [lock_key ] = pubsub
35383538 return pubsub
35393539
35403540 async def _wait_lock (self , lock_key : bytes , lock_id : bytes ) -> None :
@@ -3545,58 +3545,48 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
35453545 Args:
35463546 lock_key: The redis key for the lock.
35473547 lock_id: The ID of the lock.
3548-
3549- Raises:
3550- ResponseError: when the keyspace config cannot be set.
35513548 """
35523549 state_is_locked = False
35533550 # Enable keyspace notifications for the lock key, so we know when it is available.
35543551 redis = await self .get_redis ()
3555- pubsub = await self .get_pubsub (lock_key )
3556- # async with self.redis.pubsub() as pubsub:
3557- # await pubsub.psubscribe(lock_key_channel)
3558- # await pubsub.get_pubsub_message()
3559- count = 0
3560- while not state_is_locked :
3561- count += 1
3562- if count > 10000 :
3563- raise Exception ("Could not obtain lock" )
3564- # wait for the lock to be released
3565- print ("waiting for lock to be released" )
3566- while True :
3567- if not await redis .exists ([lock_key ]):
3568- # if not pubsub.try_get_pubsub_message():
3569- break # key was removed, try to get the lock again
3570- message = await pubsub .get_pubsub_message (
3571- # ignore_subscribe_messages=True,
3572- # timeout=self.lock_expiration / 1000.0,
3573- )
3574- # if message.pattern is None:
3575- # # raise Exception("Pattern is None")
3576- # continue
3577- # raise Exception(message)
3578- if message .message in self ._redis_keyspace_lock_release_events :
3579- break
3580- state_is_locked = await self ._try_get_lock (lock_key , lock_id )
3552+ if lock_key not in self ._pubsub_locks :
3553+ self ._pubsub_locks [lock_key ] = asyncio .Lock ()
3554+ async with self ._pubsub_locks [lock_key ]:
3555+ pubsub = await self .get_pubsub (lock_key )
3556+ while not state_is_locked :
3557+ # wait for the lock to be released
3558+ while True :
3559+ # check if we missed lock release events
3560+ if await redis .exists ([lock_key ]) == 0 :
3561+ break # key was removed, try to get the lock again
3562+
3563+ try :
3564+ # TODO: alternative to ignore_subscribe_messages?
3565+ message = await asyncio .wait_for (
3566+ pubsub .get_pubsub_message (),
3567+ timeout = self .lock_expiration / 1000.0 ,
3568+ )
3569+ except asyncio .TimeoutError :
3570+ continue
3571+ if message .message in self ._redis_keyspace_lock_release_events :
3572+ break
3573+ state_is_locked = await self ._try_get_lock (lock_key , lock_id )
35813574
35823575 @override
3583- async def disconnect (self , token : str ):
3576+ async def disconnect (self , token : str ) -> None :
35843577 """Disconnect the token from the redis client.
35853578
35863579 Args:
35873580 token: The token to disconnect.
35883581 """
35893582 lock_key = self ._lock_key (token )
3590- # if lock := self._pubsub_locks.get(lock_key):
3591- # if lock.locked():
3592- # lock.release()
3593- # del self._pubsub_locks[lock_key]
3594- if client := self ._pubsub_clients .get (self ._lock_key (token )):
3595- await client .close ()
3596- del self ._pubsub_clients [lock_key ]
3583+ if lock := self ._pubsub_locks .get (lock_key ):
3584+ if lock .locked ():
3585+ lock .release ()
3586+ del self ._pubsub_locks [lock_key ]
35973587
35983588 @contextlib .asynccontextmanager
3599- async def _lock (self , token : str ):
3589+ async def _lock (self , token : str ) -> AsyncIterator [ bytes ] :
36003590 """Obtain a redis lock for a token.
36013591
36023592 Args:
@@ -3626,8 +3616,10 @@ async def _lock(self, token: str):
36263616 # only delete our lock
36273617 redis = await self .get_redis ()
36283618 _ = await redis .delete ([lock_key ])
3619+ # if not res:
3620+ # raise RuntimeError(f"Failed to release lock for token {token}")
36293621
3630- async def close (self ):
3622+ async def close (self ) -> None :
36313623 """Explicitly close the redis connection and connection_pool.
36323624
36333625 It is necessary in testing scenarios to close between asyncio test cases
@@ -3636,14 +3628,9 @@ async def close(self):
36363628
36373629 Note: Connections will be automatically reopened when needed.
36383630 """
3639- # await self.redis.aclose(close_connection_pool=True)
3640- # TODO: is this needed with glide?
3641- redis = await self .get_redis ()
3642- await redis .close ()
3643-
3644- for pubsub in self ._pubsub_clients .values ():
3645- await pubsub .close ()
3646- self ._pubsub_clients = {}
3631+ if self .redis is not None :
3632+ await self .redis .close ()
3633+ self .redis = None
36473634
36483635
36493636def get_state_manager () -> StateManager :
0 commit comments