1
1
import pickle
2
- from typing import TypeVar
2
+ from typing import Dict , Optional , TypeVar , Union
3
3
4
4
from redis .asyncio import ConnectionPool , Redis
5
5
from taskiq import AsyncResultBackend
6
6
from taskiq .abc .result_backend import TaskiqResult
7
7
8
+ from taskiq_redis .exceptions import (
9
+ DuplicateExpireTimeSelectedError ,
10
+ ExpireTimeMustBeMoreThanZeroError ,
11
+ )
12
+
8
13
_ReturnType = TypeVar ("_ReturnType" )
9
14
10
15
11
16
class RedisAsyncResultBackend (AsyncResultBackend [_ReturnType ]):
12
17
"""Async result based on redis."""
13
18
14
- def __init__ (self , redis_url : str , keep_results : bool = True ):
19
+ def __init__ (
20
+ self ,
21
+ redis_url : str ,
22
+ keep_results : bool = True ,
23
+ result_ex_time : Optional [int ] = None ,
24
+ result_px_time : Optional [int ] = None ,
25
+ ):
15
26
"""
16
27
Constructs a new result backend.
17
28
18
29
:param redis_url: url to redis.
19
30
:param keep_results: flag to not remove results from Redis after reading.
31
+ :param result_ex_time: expire time in seconds for result.
32
+ :param result_px_time: expire time in milliseconds for result.
33
+
34
+ :raises DuplicateExpireTimeSelectedError: if result_ex_time
35
+ and result_px_time are selected.
36
+ :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
37
+ and result_px_time are equal zero.
20
38
"""
21
39
self .redis_pool = ConnectionPool .from_url (redis_url )
22
40
self .keep_results = keep_results
41
+ self .result_ex_time = result_ex_time
42
+ self .result_px_time = result_px_time
43
+
44
+ if self .result_ex_time == 0 or self .result_px_time == 0 :
45
+ raise ExpireTimeMustBeMoreThanZeroError (
46
+ "You must select one expire time param and it must be more than zero." ,
47
+ )
48
+
49
+ if self .result_ex_time and self .result_px_time :
50
+ raise DuplicateExpireTimeSelectedError (
51
+ "Choose either result_ex_time or result_px_time." ,
52
+ )
53
+
54
+ if not self .result_ex_time and not self .result_px_time :
55
+ self .result_ex_time = 60
23
56
24
57
async def shutdown (self ) -> None :
25
58
"""Closes redis connection."""
26
59
await self .redis_pool .disconnect ()
60
+ await super ().shutdown ()
27
61
28
62
async def set_result (
29
63
self ,
@@ -39,19 +73,17 @@ async def set_result(
39
73
:param task_id: ID of the task.
40
74
:param result: TaskiqResult instance.
41
75
"""
42
- result_dict = result .dict (exclude = {"return_value" })
43
-
44
- for result_key , result_value in result_dict .items ():
45
- result_dict [result_key ] = pickle .dumps (result_value )
46
- # This trick will preserve original returned value.
47
- # It helps when you return not serializable classes.
48
- result_dict ["return_value" ] = pickle .dumps (result .return_value )
76
+ redis_set_params : Dict [str , Union [str , bytes , int ]] = {
77
+ "name" : task_id ,
78
+ "value" : pickle .dumps (result ),
79
+ }
80
+ if self .result_ex_time :
81
+ redis_set_params ["ex" ] = self .result_ex_time
82
+ elif self .result_px_time :
83
+ redis_set_params ["px" ] = self .result_px_time
49
84
50
85
async with Redis (connection_pool = self .redis_pool ) as redis :
51
- await redis .hset (
52
- task_id ,
53
- mapping = result_dict ,
54
- )
86
+ await redis .set (** redis_set_params )
55
87
56
88
async def is_result_ready (self , task_id : str ) -> bool :
57
89
"""
@@ -76,23 +108,19 @@ async def get_result( # noqa: WPS210
76
108
:param with_logs: if True it will download task's logs.
77
109
:return: task's return value.
78
110
"""
79
- fields = list (TaskiqResult .__fields__ .keys ())
80
-
81
- if not with_logs :
82
- fields .remove ("log" )
83
-
84
111
async with Redis (connection_pool = self .redis_pool ) as redis :
85
- result_values = await redis .hmget (
86
- name = task_id ,
87
- keys = fields ,
88
- )
112
+ if self .keep_results :
113
+ result_value = await redis .get (
114
+ name = task_id ,
115
+ )
116
+ else :
117
+ result_value = await redis .getdel (
118
+ name = task_id ,
119
+ )
89
120
90
- if not self .keep_results :
91
- await redis .delete (task_id )
121
+ taskiq_result : TaskiqResult [_ReturnType ] = pickle .loads (result_value )
92
122
93
- result = {
94
- result_key : pickle .loads (result_value )
95
- for result_value , result_key in zip (result_values , fields )
96
- }
123
+ if not with_logs :
124
+ taskiq_result .log = None
97
125
98
- return TaskiqResult ( ** result )
126
+ return taskiq_result
0 commit comments