From 20a3371beb84e4194aa1e484b6aec008fd18fa18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awek=20Ehlert?= Date: Wed, 9 Jul 2025 17:34:04 +0200 Subject: [PATCH 1/2] Add failing test The one for sync fails ATM. For async is OK. --- tests/receiver/test_receiver.py | 64 ++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 57637e9..2fd3f83 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -1,8 +1,9 @@ import asyncio +import contextvars import random import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, List, Optional +from typing import Any, ClassVar, Generator, List, Optional import pytest from taskiq_dependencies import Depends @@ -472,3 +473,64 @@ async def task_no_result() -> str: assert resp.return_value is None assert not broker._running_tasks assert isinstance(resp.error, ValueError) + + +EXPECTED_CTX_VALUE = 42 + + +@pytest.fixture() +def ctxvar() -> Generator[contextvars.ContextVar[int], None, None]: + _ctx_variable: contextvars.ContextVar[int] = contextvars.ContextVar( + "taskiq_test_ctx_var", + ) + token = _ctx_variable.set(EXPECTED_CTX_VALUE) + yield _ctx_variable + _ctx_variable.reset(token) + + +@pytest.mark.anyio +async def test_run_task_successful_sync_preserve_contextvars( + ctxvar: contextvars.ContextVar[int], +) -> None: + """Running sync tasks should preserve context vars.""" + + def test_func() -> int: + return ctxvar.get() + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={}, + ), + ) + assert result.return_value == EXPECTED_CTX_VALUE + + +@pytest.mark.anyio +async def test_run_task_successful_async_preserve_contextvars( + ctxvar: contextvars.ContextVar[int], +) -> None: + """Running async tasks should preserve context vars.""" + + async def test_func() -> int: + return ctxvar.get() + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={}, + ), + ) + assert result.return_value == EXPECTED_CTX_VALUE From ae700b08f6e2ffc8ffc848d59c363b6b331316e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awek=20Ehlert?= Date: Wed, 9 Jul 2025 17:37:47 +0200 Subject: [PATCH 2/2] Preserve context for sync tasks Run sync tasks within copied context. This is heavily inspired by starlette's `run_in_threadpool` function introduced in https://github.com/encode/starlette/pull/192. --- taskiq/receiver/receiver.py | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index c15fb93..16c7bd3 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -1,9 +1,11 @@ import asyncio +import contextvars +import functools import inspect from concurrent.futures import Executor from logging import getLogger from time import time -from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints +from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints import anyio from taskiq_dependencies import DependencyGraph @@ -23,25 +25,6 @@ QUEUE_DONE = b"-1" -def _run_sync( - target: Callable[..., Any], - args: List[Any], - kwargs: Dict[str, Any], -) -> Any: - """ - Runs function synchronously. - - We use this function, because - we cannot pass kwargs in loop.run_with_executor(). - - :param target: function to execute. - :param args: list of function's args. - :param kwargs: dict of function's kwargs. - :return: result of function's execution. - """ - return target(*args, **kwargs) - - class Receiver: """Class that uses as a callback handler.""" @@ -255,13 +238,13 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 else: is_coroutine = False # If this is a synchronous function, we - # run it in executor. + # run it in executor and preserve the context. + ctx = contextvars.copy_context() + func = functools.partial(target, *message.args, **kwargs) target_future = loop.run_in_executor( self.executor, - _run_sync, - target, - message.args, - kwargs, + ctx.run, + func, ) timeout = message.labels.get("timeout") if timeout is not None: