|
19 | 19 | from taskiq.abc.result_backend import TaskiqResult
|
20 | 20 | from taskiq.abc.serializer import TaskiqSerializer
|
21 | 21 | from taskiq.compat import model_dump, model_validate
|
| 22 | +from taskiq.depends.progress_tracker import TaskProgress |
22 | 23 | from taskiq.serializers import PickleSerializer
|
23 | 24 |
|
24 | 25 | from taskiq_redis.exceptions import (
|
|
41 | 42 |
|
42 | 43 | _ReturnType = TypeVar("_ReturnType")
|
43 | 44 |
|
| 45 | +PROGRESS_KEY_SUFFIX = "__progress" |
| 46 | + |
44 | 47 |
|
45 | 48 | class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
46 | 49 | """Async result based on redis."""
|
@@ -174,6 +177,55 @@ async def get_result(
|
174 | 177 |
|
175 | 178 | return taskiq_result
|
176 | 179 |
|
| 180 | + async def set_progress( |
| 181 | + self, |
| 182 | + task_id: str, |
| 183 | + progress: TaskProgress[_ReturnType], |
| 184 | + ) -> None: |
| 185 | + """ |
| 186 | + Sets task progress in redis. |
| 187 | +
|
| 188 | + Dumps TaskProgress instance into the bytes and writes |
| 189 | + it to redis with a standard suffix on the task_id as the key |
| 190 | +
|
| 191 | + :param task_id: ID of the task. |
| 192 | + :param result: task's TaskProgress instance. |
| 193 | + """ |
| 194 | + redis_set_params: Dict[str, Union[str, int, bytes]] = { |
| 195 | + "name": task_id + PROGRESS_KEY_SUFFIX, |
| 196 | + "value": self.serializer.dumpb(model_dump(progress)), |
| 197 | + } |
| 198 | + if self.result_ex_time: |
| 199 | + redis_set_params["ex"] = self.result_ex_time |
| 200 | + elif self.result_px_time: |
| 201 | + redis_set_params["px"] = self.result_px_time |
| 202 | + |
| 203 | + async with Redis(connection_pool=self.redis_pool) as redis: |
| 204 | + await redis.set(**redis_set_params) # type: ignore |
| 205 | + |
| 206 | + async def get_progress( |
| 207 | + self, |
| 208 | + task_id: str, |
| 209 | + ) -> Union[TaskProgress[_ReturnType], None]: |
| 210 | + """ |
| 211 | + Gets progress results from the task. |
| 212 | +
|
| 213 | + :param task_id: task's id. |
| 214 | + :return: task's TaskProgress instance. |
| 215 | + """ |
| 216 | + async with Redis(connection_pool=self.redis_pool) as redis: |
| 217 | + result_value = await redis.get( |
| 218 | + name=task_id + PROGRESS_KEY_SUFFIX, |
| 219 | + ) |
| 220 | + |
| 221 | + if result_value is None: |
| 222 | + return None |
| 223 | + |
| 224 | + return model_validate( |
| 225 | + TaskProgress[_ReturnType], |
| 226 | + self.serializer.loadb(result_value), |
| 227 | + ) |
| 228 | + |
177 | 229 |
|
178 | 230 | class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
179 | 231 | """Async result backend based on redis cluster."""
|
@@ -301,6 +353,53 @@ async def get_result(
|
301 | 353 |
|
302 | 354 | return taskiq_result
|
303 | 355 |
|
| 356 | + async def set_progress( |
| 357 | + self, |
| 358 | + task_id: str, |
| 359 | + progress: TaskProgress[_ReturnType], |
| 360 | + ) -> None: |
| 361 | + """ |
| 362 | + Sets task progress in redis. |
| 363 | +
|
| 364 | + Dumps TaskProgress instance into the bytes and writes |
| 365 | + it to redis with a standard suffix on the task_id as the key |
| 366 | +
|
| 367 | + :param task_id: ID of the task. |
| 368 | + :param result: task's TaskProgress instance. |
| 369 | + """ |
| 370 | + redis_set_params: Dict[str, Union[str, int, bytes]] = { |
| 371 | + "name": task_id + PROGRESS_KEY_SUFFIX, |
| 372 | + "value": self.serializer.dumpb(model_dump(progress)), |
| 373 | + } |
| 374 | + if self.result_ex_time: |
| 375 | + redis_set_params["ex"] = self.result_ex_time |
| 376 | + elif self.result_px_time: |
| 377 | + redis_set_params["px"] = self.result_px_time |
| 378 | + |
| 379 | + await self.redis.set(**redis_set_params) # type: ignore |
| 380 | + |
| 381 | + async def get_progress( |
| 382 | + self, |
| 383 | + task_id: str, |
| 384 | + ) -> Union[TaskProgress[_ReturnType], None]: |
| 385 | + """ |
| 386 | + Gets progress results from the task. |
| 387 | +
|
| 388 | + :param task_id: task's id. |
| 389 | + :return: task's TaskProgress instance. |
| 390 | + """ |
| 391 | + result_value = await self.redis.get( # type: ignore[attr-defined] |
| 392 | + name=task_id + PROGRESS_KEY_SUFFIX, |
| 393 | + ) |
| 394 | + |
| 395 | + if result_value is None: |
| 396 | + return None |
| 397 | + |
| 398 | + return model_validate( |
| 399 | + TaskProgress[_ReturnType], |
| 400 | + self.serializer.loadb(result_value), |
| 401 | + ) |
| 402 | + |
304 | 403 |
|
305 | 404 | class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
306 | 405 | """Async result based on redis sentinel."""
|
@@ -439,6 +538,55 @@ async def get_result(
|
439 | 538 |
|
440 | 539 | return taskiq_result
|
441 | 540 |
|
| 541 | + async def set_progress( |
| 542 | + self, |
| 543 | + task_id: str, |
| 544 | + progress: TaskProgress[_ReturnType], |
| 545 | + ) -> None: |
| 546 | + """ |
| 547 | + Sets task progress in redis. |
| 548 | +
|
| 549 | + Dumps TaskProgress instance into the bytes and writes |
| 550 | + it to redis with a standard suffix on the task_id as the key |
| 551 | +
|
| 552 | + :param task_id: ID of the task. |
| 553 | + :param result: task's TaskProgress instance. |
| 554 | + """ |
| 555 | + redis_set_params: Dict[str, Union[str, int, bytes]] = { |
| 556 | + "name": task_id + PROGRESS_KEY_SUFFIX, |
| 557 | + "value": self.serializer.dumpb(model_dump(progress)), |
| 558 | + } |
| 559 | + if self.result_ex_time: |
| 560 | + redis_set_params["ex"] = self.result_ex_time |
| 561 | + elif self.result_px_time: |
| 562 | + redis_set_params["px"] = self.result_px_time |
| 563 | + |
| 564 | + async with self._acquire_master_conn() as redis: |
| 565 | + await redis.set(**redis_set_params) # type: ignore |
| 566 | + |
| 567 | + async def get_progress( |
| 568 | + self, |
| 569 | + task_id: str, |
| 570 | + ) -> Union[TaskProgress[_ReturnType], None]: |
| 571 | + """ |
| 572 | + Gets progress results from the task. |
| 573 | +
|
| 574 | + :param task_id: task's id. |
| 575 | + :return: task's TaskProgress instance. |
| 576 | + """ |
| 577 | + async with self._acquire_master_conn() as redis: |
| 578 | + result_value = await redis.get( |
| 579 | + name=task_id + PROGRESS_KEY_SUFFIX, |
| 580 | + ) |
| 581 | + |
| 582 | + if result_value is None: |
| 583 | + return None |
| 584 | + |
| 585 | + return model_validate( |
| 586 | + TaskProgress[_ReturnType], |
| 587 | + self.serializer.loadb(result_value), |
| 588 | + ) |
| 589 | + |
442 | 590 | async def shutdown(self) -> None:
|
443 | 591 | """Shutdown sentinel connections."""
|
444 | 592 | for sentinel in self.sentinel.sentinels:
|
|
0 commit comments