Skip to content

Commit ffed4db

Browse files
committed
Add connection_kwargs to result backends
1 parent 77a51b4 commit ffed4db

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

taskiq_redis/redis_backend.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pickle
2-
from typing import Dict, Optional, TypeVar, Union
2+
from typing import Any, Dict, Optional, TypeVar, Union
33

4-
from redis.asyncio import ConnectionPool, Redis
4+
from redis.asyncio import BlockingConnectionPool, Redis
55
from redis.asyncio.cluster import RedisCluster
66
from taskiq import AsyncResultBackend
77
from taskiq.abc.result_backend import TaskiqResult
@@ -24,6 +24,8 @@ def __init__(
2424
keep_results: bool = True,
2525
result_ex_time: Optional[int] = None,
2626
result_px_time: Optional[int] = None,
27+
max_connection_pool_size: Optional[int] = None,
28+
**connection_kwargs: Any,
2729
) -> None:
2830
"""
2931
Constructs a new result backend.
@@ -32,13 +34,19 @@ def __init__(
3234
:param keep_results: flag to not remove results from Redis after reading.
3335
:param result_ex_time: expire time in seconds for result.
3436
:param result_px_time: expire time in milliseconds for result.
37+
:param max_connection_pool_size: maximum number of connections in pool.
38+
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
3539
3640
:raises DuplicateExpireTimeSelectedError: if result_ex_time
3741
and result_px_time are selected.
3842
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
3943
and result_px_time are equal zero.
4044
"""
41-
self.redis_pool = ConnectionPool.from_url(redis_url)
45+
self.redis_pool = BlockingConnectionPool.from_url(
46+
url=redis_url,
47+
max_connections=max_connection_pool_size,
48+
**connection_kwargs,
49+
)
4250
self.keep_results = keep_results
4351
self.result_ex_time = result_ex_time
4452
self.result_px_time = result_px_time
@@ -146,6 +154,7 @@ def __init__(
146154
keep_results: bool = True,
147155
result_ex_time: Optional[int] = None,
148156
result_px_time: Optional[int] = None,
157+
**connection_kwargs: Any,
149158
) -> None:
150159
"""
151160
Constructs a new result backend.
@@ -154,13 +163,17 @@ def __init__(
154163
:param keep_results: flag to not remove results from Redis after reading.
155164
:param result_ex_time: expire time in seconds for result.
156165
:param result_px_time: expire time in milliseconds for result.
166+
:param connection_kwargs: additional arguments for RedisCluster.
157167
158168
:raises DuplicateExpireTimeSelectedError: if result_ex_time
159169
and result_px_time are selected.
160170
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
161171
and result_px_time are equal zero.
162172
"""
163-
self.redis: RedisCluster[bytes] = RedisCluster.from_url(redis_url)
173+
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
174+
redis_url,
175+
**connection_kwargs,
176+
)
164177
self.keep_results = keep_results
165178
self.result_ex_time = result_ex_time
166179
self.result_px_time = result_px_time

tests/test_result_backend.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import uuid
23

34
import pytest
@@ -132,6 +133,38 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
132133
await result_backend.shutdown()
133134

134135

136+
@pytest.mark.anyio
137+
async def test_set_result_max_connections(redis_url: str) -> None:
138+
"""
139+
Tests that asynchronous backend works with connection limit.
140+
141+
:param redis_url: redis URL.
142+
"""
143+
result_backend = RedisAsyncResultBackend( # type: ignore
144+
redis_url=redis_url,
145+
max_connection_pool_size=1,
146+
timeout=3,
147+
)
148+
149+
task_id = uuid.uuid4().hex
150+
result: "TaskiqResult[int]" = TaskiqResult(
151+
is_err=True,
152+
log="My Log",
153+
return_value=11,
154+
execution_time=112.2,
155+
)
156+
await result_backend.set_result(
157+
task_id=task_id,
158+
result=result,
159+
)
160+
161+
async def get_result() -> None:
162+
await result_backend.get_result(task_id=task_id, with_logs=True)
163+
164+
await asyncio.gather(*[get_result() for _ in range(10)])
165+
await result_backend.shutdown()
166+
167+
135168
@pytest.mark.anyio
136169
async def test_set_result_success_cluster(redis_cluster_url: str) -> None:
137170
"""

0 commit comments

Comments
 (0)