Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix async client safety #3512

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ using `invoke standalone-tests`; similarly, RedisCluster tests can be run by usi
Each run of tests starts and stops the various dockers required. Sometimes
things get stuck, an `invoke clean` can help.

## Linting and Formatting

Call `invoke linters` to run linters without also running tests. This command will
only report issues, not fix them automatically. Run `invoke formatters` to
automatically format your code.

## Documentation

If relevant, update the code documentation, via docstrings, or in `/docs`.
Expand Down
40 changes: 38 additions & 2 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ def __init__(
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

# When used as an async context manager, we need to increment and decrement
# a usage counter so that we can close the connection pool when no one is
# using the client.
self._usage_counter = 0
self._usage_lock = asyncio.Lock()

def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
Expand Down Expand Up @@ -562,10 +568,40 @@ def client(self) -> "Redis":
)

async def __aenter__(self: _RedisT) -> _RedisT:
return await self.initialize()
"""
Async context manager entry. Increments a usage counter so that the
connection pool is only closed (via aclose()) when no context is using
the client.
"""
async with self._usage_lock:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest adding another function _increment_usage for this operation. It will be easier to read and follow the code if here you call _increment_usage and below in the except clause you call _decrement_usage. The same applies to cluster.py changes as well.

self._usage_counter += 1
try:
# Initialize the client (i.e. establish connection, etc.)
return await self.initialize()
except Exception:
# If initialization fails, decrement the counter to keep it in sync
async with self._usage_lock:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can directly use the new function _decrement_usage here. The same applies to cluster.py changes as well.

self._usage_counter -= 1
raise

async def _decrement_usage(self) -> int:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A helper method is required so we can use it in the shield().

"""
Helper coroutine to decrement the usage counter while holding the lock.
Returns the new value of the usage counter.
"""
async with self._usage_lock:
self._usage_counter -= 1
return self._usage_counter

async def __aexit__(self, exc_type, exc_value, traceback):
await self.aclose()
"""
Async context manager exit. Decrements a usage counter. If this is the
last exit (counter becomes zero), the client closes its connection pool.
"""
current_usage = await asyncio.shield(self._decrement_usage())
if current_usage == 0:
# This was the last active context, so disconnect the pool.
await asyncio.shield(self.aclose())

_DEL_MESSAGE = "Unclosed Redis client"

Expand Down
42 changes: 39 additions & 3 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ def __init__(
self._initialize = True
self._lock: Optional[asyncio.Lock] = None

# When used as an async context manager, we need to increment and decrement
# a usage counter so that we can close the connection pool when no one is
# using the client.
self._usage_counter = 0
self._usage_lock = asyncio.Lock()

async def initialize(self) -> "RedisCluster":
"""Get all nodes from startup nodes & creates connections if not initialized."""
if self._initialize:
Expand Down Expand Up @@ -415,10 +421,40 @@ async def close(self) -> None:
await self.aclose()

async def __aenter__(self) -> "RedisCluster":
return await self.initialize()
"""
Async context manager entry. Increments a usage counter so that the
connection pool is only closed (via aclose()) when no context is using
the client.
"""
async with self._usage_lock:
self._usage_counter += 1
try:
# Initialize the client (i.e. establish connection, etc.)
return await self.initialize()
except Exception:
# If initialization fails, decrement the counter to keep it in sync
async with self._usage_lock:
self._usage_counter -= 1
raise

async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
await self.aclose()
async def _decrement_usage(self) -> int:
"""
Helper coroutine to decrement the usage counter while holding the lock.
Returns the new value of the usage counter.
"""
async with self._usage_lock:
self._usage_counter -= 1
return self._usage_counter

async def __aexit__(self, exc_type, exc_value, traceback):
"""
Async context manager exit. Decrements a usage counter. If this is the
last exit (counter becomes zero), the client closes its connection pool.
"""
current_usage = await asyncio.shield(self._decrement_usage())
if current_usage == 0:
# This was the last active context, so disconnect the pool.
await asyncio.shield(self.aclose())

def __await__(self) -> Generator[Any, None, "RedisCluster"]:
return self.initialize().__await__()
Expand Down
6 changes: 6 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def linters(c):
run("vulture redis whitelist.py --min-confidence 80")
run("flynt --fail-on-change --dry-run tests redis")

@task
def formatters(c):
"""Format code"""
run("black --target-version py37 tests redis")
run("isort tests redis")


@task
def all_tests(c):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_asyncio/test_usage_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio

import pytest


@pytest.mark.asyncio
async def test_usage_counter(r):
async def dummy_task():
async with r:
await asyncio.sleep(0.01)

tasks = [dummy_task() for _ in range(20)]
await asyncio.gather(*tasks)

# After all tasks have completed, the usage counter should be back to zero.
assert r._usage_counter == 0
Loading