Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/guardrails/resources/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import asyncio
from collections.abc import AsyncIterator
from concurrent.futures import ThreadPoolExecutor
from contextvars import copy_context
from functools import partial
from typing import Any

from ..._base_client import GuardrailsBaseClient
Expand Down Expand Up @@ -93,10 +95,10 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.chat.completions.create,
**llm_kwargs,
)
llm_call_fn = partial(self._client._resource_client.chat.completions.create, **llm_kwargs)
ctx = copy_context()
llm_future = executor.submit(ctx.run, llm_call_fn)

input_results = self._client._run_stage_guardrails(
"input",
latest_message,
Expand Down
18 changes: 10 additions & 8 deletions src/guardrails/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import asyncio
from collections.abc import AsyncIterator
from concurrent.futures import ThreadPoolExecutor
from contextvars import copy_context
from functools import partial
from typing import Any

from pydantic import BaseModel
Expand Down Expand Up @@ -75,10 +77,10 @@ def create(
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.responses.create,
**llm_kwargs,
)
llm_call_fn = partial(self._client._resource_client.responses.create, **llm_kwargs)
ctx = copy_context()
llm_future = executor.submit(ctx.run, llm_call_fn)

input_results = self._client._run_stage_guardrails(
"input",
latest_message,
Expand Down Expand Up @@ -141,10 +143,10 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.responses.parse,
**llm_kwargs,
)
llm_call_fn = partial(self._client._resource_client.responses.parse, **llm_kwargs)
ctx = copy_context()
llm_future = executor.submit(ctx.run, llm_call_fn)

input_results = self._client._run_stage_guardrails(
"input",
latest_message,
Expand Down
79 changes: 78 additions & 1 deletion tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar, copy_context
from dataclasses import FrozenInstanceError

import pytest
Expand Down Expand Up @@ -34,4 +36,79 @@ def test_context_is_immutable() -> None:
context = GuardrailsContext(guardrail_llm=_StubClient())

with pytest.raises(FrozenInstanceError):
context.guardrail_llm = None # type: ignore[misc]
context.guardrail_llm = None


def test_contextvar_propagates_with_copy_context() -> None:
test_var: ContextVar[str | None] = ContextVar("test_var", default=None)
test_var.set("test_value")

def get_contextvar():
return test_var.get()

ctx = copy_context()
result = ctx.run(get_contextvar)
assert result == "test_value" # noqa: S101


def test_contextvar_propagates_with_threadpool() -> None:
test_var: ContextVar[str | None] = ContextVar("test_var", default=None)
test_var.set("thread_test")

def get_contextvar():
return test_var.get()

ctx = copy_context()
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(ctx.run, get_contextvar)
result = future.result()

assert result == "thread_test" # noqa: S101


def test_guardrails_context_propagates_with_copy_context() -> None:
context = GuardrailsContext(guardrail_llm=_StubClient())
set_context(context)

def get_guardrails_context():
return get_context()

ctx = copy_context()
result = ctx.run(get_guardrails_context)
assert result is context # noqa: S101

clear_context()


def test_guardrails_context_propagates_with_threadpool() -> None:
context = GuardrailsContext(guardrail_llm=_StubClient())
set_context(context)

def get_guardrails_context():
return get_context()

ctx = copy_context()
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(ctx.run, get_guardrails_context)
result = future.result()

assert result is context # noqa: S101

clear_context()


def test_multiple_contextvars_propagate_with_threadpool() -> None:
var1: ContextVar[str | None] = ContextVar("var1", default=None)
var2: ContextVar[int | None] = ContextVar("var2", default=None)
var1.set("value1")
var2.set(42)

def get_multiple_contextvars():
return (var1.get(), var2.get())

ctx = copy_context()
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(ctx.run, get_multiple_contextvars)
result = future.result()

assert result == ("value1", 42) # noqa: S101
Loading