From 28964c1ec4fc481141f6025248845c5e22588a41 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 11 Feb 2025 15:58:39 +0200 Subject: [PATCH] Backport from master (5.3.0b5) (#3506) * Fixed flacky TokenManager test (#3468) * Fixed flacky TokenManager test * Fixed additional flacky test * Removed token count assertion * Skipped test on version 3.9 * Fix incorrect attribute reuse (#3456) add CacheEntry Co-authored-by: zhousheng06 Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * Expand type for EncodedT (#3472) As of PEP 688, type checkers will no longer implicitly consider bytearray to be compatible with bytes * Moved self._lock initialisation to Pool constructor (#3473) * Moved self._lock initialisation to Pool constructor * Added test case * Codestyle fixes * Added correct annotations * DOC-4423: add TCEs for various command pages (#3476) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * DOC-4345 added testable JSON search examples for home page (#3407) * DOC-4345 added testable JSON search examples for home page * DOC-4345 avoid possible non-deterministic results in tests * DOC-4345 close connection at end of example * DOC-4345 remove unnecessary blank lines * Adding unit text fixes to improve compatibility with MacOS. (#3486) * Adding unit text fixes to improve compatibility with MacOS. * Applying review comments * Unifying the exception msg validation pattern for both test_connection.py files --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * Add return type to `close` functions (#3496) * Add types to ConnectionPool.from_url (#3495) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * Add types to execute method of pipelines (#3494) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * DOC-4796 fixed capped lists example (#3493) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * typing for client __init__ (#3357) * typing for client __init__ * typing with string literals * retry_on_error more specific typing * retry typing * fix lint --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> * test: Updated CredentialProvider test infrastructure (#3502) * test: Updated CredentialProvider test infrastructure * Added linter exclusion * Updated dev dependency * Codestyle fixes * Updated async test infra * Added missing constant * Updated package version * Updated testing versions and docs * Updated server versions * Fixed test --------- Co-authored-by: zs-neo <48560952+zs-neo@users.noreply.github.com> Co-authored-by: zhousheng06 Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Co-authored-by: David Dougherty Co-authored-by: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Co-authored-by: petyaslavova Co-authored-by: Patrick Arminio Co-authored-by: Artur Mostowski --- .github/actions/run-tests/action.yml | 6 +- .github/workflows/integration.yaml | 4 +- dev_requirements.txt | 2 +- docker-compose.yml | 3 +- docs/advanced_features.rst | 15 ++- doctests/cmds_cnxmgmt.py | 36 +++++++ doctests/cmds_hash.py | 24 +++++ doctests/cmds_list.py | 123 +++++++++++++++++++++++ doctests/cmds_servermgmt.py | 30 ++++++ doctests/cmds_set.py | 35 +++++++ doctests/dt_list.py | 6 +- doctests/home_json.py | 137 ++++++++++++++++++++++++++ redis/asyncio/client.py | 2 +- redis/client.py | 104 ++++++++++--------- redis/cluster.py | 6 +- redis/connection.py | 15 ++- redis/typing.py | 2 +- setup.py | 2 +- tests/conftest.py | 99 ++++++++++++------- tests/test_asyncio/conftest.py | 97 +++++++++++------- tests/test_asyncio/test_connection.py | 30 +++--- tests/test_auth/test_token_manager.py | 36 +++---- tests/test_commands.py | 2 +- tests/test_connection.py | 25 ++--- tests/test_connection_pool.py | 28 +++++- tests/test_multiprocessing.py | 4 + 26 files changed, 670 insertions(+), 203 deletions(-) create mode 100644 doctests/cmds_cnxmgmt.py create mode 100644 doctests/cmds_list.py create mode 100644 doctests/cmds_servermgmt.py create mode 100644 doctests/cmds_set.py create mode 100644 doctests/home_json.py diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 5ca6bf5a09..e5dcef03ff 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -56,9 +56,9 @@ runs: # Mapping of redis version to stack version declare -A redis_stack_version_mapping=( - ["7.4.1"]="7.4.0-v1" - ["7.2.6"]="7.2.0-v13" - ["6.2.16"]="6.2.6-v17" + ["7.4.2"]="7.4.0-v3" + ["7.2.7"]="7.2.0-v15" + ["6.2.17"]="6.2.6-v19" ) if [[ -v redis_stack_version_mapping[$REDIS_VERSION] ]]; then diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 7c74de5290..7e92cfb92d 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -29,7 +29,7 @@ env: COVERAGE_CORE: sysmon REDIS_IMAGE: redis:latest REDIS_STACK_IMAGE: redis/redis-stack-server:latest - CURRENT_REDIS_VERSION: '7.4.1' + CURRENT_REDIS_VERSION: '7.4.2' jobs: dependency-audit: @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16'] + redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.12'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/dev_requirements.txt b/dev_requirements.txt index be74470ec2..728536d6fb 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,4 +16,4 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 -redis-entraid==0.1.0b1 +redis-entraid==0.3.0b1 diff --git a/docker-compose.yml b/docker-compose.yml index 7804f09c8a..60657d5653 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -103,7 +103,7 @@ services: - all redis-stack: - image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:edge} + image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:latest} container_name: redis-stack ports: - 6479:6379 @@ -112,6 +112,7 @@ services: profiles: - standalone - all-stack + - all redis-stack-graph: image: redis/redis-stack-server:6.2.6-v15 diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 10b7b4681b..cebf241e6c 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -471,31 +471,30 @@ command is received. Token-based authentication ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Since redis-py version 5.3.0 new StreamableCredentialProvider interface was introduced. -This interface describes a CredentialProvider with an ability to stream an events that will be handled by listener. +Since redis-py version 5.3.0 new `StreamableCredentialProvider` interface was introduced. +This interface describes a `CredentialProvider` with an ability to stream an events that will be handled by listener. -To keep redis-py with minimal dependencies needed to run it, we decided to separate StreamableCredentialProvider +To keep redis-py with minimal dependencies needed to run it, we decided to separate `StreamableCredentialProvider` implementations in a separate packages. So If you're interested to try this feature please add them as a separate dependency to your project. `EntraIdCredentialProvider` is a first implementation that allows you to integrate redis-py with Azure Cache for Redis -service. It will allows you to obtain a tokens from Microsoft EntraID and authenticate/re-authenticate your connections +service. It will allows you to obtain a tokens from `Microsoft EntraID` and authenticate/re-authenticate your connections with it in a background mode. To get `EntraIdCredentialProvider` you need to install following package: `pip install redis-entraid` -To setup a credential provider, first you have to create and configure an IdentityProvider and provide -TokenAuthConfig object. +To setup a credential provider, please use one of the factory methods bundled with package. `Here's a quick guide how to do this -`_ +`_ Now all you have to do is to pass an instance of `EntraIdCredentialProvider` via constructor, available for sync and async clients: .. code:: python - >>> cred_provider = EntraIdCredentialProvider(auth_config) + >>> cred_provider = create_from_service_principal(CLIENT_ID, CLIENT_SECRET, TENANT_ID) >>> r = Redis(credential_provider=cred_provider) >>> r.ping() diff --git a/doctests/cmds_cnxmgmt.py b/doctests/cmds_cnxmgmt.py new file mode 100644 index 0000000000..c691f723cf --- /dev/null +++ b/doctests/cmds_cnxmgmt.py @@ -0,0 +1,36 @@ +# EXAMPLE: cmds_cnxmgmt +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START auth1 +# REMOVE_START +r.config_set("requirepass", "temp_pass") +# REMOVE_END +res1 = r.auth(password="temp_pass") +print(res1) # >>> True + +res2 = r.auth(password="temp_pass", username="default") +print(res2) # >>> True + +# REMOVE_START +assert res1 == True +assert res2 == True +r.config_set("requirepass", "") +# REMOVE_END +# STEP_END + +# STEP_START auth2 +# REMOVE_START +r.acl_setuser("test-user", enabled=True, passwords=["+strong_password"], commands=["+acl"]) +# REMOVE_END +res = r.auth(username="test-user", password="strong_password") +print(res) # >>> True + +# REMOVE_START +assert res == True +r.acl_deluser("test-user") +# REMOVE_END +# STEP_END diff --git a/doctests/cmds_hash.py b/doctests/cmds_hash.py index 0bc1cb8038..65bbd52d60 100644 --- a/doctests/cmds_hash.py +++ b/doctests/cmds_hash.py @@ -61,3 +61,27 @@ r.delete("myhash") # REMOVE_END # STEP_END + +# STEP_START hgetall +res10 = r.hset("myhash", mapping={"field1": "Hello", "field2": "World"}) + +res11 = r.hgetall("myhash") +print(res11) # >>> { "field1": "Hello", "field2": "World" } + +# REMOVE_START +assert res11 == { "field1": "Hello", "field2": "World" } +r.delete("myhash") +# REMOVE_END +# STEP_END + +# STEP_START hvals +res10 = r.hset("myhash", mapping={"field1": "Hello", "field2": "World"}) + +res11 = r.hvals("myhash") +print(res11) # >>> [ "Hello", "World" ] + +# REMOVE_START +assert res11 == [ "Hello", "World" ] +r.delete("myhash") +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_list.py b/doctests/cmds_list.py new file mode 100644 index 0000000000..cce2d540a8 --- /dev/null +++ b/doctests/cmds_list.py @@ -0,0 +1,123 @@ +# EXAMPLE: cmds_list +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START lpush +res1 = r.lpush("mylist", "world") +print(res1) # >>> 1 + +res2 = r.lpush("mylist", "hello") +print(res2) # >>> 2 + +res3 = r.lrange("mylist", 0, -1) +print(res3) # >>> [ "hello", "world" ] + +# REMOVE_START +assert res3 == [ "hello", "world" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START lrange +res4 = r.rpush("mylist", "one"); +print(res4) # >>> 1 + +res5 = r.rpush("mylist", "two") +print(res5) # >>> 2 + +res6 = r.rpush("mylist", "three") +print(res6) # >>> 3 + +res7 = r.lrange('mylist', 0, 0) +print(res7) # >>> [ 'one' ] + +res8 = r.lrange('mylist', -3, 2) +print(res8) # >>> [ 'one', 'two', 'three' ] + +res9 = r.lrange('mylist', -100, 100) +print(res9) # >>> [ 'one', 'two', 'three' ] + +res10 = r.lrange('mylist', 5, 10) +print(res10) # >>> [] + +# REMOVE_START +assert res7 == [ 'one' ] +assert res8 == [ 'one', 'two', 'three' ] +assert res9 == [ 'one', 'two', 'three' ] +assert res10 == [] +r.delete('mylist') +# REMOVE_END +# STEP_END + +# STEP_START llen +res11 = r.lpush("mylist", "World") +print(res11) # >>> 1 + +res12 = r.lpush("mylist", "Hello") +print(res12) # >>> 2 + +res13 = r.llen("mylist") +print(res13) # >>> 2 + +# REMOVE_START +assert res13 == 2 +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START rpush +res14 = r.rpush("mylist", "hello") +print(res14) # >>> 1 + +res15 = r.rpush("mylist", "world") +print(res15) # >>> 2 + +res16 = r.lrange("mylist", 0, -1) +print(res16) # >>> [ "hello", "world" ] + +# REMOVE_START +assert res16 == [ "hello", "world" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START lpop +res17 = r.rpush("mylist", *["one", "two", "three", "four", "five"]) +print(res17) # >>> 5 + +res18 = r.lpop("mylist") +print(res18) # >>> "one" + +res19 = r.lpop("mylist", 2) +print(res19) # >>> ['two', 'three'] + +res17 = r.lrange("mylist", 0, -1) +print(res17) # >>> [ "four", "five" ] + +# REMOVE_START +assert res17 == [ "four", "five" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START rpop +res18 = r.rpush("mylist", *["one", "two", "three", "four", "five"]) +print(res18) # >>> 5 + +res19 = r.rpop("mylist") +print(res19) # >>> "five" + +res20 = r.rpop("mylist", 2) +print(res20) # >>> ['four', 'three'] + +res21 = r.lrange("mylist", 0, -1) +print(res21) # >>> [ "one", "two" ] + +# REMOVE_START +assert res21 == [ "one", "two" ] +r.delete("mylist") +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_servermgmt.py b/doctests/cmds_servermgmt.py new file mode 100644 index 0000000000..6ad2b6acb2 --- /dev/null +++ b/doctests/cmds_servermgmt.py @@ -0,0 +1,30 @@ +# EXAMPLE: cmds_servermgmt +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START flushall +# REMOVE_START +r.set("foo", "1") +r.set("bar", "2") +r.set("baz", "3") +# REMOVE_END +res1 = r.flushall(asynchronous=False) +print(res1) # >>> True + +res2 = r.keys() +print(res2) # >>> [] + +# REMOVE_START +assert res1 == True +assert res2 == [] +# REMOVE_END +# STEP_END + +# STEP_START info +res3 = r.info() +print(res3) +# >>> {'redis_version': '7.4.0', 'redis_git_sha1': 'c9d29f6a',...} +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_set.py b/doctests/cmds_set.py new file mode 100644 index 0000000000..ece74e8cf0 --- /dev/null +++ b/doctests/cmds_set.py @@ -0,0 +1,35 @@ +# EXAMPLE: cmds_set +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START sadd +res1 = r.sadd("myset", "Hello", "World") +print(res1) # >>> 2 + +res2 = r.sadd("myset", "World") +print(res2) # >>> 0 + +res3 = r.smembers("myset") +print(res3) # >>> {'Hello', 'World'} + +# REMOVE_START +assert res3 == {'Hello', 'World'} +r.delete('myset') +# REMOVE_END +# STEP_END + +# STEP_START smembers +res4 = r.sadd("myset", "Hello", "World") +print(res4) # >>> 2 + +res5 = r.smembers("myset") +print(res5) # >>> {'Hello', 'World'} + +# REMOVE_START +assert res5 == {'Hello', 'World'} +r.delete('myset') +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/dt_list.py b/doctests/dt_list.py index be8a4b8562..111da8eb08 100644 --- a/doctests/dt_list.py +++ b/doctests/dt_list.py @@ -165,20 +165,20 @@ # REMOVE_END # STEP_START ltrim -res27 = r.lpush("bikes:repairs", "bike:1", "bike:2", "bike:3", "bike:4", "bike:5") +res27 = r.rpush("bikes:repairs", "bike:1", "bike:2", "bike:3", "bike:4", "bike:5") print(res27) # >>> 5 res28 = r.ltrim("bikes:repairs", 0, 2) print(res28) # >>> True res29 = r.lrange("bikes:repairs", 0, -1) -print(res29) # >>> ['bike:5', 'bike:4', 'bike:3'] +print(res29) # >>> ['bike:1', 'bike:2', 'bike:3'] # STEP_END # REMOVE_START assert res27 == 5 assert res28 is True -assert res29 == ["bike:5", "bike:4", "bike:3"] +assert res29 == ["bike:1", "bike:2", "bike:3"] r.delete("bikes:repairs") # REMOVE_END diff --git a/doctests/home_json.py b/doctests/home_json.py new file mode 100644 index 0000000000..922c83d2fe --- /dev/null +++ b/doctests/home_json.py @@ -0,0 +1,137 @@ +# EXAMPLE: py_home_json +""" +JSON examples from redis-py "home" page" + https://redis.io/docs/latest/develop/connect/clients/python/redis-py/#example-indexing-and-querying-json-documents +""" + +# STEP_START import +import redis +from redis.commands.json.path import Path +import redis.commands.search.aggregation as aggregations +import redis.commands.search.reducers as reducers +from redis.commands.search.field import TextField, NumericField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +import redis.exceptions +# STEP_END + +# STEP_START connect +r = redis.Redis(decode_responses=True) +# STEP_END + +# REMOVE_START +try: + r.ft("idx:users").dropindex(True) +except redis.exceptions.ResponseError: + pass + +r.delete("user:1", "user:2", "user:3") +# REMOVE_END +# STEP_START create_data +user1 = { + "name": "Paul John", + "email": "paul.john@example.com", + "age": 42, + "city": "London" +} + +user2 = { + "name": "Eden Zamir", + "email": "eden.zamir@example.com", + "age": 29, + "city": "Tel Aviv" +} + +user3 = { + "name": "Paul Zamir", + "email": "paul.zamir@example.com", + "age": 35, + "city": "Tel Aviv" +} +# STEP_END + +# STEP_START make_index +schema = ( + TextField("$.name", as_name="name"), + TagField("$.city", as_name="city"), + NumericField("$.age", as_name="age") +) + +indexCreated = r.ft("idx:users").create_index( + schema, + definition=IndexDefinition( + prefix=["user:"], index_type=IndexType.JSON + ) +) +# STEP_END +# Tests for 'make_index' step. +# REMOVE_START +assert indexCreated +# REMOVE_END + +# STEP_START add_data +user1Set = r.json().set("user:1", Path.root_path(), user1) +user2Set = r.json().set("user:2", Path.root_path(), user2) +user3Set = r.json().set("user:3", Path.root_path(), user3) +# STEP_END +# Tests for 'add_data' step. +# REMOVE_START +assert user1Set +assert user2Set +assert user3Set +# REMOVE_END + +# STEP_START query1 +findPaulResult = r.ft("idx:users").search( + Query("Paul @age:[30 40]") +) + +print(findPaulResult) +# >>> Result{1 total, docs: [Document {'id': 'user:3', ... +# STEP_END +# Tests for 'query1' step. +# REMOVE_START +assert str(findPaulResult) == ( + "Result{1 total, docs: [Document {'id': 'user:3', 'payload': None, " + + "'json': '{\"name\":\"Paul Zamir\",\"email\":" + + "\"paul.zamir@example.com\",\"age\":35,\"city\":\"Tel Aviv\"}'}]}" +) +# REMOVE_END + +# STEP_START query2 +citiesResult = r.ft("idx:users").search( + Query("Paul").return_field("$.city", as_field="city") +).docs + +print(citiesResult) +# >>> [Document {'id': 'user:1', 'payload': None, ... +# STEP_END +# Tests for 'query2' step. +# REMOVE_START +citiesResult.sort(key=lambda doc: doc['id']) + +assert str(citiesResult) == ( + "[Document {'id': 'user:1', 'payload': None, 'city': 'London'}, " + + "Document {'id': 'user:3', 'payload': None, 'city': 'Tel Aviv'}]" +) +# REMOVE_END + +# STEP_START query3 +req = aggregations.AggregateRequest("*").group_by( + '@city', reducers.count().alias('count') +) + +aggResult = r.ft("idx:users").aggregate(req).rows +print(aggResult) +# >>> [['city', 'London', 'count', '1'], ['city', 'Tel Aviv', 'count', '2']] +# STEP_END +# Tests for 'query3' step. +# REMOVE_START +aggResult.sort(key=lambda row: row[1]) + +assert str(aggResult) == ( + "[['city', 'London', 'count', '1'], ['city', 'Tel Aviv', 'count', '2']]" +) +# REMOVE_END + +r.close() diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9478d539d7..7c17938714 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1554,7 +1554,7 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): await self.reset() raise - async def execute(self, raise_on_error: bool = True): + async def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: diff --git a/redis/client.py b/redis/client.py index a7c1364a10..5a9f4fafb5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,7 +4,17 @@ import time import warnings from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Type, + Union, +) from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -53,6 +63,11 @@ str_if_bytes, ) +if TYPE_CHECKING: + import ssl + + import OpenSSL + SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -175,47 +190,47 @@ def from_pool( def __init__( self, - host="localhost", - port=6379, - db=0, - password=None, - socket_timeout=None, - socket_connect_timeout=None, - socket_keepalive=None, - socket_keepalive_options=None, - connection_pool=None, - unix_socket_path=None, - encoding="utf-8", - encoding_errors="strict", - charset=None, - errors=None, - decode_responses=False, - retry_on_timeout=False, - retry_on_error=None, - ssl=False, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_path=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - max_connections=None, - single_connection_client=False, - health_check_interval=0, - client_name=None, - lib_name="redis-py", - lib_version=get_lib_version(), - username=None, - retry=None, - redis_connect_func=None, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + charset: Optional[str] = None, + errors: Optional[str] = None, + decode_responses: bool = False, + retry_on_timeout: bool = False, + retry_on_error: Optional[List[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = False, + ssl_password: Optional[str] = None, + ssl_validate_ocsp: bool = False, + ssl_validate_ocsp_stapled: bool = False, + ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, + ssl_ocsp_expected_cert: Optional[str] = None, + ssl_min_version: Optional["ssl.TLSVersion"] = None, + ssl_ciphers: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, cache: Optional[CacheInterface] = None, @@ -550,7 +565,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): self.close() - def close(self): + def close(self) -> None: # In case a connection property does not yet exist # (due to a crash earlier in the Redis() constructor), return # immediately as there is nothing to clean-up. @@ -1551,11 +1566,10 @@ def _disconnect_raise_reset( conn.retry_on_error is None or isinstance(error, tuple(conn.retry_on_error)) is False ): - self.reset() raise error - def execute(self, raise_on_error=True): + def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: diff --git a/redis/cluster.py b/redis/cluster.py index 38bd5dde1a..e8f47afe25 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1244,7 +1244,7 @@ def _execute_command(self, target_node, *args, **kwargs): raise ClusterError("TTL exhausted.") - def close(self): + def close(self) -> None: try: with self._lock: if self.nodes_manager: @@ -1686,7 +1686,7 @@ def initialize(self): # If initialize was called after a MovedError, clear it self._moved_exception = None - def close(self): + def close(self) -> None: self.default_node = None for node in self.nodes_cache.values(): if node.redis_connection: @@ -2067,7 +2067,7 @@ def annotate_exception(self, exception, number, command): ) exception.args = (msg,) + exception.args[1:] - def execute(self, raise_on_error=True): + def execute(self, raise_on_error: bool = True) -> List[Any]: """ Execute all the commands in the current pipeline """ diff --git a/redis/connection.py b/redis/connection.py index 9d29b4aba6..d47f46590b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -904,9 +904,11 @@ def read_response( and self._cache.get(self._current_command_cache_key).status != CacheEntryStatus.IN_PROGRESS ): - return copy.deepcopy( + res = copy.deepcopy( self._cache.get(self._current_command_cache_key).cache_value ) + self._current_command_cache_key = None + return res response = self._conn.read_response( disable_decoding=disable_decoding, @@ -932,6 +934,8 @@ def read_response( cache_entry.cache_value = response self._cache.set(cache_entry) + self._current_command_cache_key = None + return response def pack_command(self, *args): @@ -1259,6 +1263,9 @@ def parse_url(url): return kwargs +_CP = TypeVar("_CP", bound="ConnectionPool") + + class ConnectionPool: """ Create a connection pool. ``If max_connections`` is set, then this @@ -1274,7 +1281,7 @@ class ConnectionPool: """ @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: """ Return a connection pool configured from the given URL. @@ -1374,6 +1381,7 @@ def __init__( # will notice the first thread already did the work and simply # release the lock. self._fork_lock = threading.Lock() + self._lock = threading.Lock() self.reset() def __repr__(self) -> (str, str): @@ -1391,7 +1399,6 @@ def get_protocol(self): return self.connection_kwargs.get("protocol", None) def reset(self) -> None: - self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() diff --git a/redis/typing.py b/redis/typing.py index b4d442c444..24ad607480 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -20,7 +20,7 @@ Number = Union[int, float] -EncodedT = Union[bytes, memoryview] +EncodedT = Union[bytes, bytearray, memoryview] DecodedT = Union[str, int, float] EncodableT = Union[EncodedT, DecodedT] AbsExpiryT = Union[int, datetime] diff --git a/setup.py b/setup.py index 167cd5ee07..81bbedfe9f 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.3.0b4", + version="5.3.0b5", packages=find_packages( include=[ "redis", diff --git a/tests/conftest.py b/tests/conftest.py index a900cea8bf..fc732c0d72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import time from datetime import datetime, timezone from enum import Enum -from typing import Callable, TypeVar +from typing import Callable, TypeVar, Union from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -17,6 +17,7 @@ from redis import Sentinel from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -29,12 +30,21 @@ from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.ssl_utils import get_tls_certificates @@ -623,17 +633,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) - return _get_service_principal_provider(request) +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + + return _get_service_principal_provider_config(request) + + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -641,23 +667,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -671,14 +698,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 8833426af1..fb6c51140e 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -17,14 +17,24 @@ from redis.asyncio.retry import Retry from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.credentials import CredentialProvider -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.conftest import REDIS_INFO @@ -255,17 +265,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + + +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} - return _get_service_principal_provider(request) + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + return _get_service_principal_provider_config(request) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -273,23 +299,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -303,14 +330,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -322,31 +349,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index e584fc6999..d4956f16e9 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -1,6 +1,7 @@ import asyncio import socket import types +from errno import ECONNREFUSED from unittest.mock import patch import pytest @@ -36,15 +37,16 @@ async def test_invalid_response(create_redis): fake_stream = MockStream(raw + b"\r\n") parser: _AsyncRESPBase = r.connection._parser - with mock.patch.object(parser, "_stream", fake_stream): - with pytest.raises(InvalidResponse) as cm: - await parser.read_response() + if isinstance(parser, _AsyncRESPBase): - assert str(cm.value) == f"Protocol Error: {raw!r}" + exp_err = f"Protocol Error: {raw!r}" else: - assert ( - str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte' - ) + exp_err = f'Protocol error, got "{raw.decode()}" as reply type byte' + + with mock.patch.object(parser, "_stream", fake_stream): + with pytest.raises(InvalidResponse, match=exp_err): + await parser.read_response() + await r.connection.disconnect() @@ -170,10 +172,9 @@ async def test_connect_timeout_error_without_retry(): conn._connect = mock.AsyncMock() conn._connect.side_effect = socket.timeout - with pytest.raises(TimeoutError) as e: + with pytest.raises(TimeoutError, match="Timeout connecting to server"): await conn.connect() assert conn._connect.call_count == 1 - assert str(e.value) == "Timeout connecting to server" @pytest.mark.onlynoncluster @@ -531,17 +532,14 @@ async def test_format_error_message(conn, error, expected_message): async def test_network_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = rf"^Error {ECONNREFUSED} connecting to 127.0.0.1:9999.(.+)$" + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(host="127.0.0.1", port=9999) await redis.set("a", "b") - assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect") async def test_unix_socket_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(unix_socket_path="unix:///tmp/a.sock") await redis.set("a", "b") - assert ( - str(e.value) - == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." - ) diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py index bb396e246c..cdbf60889d 100644 --- a/tests/test_auth/test_token_manager.py +++ b/tests/test_auth/test_token_manager.py @@ -17,17 +17,17 @@ class TestTokenManager: @pytest.mark.parametrize( - "exp_refresh_ratio,tokens_refreshed", + "exp_refresh_ratio", [ - (0.9, 2), - (0.28, 4), + 0.9, + 0.28, ], ids=[ - "Refresh ratio = 0.9, 2 tokens in 0,1 second", - "Refresh ratio = 0.28, 4 tokens in 0,1 second", + "Refresh ratio = 0.9", + "Refresh ratio = 0.28", ], ) - def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): + def test_success_token_renewal(self, exp_refresh_ratio): tokens = [] mock_provider = Mock(spec=IdentityProviderInterface) mock_provider.request_token.side_effect = [ @@ -39,14 +39,14 @@ def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 130, - (datetime.now(timezone.utc).timestamp() * 1000) + 30, + (datetime.now(timezone.utc).timestamp() * 1000) + 150, + (datetime.now(timezone.utc).timestamp() * 1000) + 50, {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 160, - (datetime.now(timezone.utc).timestamp() * 1000) + 60, + (datetime.now(timezone.utc).timestamp() * 1000) + 170, + (datetime.now(timezone.utc).timestamp() * 1000) + 70, {"oid": "test"}, ), SimpleToken( @@ -70,7 +70,7 @@ def on_next(token): mgr.start(mock_listener) sleep(0.1) - assert len(tokens) == tokens_refreshed + assert len(tokens) > 0 @pytest.mark.parametrize( "exp_refresh_ratio,tokens_refreshed", @@ -176,19 +176,13 @@ def test_token_renewal_with_skip_initial(self): mock_provider.request_token.side_effect = [ SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000) + 50, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 120, - (datetime.now(timezone.utc).timestamp() * 1000), - {"oid": "test"}, - ), - SimpleToken( - "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000) + 150, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), @@ -207,9 +201,9 @@ def on_next(token): mgr.start(mock_listener, skip_initial=True) # Should be less than a 0.1, or it will be flacky due to # additional token renewal. - sleep(0.2) + sleep(0.1) - assert len(tokens) == 2 + assert len(tokens) == 1 @pytest.mark.asyncio async def test_async_token_renewal_with_skip_initial(self): diff --git a/tests/test_commands.py b/tests/test_commands.py index 4cad4c14b6..0f5a9c7b16 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -4345,7 +4345,7 @@ def test_xgroup_create_entriesread(self, r: redis.Redis): "pending": 0, "last-delivered-id": b"0-0", "entries-read": 7, - "lag": -6, + "lag": 1, } ] assert r.xinfo_groups(stream) == expected diff --git a/tests/test_connection.py b/tests/test_connection.py index 7683a1416d..6c1498a329 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,8 +1,10 @@ import copy import platform import socket +import sys import threading import types +from errno import ECONNREFUSED from typing import Any from unittest import mock from unittest.mock import call, patch @@ -43,9 +45,8 @@ def test_invalid_response(r): raw = b"x" parser = r.connection._parser with mock.patch.object(parser._buffer, "readline", return_value=raw): - with pytest.raises(InvalidResponse) as cm: + with pytest.raises(InvalidResponse, match=f"Protocol Error: {raw!r}"): parser.read_response() - assert str(cm.value) == f"Protocol Error: {raw!r}" @skip_if_server_version_lt("4.0.0") @@ -140,10 +141,9 @@ def test_connect_timeout_error_without_retry(self): conn._connect = mock.Mock() conn._connect.side_effect = socket.timeout - with pytest.raises(TimeoutError) as e: + with pytest.raises(TimeoutError, match="Timeout connecting to server"): conn.connect() assert conn._connect.call_count == 1 - assert str(e.value) == "Timeout connecting to server" self.clear(conn) @@ -249,6 +249,7 @@ def get_redis_connection(): r1.close() +@pytest.mark.skipif(sys.version_info == (3, 9), reason="Flacky test on Python 3.9") @pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) def test_redis_connection_pool(request, from_url): """Verify that basic Redis instances using `connection_pool` @@ -347,20 +348,17 @@ def test_format_error_message(conn, error, expected_message): def test_network_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = f"Error {ECONNREFUSED} connecting to localhost:9999. Connection refused." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(port=9999) redis.set("a", "b") - assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused." def test_unix_socket_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(unix_socket_path="unix:///tmp/a.sock") redis.set("a", "b") - assert ( - str(e.value) - == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." - ) class TestUnitConnectionPool: @@ -499,9 +497,9 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): ) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) assert proxy_connection.read_response() == b"bar" + assert proxy_connection._current_command_cache_key is None assert proxy_connection.read_response() == b"bar" - mock_connection.read_response.assert_called_once() mock_cache.set.assert_has_calls( [ call( @@ -528,9 +526,6 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): call(CacheKey(command="GET", redis_keys=("foo",))), call(CacheKey(command="GET", redis_keys=("foo",))), call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), ] ) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index dee7c554d3..118294ee1b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,10 +7,16 @@ import pytest import redis -from redis.connection import to_bool -from redis.utils import SSL_AVAILABLE - -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from redis.cache import CacheConfig +from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.utils import HIREDIS_AVAILABLE, SSL_AVAILABLE + +from .conftest import ( + _get_client, + skip_if_redis_enterprise, + skip_if_resp_version, + skip_if_server_version_lt, +) from .test_pubsub import wait_for_message @@ -196,6 +202,20 @@ def test_repr_contains_db_info_unix(self): expected = "path=abc,db=0,client_name=test-client" assert expected in repr(pool) + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @pytest.mark.onlynoncluster + @skip_if_resp_version(2) + @skip_if_server_version_lt("7.4.0") + def test_initialise_pool_with_cache(self, master_host): + pool = redis.BlockingConnectionPool( + connection_class=Connection, + host=master_host[0], + port=master_host[1], + protocol=3, + cache_config=CacheConfig(), + ) + assert isinstance(pool.get_connection("_"), CacheProxyConnection) + class TestConnectionPoolURLParsing: def test_hostname(self): diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 5cda3190a6..116d20dab0 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,5 +1,6 @@ import contextlib import multiprocessing +import sys import pytest import redis @@ -8,6 +9,9 @@ from .conftest import _get_client +if sys.platform == "darwin": + multiprocessing.set_start_method("fork", force=True) + @contextlib.contextmanager def exit_callback(callback, *args):