Skip to content

Commit 254fae4

Browse files
authored
Merge pull request #67 from sminnee/feature/progress
2 parents e507271 + 8a97ced commit 254fae4

File tree

3 files changed

+272
-2
lines changed

3 files changed

+272
-2
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ repos:
1818
hooks:
1919
- id: black
2020
name: Format with Black
21-
entry: black
21+
entry: poetry run black
2222
language: system
2323
types: [python]
2424

@@ -36,6 +36,6 @@ repos:
3636

3737
- id: mypy
3838
name: Validate types with MyPy
39-
entry: mypy
39+
entry: poetry run mypy
4040
language: system
4141
types: [ python ]

taskiq_redis/redis_backend.py

+148
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from taskiq.abc.result_backend import TaskiqResult
2020
from taskiq.abc.serializer import TaskiqSerializer
2121
from taskiq.compat import model_dump, model_validate
22+
from taskiq.depends.progress_tracker import TaskProgress
2223
from taskiq.serializers import PickleSerializer
2324

2425
from taskiq_redis.exceptions import (
@@ -41,6 +42,8 @@
4142

4243
_ReturnType = TypeVar("_ReturnType")
4344

45+
PROGRESS_KEY_SUFFIX = "__progress"
46+
4447

4548
class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
4649
"""Async result based on redis."""
@@ -174,6 +177,55 @@ async def get_result(
174177

175178
return taskiq_result
176179

180+
async def set_progress(
181+
self,
182+
task_id: str,
183+
progress: TaskProgress[_ReturnType],
184+
) -> None:
185+
"""
186+
Sets task progress in redis.
187+
188+
Dumps TaskProgress instance into the bytes and writes
189+
it to redis with a standard suffix on the task_id as the key
190+
191+
:param task_id: ID of the task.
192+
:param result: task's TaskProgress instance.
193+
"""
194+
redis_set_params: Dict[str, Union[str, int, bytes]] = {
195+
"name": task_id + PROGRESS_KEY_SUFFIX,
196+
"value": self.serializer.dumpb(model_dump(progress)),
197+
}
198+
if self.result_ex_time:
199+
redis_set_params["ex"] = self.result_ex_time
200+
elif self.result_px_time:
201+
redis_set_params["px"] = self.result_px_time
202+
203+
async with Redis(connection_pool=self.redis_pool) as redis:
204+
await redis.set(**redis_set_params) # type: ignore
205+
206+
async def get_progress(
207+
self,
208+
task_id: str,
209+
) -> Union[TaskProgress[_ReturnType], None]:
210+
"""
211+
Gets progress results from the task.
212+
213+
:param task_id: task's id.
214+
:return: task's TaskProgress instance.
215+
"""
216+
async with Redis(connection_pool=self.redis_pool) as redis:
217+
result_value = await redis.get(
218+
name=task_id + PROGRESS_KEY_SUFFIX,
219+
)
220+
221+
if result_value is None:
222+
return None
223+
224+
return model_validate(
225+
TaskProgress[_ReturnType],
226+
self.serializer.loadb(result_value),
227+
)
228+
177229

178230
class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
179231
"""Async result backend based on redis cluster."""
@@ -301,6 +353,53 @@ async def get_result(
301353

302354
return taskiq_result
303355

356+
async def set_progress(
357+
self,
358+
task_id: str,
359+
progress: TaskProgress[_ReturnType],
360+
) -> None:
361+
"""
362+
Sets task progress in redis.
363+
364+
Dumps TaskProgress instance into the bytes and writes
365+
it to redis with a standard suffix on the task_id as the key
366+
367+
:param task_id: ID of the task.
368+
:param result: task's TaskProgress instance.
369+
"""
370+
redis_set_params: Dict[str, Union[str, int, bytes]] = {
371+
"name": task_id + PROGRESS_KEY_SUFFIX,
372+
"value": self.serializer.dumpb(model_dump(progress)),
373+
}
374+
if self.result_ex_time:
375+
redis_set_params["ex"] = self.result_ex_time
376+
elif self.result_px_time:
377+
redis_set_params["px"] = self.result_px_time
378+
379+
await self.redis.set(**redis_set_params) # type: ignore
380+
381+
async def get_progress(
382+
self,
383+
task_id: str,
384+
) -> Union[TaskProgress[_ReturnType], None]:
385+
"""
386+
Gets progress results from the task.
387+
388+
:param task_id: task's id.
389+
:return: task's TaskProgress instance.
390+
"""
391+
result_value = await self.redis.get( # type: ignore[attr-defined]
392+
name=task_id + PROGRESS_KEY_SUFFIX,
393+
)
394+
395+
if result_value is None:
396+
return None
397+
398+
return model_validate(
399+
TaskProgress[_ReturnType],
400+
self.serializer.loadb(result_value),
401+
)
402+
304403

305404
class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
306405
"""Async result based on redis sentinel."""
@@ -439,6 +538,55 @@ async def get_result(
439538

440539
return taskiq_result
441540

541+
async def set_progress(
542+
self,
543+
task_id: str,
544+
progress: TaskProgress[_ReturnType],
545+
) -> None:
546+
"""
547+
Sets task progress in redis.
548+
549+
Dumps TaskProgress instance into the bytes and writes
550+
it to redis with a standard suffix on the task_id as the key
551+
552+
:param task_id: ID of the task.
553+
:param result: task's TaskProgress instance.
554+
"""
555+
redis_set_params: Dict[str, Union[str, int, bytes]] = {
556+
"name": task_id + PROGRESS_KEY_SUFFIX,
557+
"value": self.serializer.dumpb(model_dump(progress)),
558+
}
559+
if self.result_ex_time:
560+
redis_set_params["ex"] = self.result_ex_time
561+
elif self.result_px_time:
562+
redis_set_params["px"] = self.result_px_time
563+
564+
async with self._acquire_master_conn() as redis:
565+
await redis.set(**redis_set_params) # type: ignore
566+
567+
async def get_progress(
568+
self,
569+
task_id: str,
570+
) -> Union[TaskProgress[_ReturnType], None]:
571+
"""
572+
Gets progress results from the task.
573+
574+
:param task_id: task's id.
575+
:return: task's TaskProgress instance.
576+
"""
577+
async with self._acquire_master_conn() as redis:
578+
result_value = await redis.get(
579+
name=task_id + PROGRESS_KEY_SUFFIX,
580+
)
581+
582+
if result_value is None:
583+
return None
584+
585+
return model_validate(
586+
TaskProgress[_ReturnType],
587+
self.serializer.loadb(result_value),
588+
)
589+
442590
async def shutdown(self) -> None:
443591
"""Shutdown sentinel connections."""
444592
for sentinel in self.sentinel.sentinels:

tests/test_result_backend.py

+122
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from taskiq import TaskiqResult
7+
from taskiq.depends.progress_tracker import TaskProgress, TaskState
78

89
from taskiq_redis import (
910
RedisAsyncClusterResultBackend,
@@ -438,3 +439,124 @@ async def test_keep_results_after_reading_sentinel(
438439
res2 = await result_backend.get_result(task_id=task_id)
439440
assert res1 == res2
440441
await result_backend.shutdown()
442+
443+
444+
@pytest.mark.anyio
445+
async def test_set_progress(redis_url: str) -> None:
446+
"""
447+
Test that set_progress/get_progress works.
448+
449+
:param redis_url: redis URL.
450+
"""
451+
result_backend = RedisAsyncResultBackend( # type: ignore
452+
redis_url=redis_url,
453+
)
454+
task_id = uuid.uuid4().hex
455+
456+
test_progress_1 = TaskProgress(
457+
state=TaskState.STARTED,
458+
meta={"message": "quarter way", "pct": 25},
459+
)
460+
test_progress_2 = TaskProgress(
461+
state=TaskState.STARTED,
462+
meta={"message": "half way", "pct": 50},
463+
)
464+
465+
# Progress starts as None
466+
assert await result_backend.get_progress(task_id=task_id) is None
467+
468+
# Setting the first time persists
469+
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)
470+
471+
fetched_result = await result_backend.get_progress(task_id=task_id)
472+
assert fetched_result == test_progress_1
473+
474+
# Setting the second time replaces the first
475+
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)
476+
477+
fetched_result = await result_backend.get_progress(task_id=task_id)
478+
assert fetched_result == test_progress_2
479+
480+
await result_backend.shutdown()
481+
482+
483+
@pytest.mark.anyio
484+
async def test_set_progress_cluster(redis_cluster_url: str) -> None:
485+
"""
486+
Test that set_progress/get_progress works in cluster mode.
487+
488+
:param redis_url: redis URL.
489+
"""
490+
result_backend = RedisAsyncClusterResultBackend( # type: ignore
491+
redis_url=redis_cluster_url,
492+
)
493+
task_id = uuid.uuid4().hex
494+
495+
test_progress_1 = TaskProgress(
496+
state=TaskState.STARTED,
497+
meta={"message": "quarter way", "pct": 25},
498+
)
499+
test_progress_2 = TaskProgress(
500+
state=TaskState.STARTED,
501+
meta={"message": "half way", "pct": 50},
502+
)
503+
504+
# Progress starts as None
505+
assert await result_backend.get_progress(task_id=task_id) is None
506+
507+
# Setting the first time persists
508+
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)
509+
510+
fetched_result = await result_backend.get_progress(task_id=task_id)
511+
assert fetched_result == test_progress_1
512+
513+
# Setting the second time replaces the first
514+
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)
515+
516+
fetched_result = await result_backend.get_progress(task_id=task_id)
517+
assert fetched_result == test_progress_2
518+
519+
await result_backend.shutdown()
520+
521+
522+
@pytest.mark.anyio
523+
async def test_set_progress_sentinel(
524+
redis_sentinels: List[Tuple[str, int]],
525+
redis_sentinel_master_name: str,
526+
) -> None:
527+
"""
528+
Test that set_progress/get_progress works in cluster mode.
529+
530+
:param redis_url: redis URL.
531+
"""
532+
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
533+
sentinels=redis_sentinels,
534+
master_name=redis_sentinel_master_name,
535+
)
536+
task_id = uuid.uuid4().hex
537+
538+
test_progress_1 = TaskProgress(
539+
state=TaskState.STARTED,
540+
meta={"message": "quarter way", "pct": 25},
541+
)
542+
test_progress_2 = TaskProgress(
543+
state=TaskState.STARTED,
544+
meta={"message": "half way", "pct": 50},
545+
)
546+
547+
# Progress starts as None
548+
assert await result_backend.get_progress(task_id=task_id) is None
549+
550+
# Setting the first time persists
551+
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)
552+
553+
fetched_result = await result_backend.get_progress(task_id=task_id)
554+
assert fetched_result == test_progress_1
555+
556+
# Setting the second time replaces the first
557+
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)
558+
559+
fetched_result = await result_backend.get_progress(task_id=task_id)
560+
assert fetched_result == test_progress_2
561+
562+
await result_backend.shutdown()

0 commit comments

Comments
 (0)