diff --git a/src/guardrails/resources/chat/chat.py b/src/guardrails/resources/chat/chat.py index a76d9b7..8821976 100644 --- a/src/guardrails/resources/chat/chat.py +++ b/src/guardrails/resources/chat/chat.py @@ -3,6 +3,8 @@ import asyncio from collections.abc import AsyncIterator from concurrent.futures import ThreadPoolExecutor +from contextvars import copy_context +from functools import partial from typing import Any from ..._base_client import GuardrailsBaseClient @@ -93,10 +95,10 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals if supports_safety_identifier(self._client._resource_client): llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER - llm_future = executor.submit( - self._client._resource_client.chat.completions.create, - **llm_kwargs, - ) + llm_call_fn = partial(self._client._resource_client.chat.completions.create, **llm_kwargs) + ctx = copy_context() + llm_future = executor.submit(ctx.run, llm_call_fn) + input_results = self._client._run_stage_guardrails( "input", latest_message, diff --git a/src/guardrails/resources/responses/responses.py b/src/guardrails/resources/responses/responses.py index 4df5f46..262529f 100644 --- a/src/guardrails/resources/responses/responses.py +++ b/src/guardrails/resources/responses/responses.py @@ -3,6 +3,8 @@ import asyncio from collections.abc import AsyncIterator from concurrent.futures import ThreadPoolExecutor +from contextvars import copy_context +from functools import partial from typing import Any from pydantic import BaseModel @@ -75,10 +77,10 @@ def create( if supports_safety_identifier(self._client._resource_client): llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER - llm_future = executor.submit( - self._client._resource_client.responses.create, - **llm_kwargs, - ) + llm_call_fn = partial(self._client._resource_client.responses.create, **llm_kwargs) + ctx = copy_context() + llm_future = executor.submit(ctx.run, llm_call_fn) + input_results = self._client._run_stage_guardrails( "input", latest_message, @@ -141,10 +143,10 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM if supports_safety_identifier(self._client._resource_client): llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER - llm_future = executor.submit( - self._client._resource_client.responses.parse, - **llm_kwargs, - ) + llm_call_fn = partial(self._client._resource_client.responses.parse, **llm_kwargs) + ctx = copy_context() + llm_future = executor.submit(ctx.run, llm_call_fn) + input_results = self._client._run_stage_guardrails( "input", latest_message, diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index bd34790..cf7dd54 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -2,6 +2,8 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar, copy_context from dataclasses import FrozenInstanceError import pytest @@ -34,4 +36,79 @@ def test_context_is_immutable() -> None: context = GuardrailsContext(guardrail_llm=_StubClient()) with pytest.raises(FrozenInstanceError): - context.guardrail_llm = None # type: ignore[misc] + context.guardrail_llm = None + + +def test_contextvar_propagates_with_copy_context() -> None: + test_var: ContextVar[str | None] = ContextVar("test_var", default=None) + test_var.set("test_value") + + def get_contextvar(): + return test_var.get() + + ctx = copy_context() + result = ctx.run(get_contextvar) + assert result == "test_value" # noqa: S101 + + +def test_contextvar_propagates_with_threadpool() -> None: + test_var: ContextVar[str | None] = ContextVar("test_var", default=None) + test_var.set("thread_test") + + def get_contextvar(): + return test_var.get() + + ctx = copy_context() + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(ctx.run, get_contextvar) + result = future.result() + + assert result == "thread_test" # noqa: S101 + + +def test_guardrails_context_propagates_with_copy_context() -> None: + context = GuardrailsContext(guardrail_llm=_StubClient()) + set_context(context) + + def get_guardrails_context(): + return get_context() + + ctx = copy_context() + result = ctx.run(get_guardrails_context) + assert result is context # noqa: S101 + + clear_context() + + +def test_guardrails_context_propagates_with_threadpool() -> None: + context = GuardrailsContext(guardrail_llm=_StubClient()) + set_context(context) + + def get_guardrails_context(): + return get_context() + + ctx = copy_context() + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(ctx.run, get_guardrails_context) + result = future.result() + + assert result is context # noqa: S101 + + clear_context() + + +def test_multiple_contextvars_propagate_with_threadpool() -> None: + var1: ContextVar[str | None] = ContextVar("var1", default=None) + var2: ContextVar[int | None] = ContextVar("var2", default=None) + var1.set("value1") + var2.set(42) + + def get_multiple_contextvars(): + return (var1.get(), var2.get()) + + ctx = copy_context() + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(ctx.run, get_multiple_contextvars) + result = future.result() + + assert result == ("value1", 42) # noqa: S101