Skip to content

Commit

Permalink
Deprecating unused arguments in connection pools's get_connection fun…
Browse files Browse the repository at this point in the history
…ctions (#3517)
  • Loading branch information
petyaslavova authored Feb 20, 2025
1 parent 799716c commit 8427c7b
Show file tree
Hide file tree
Showing 19 changed files with 196 additions and 113 deletions.
16 changes: 6 additions & 10 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ async def initialize(self: _RedisT) -> _RedisT:
if self.single_connection_client:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
self.connection = await self.connection_pool.get_connection()

self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(
Expand Down Expand Up @@ -638,7 +638,7 @@ async def execute_command(self, *args, **options):
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)
conn = self.connection or await pool.get_connection()

if self.single_connection_client:
await self._single_conn_lock.acquire()
Expand Down Expand Up @@ -712,7 +712,7 @@ def __init__(self, connection_pool: ConnectionPool):

async def connect(self):
if self.connection is None:
self.connection = await self.connection_pool.get_connection("MONITOR")
self.connection = await self.connection_pool.get_connection()

async def __aenter__(self):
await self.connect()
Expand Down Expand Up @@ -900,9 +900,7 @@ async def connect(self):
Ensure that the PubSub is connected
"""
if self.connection is None:
self.connection = await self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
self.connection = await self.connection_pool.get_connection()
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
Expand Down Expand Up @@ -1370,9 +1368,7 @@ async def immediate_execute_command(self, *args, **options):
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = await self.connection_pool.get_connection(
command_name, self.shard_hint
)
conn = await self.connection_pool.get_connection()
self.connection = conn

return await conn.retry.call_with_retry(
Expand Down Expand Up @@ -1568,7 +1564,7 @@ async def execute(self, raise_on_error: bool = True) -> List[Any]:

conn = self.connection
if not conn:
conn = await self.connection_pool.get_connection("MULTI", self.shard_hint)
conn = await self.connection_pool.get_connection()
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
Expand Down
16 changes: 13 additions & 3 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from ..auth.token import TokenInterface
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
from ..utils import format_error_message
from ..utils import deprecated_args, format_error_message

# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
Expand Down Expand Up @@ -1087,7 +1087,12 @@ def can_get_connection(self) -> bool:
or len(self._in_use_connections) < self.max_connections
)

async def get_connection(self, command_name, *keys, **options):
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.0.3",
)
async def get_connection(self, command_name=None, *keys, **options):
async with self._lock:
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
Expand Down Expand Up @@ -1255,7 +1260,12 @@ def __init__(
self._condition = asyncio.Condition()
self.timeout = timeout

async def get_connection(self, command_name, *keys, **options):
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.0.3",
)
async def get_connection(self, command_name=None, *keys, **options):
"""Gets a connection from the pool, blocking until one is available"""
try:
async with self._condition:
Expand Down
14 changes: 6 additions & 8 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def __init__(
self.connection = None
self._single_connection_client = single_connection_client
if self._single_connection_client:
self.connection = self.connection_pool.get_connection("_")
self.connection = self.connection_pool.get_connection()
self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(
self.connection, ClientType.SYNC, self.single_connection_lock
Expand Down Expand Up @@ -608,7 +608,7 @@ def _execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
pool = self.connection_pool
command_name = args[0]
conn = self.connection or pool.get_connection(command_name, **options)
conn = self.connection or pool.get_connection()

if self._single_connection_client:
self.single_connection_lock.acquire()
Expand Down Expand Up @@ -667,7 +667,7 @@ class Monitor:

def __init__(self, connection_pool):
self.connection_pool = connection_pool
self.connection = self.connection_pool.get_connection("MONITOR")
self.connection = self.connection_pool.get_connection()

def __enter__(self):
self.connection.send_command("MONITOR")
Expand Down Expand Up @@ -840,9 +840,7 @@ def execute_command(self, *args):
# subscribed to one or more channels

if self.connection is None:
self.connection = self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
self.connection = self.connection_pool.get_connection()
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
Expand Down Expand Up @@ -1397,7 +1395,7 @@ def immediate_execute_command(self, *args, **options):
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = self.connection_pool.get_connection(command_name, self.shard_hint)
conn = self.connection_pool.get_connection()
self.connection = conn

return conn.retry.call_with_retry(
Expand Down Expand Up @@ -1583,7 +1581,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]:

conn = self.connection
if not conn:
conn = self.connection_pool.get_connection("MULTI", self.shard_hint)
conn = self.connection_pool.get_connection()
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
Expand Down
18 changes: 10 additions & 8 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from redis.retry import Retry
from redis.utils import (
HIREDIS_AVAILABLE,
deprecated_args,
dict_merge,
list_keys_to_dict,
merge_result,
Expand All @@ -52,10 +53,13 @@ def get_node_name(host: str, port: Union[str, int]) -> str:
return f"{host}:{port}"


@deprecated_args(
allowed_args=["redis_node"],
reason="Use get_connection(redis_node) instead",
version="5.0.3",
)
def get_connection(redis_node, *args, **options):
return redis_node.connection or redis_node.connection_pool.get_connection(
args[0], **options
)
return redis_node.connection or redis_node.connection_pool.get_connection()


def parse_scan_result(command, res, **options):
Expand Down Expand Up @@ -1151,7 +1155,7 @@ def _execute_command(self, target_node, *args, **kwargs):
moved = False

redis_node = self.get_redis_connection(target_node)
connection = get_connection(redis_node, *args, **kwargs)
connection = get_connection(redis_node)
if asking:
connection.send_command("ASKING")
redis_node.parse_response(connection, "ASKING", **kwargs)
Expand Down Expand Up @@ -1822,9 +1826,7 @@ def execute_command(self, *args):
self.node = node
redis_connection = self.cluster.get_redis_connection(node)
self.connection_pool = redis_connection.connection_pool
self.connection = self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
self.connection = self.connection_pool.get_connection()
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
Expand Down Expand Up @@ -2184,7 +2186,7 @@ def _send_cluster_commands(
if node_name not in nodes:
redis_node = self.get_redis_connection(node)
try:
connection = get_connection(redis_node, c.args)
connection = get_connection(redis_node)
except (ConnectionError, TimeoutError):
for n in nodes.values():
n.connection_pool.release(n.connection)
Expand Down
16 changes: 14 additions & 2 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
HIREDIS_AVAILABLE,
SSL_AVAILABLE,
compare_versions,
deprecated_args,
ensure_string,
format_error_message,
get_lib_version,
Expand Down Expand Up @@ -1461,8 +1462,14 @@ def _checkpid(self) -> None:
finally:
self._fork_lock.release()

def get_connection(self, command_name: str, *keys, **options) -> "Connection":
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.0.3",
)
def get_connection(self, command_name=None, *keys, **options) -> "Connection":
"Get a connection from the pool"

self._checkpid()
with self._lock:
try:
Expand Down Expand Up @@ -1683,7 +1690,12 @@ def make_connection(self):
self._connections.append(connection)
return connection

def get_connection(self, command_name, *keys, **options):
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.0.3",
)
def get_connection(self, command_name=None, *keys, **options):
"""
Get a connection, blocking for ``self.timeout`` until a connection
is available from the pool.
Expand Down
65 changes: 65 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,71 @@ def wrapper(*args, **kwargs):
return decorator


def warn_deprecated_arg_usage(
arg_name: Union[list, str],
function_name: str,
reason: str = "",
version: str = "",
stacklevel: int = 2,
):
import warnings

msg = (
f"Call to '{function_name}' function with deprecated"
f" usage of input argument/s '{arg_name}'."
)
if reason:
msg += f" ({reason})"
if version:
msg += f" -- Deprecated since version {version}."
warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)


def deprecated_args(
args_to_warn: list = ["*"],
allowed_args: list = [],
reason: str = "",
version: str = "",
):
"""
Decorator to mark specified args of a function as deprecated.
If '*' is in args_to_warn, all arguments will be marked as deprecated.
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Get function argument names
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]

provided_args = dict(zip(arg_names, args))
provided_args.update(kwargs)

provided_args.pop("self", None)
for allowed_arg in allowed_args:
provided_args.pop(allowed_arg, None)

for arg in args_to_warn:
if arg == "*" and len(provided_args) > 0:
warn_deprecated_arg_usage(
list(provided_args.keys()),
func.__name__,
reason,
version,
stacklevel=3,
)
elif arg in provided_args:
warn_deprecated_arg_usage(
arg, func.__name__, reason, version, stacklevel=3
)

return func(*args, **kwargs)

return wrapper

return decorator


def _set_info_logger():
"""
Set up a logger that log info logs to stdout.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def call_with_retry(self, _, __):
mock_conn = mock.AsyncMock(spec=Connection)
mock_conn.retry = Retry_()

async def get_conn(_):
async def get_conn():
# Validate only one client is created in single-client mode when
# concurrent requests are made
nonlocal init_call_count
Expand Down
Loading

0 comments on commit 8427c7b

Please sign in to comment.