Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/agents_sdk_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,52 @@ from guardrails import JsonString
agent = GuardrailAgent(config=JsonString('{"version": 1, ...}'), ...)
```

## Token Usage Tracking

Track token usage from LLM-based guardrails using the unified `total_guardrail_token_usage` function:

```python
from guardrails import GuardrailAgent, total_guardrail_token_usage
from agents import Runner

agent = GuardrailAgent(config="config.json", name="Assistant", instructions="...")
result = await Runner.run(agent, "Hello")

# Get aggregated token usage from all guardrails
tokens = total_guardrail_token_usage(result)
print(f"Guardrail tokens used: {tokens['total_tokens']}")
```

### Per-Stage Token Usage

For per-stage token usage, access the guardrail results directly on the `RunResult`:

```python
# Input guardrails (agent-level)
for gr in result.input_guardrail_results:
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
if usage:
print(f"Input guardrail: {usage['total_tokens']} tokens")

# Output guardrails (agent-level)
for gr in result.output_guardrail_results:
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
if usage:
print(f"Output guardrail: {usage['total_tokens']} tokens")

# Tool input guardrails (per-tool)
for gr in result.tool_input_guardrail_results:
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
if usage:
print(f"Tool input guardrail: {usage['total_tokens']} tokens")

# Tool output guardrails (per-tool)
for gr in result.tool_output_guardrail_results:
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
if usage:
print(f"Tool output guardrail: {usage['total_tokens']} tokens")
```

## Next Steps

- Use the [Guardrails Wizard](https://guardrails.openai.com/) to generate your configuration
Expand Down
81 changes: 81 additions & 0 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,87 @@ client = GuardrailsAsyncOpenAI(
)
```

## Token Usage Tracking

LLM-based guardrails (Jailbreak, Custom Prompt Check, etc.) consume tokens. You can track token usage across all guardrail calls using the unified `total_guardrail_token_usage` function:

```python
from guardrails import GuardrailsAsyncOpenAI, total_guardrail_token_usage

client = GuardrailsAsyncOpenAI(config="config.json")
response = await client.responses.create(model="gpt-4o", input="Hello")

# Get aggregated token usage from all guardrails
tokens = total_guardrail_token_usage(response)
print(f"Guardrail tokens used: {tokens['total_tokens']}")
# Output: Guardrail tokens used: 425
```

The function returns a dictionary:
```python
{
"prompt_tokens": 300, # Sum of prompt tokens across all LLM guardrails
"completion_tokens": 125, # Sum of completion tokens
"total_tokens": 425, # Total tokens used by guardrails
}
```

### Works Across All Surfaces

`total_guardrail_token_usage` works with any guardrails result type:

```python
# OpenAI client responses
response = await client.responses.create(...)
tokens = total_guardrail_token_usage(response)

# Streaming (use the last chunk)
async for chunk in stream:
last_chunk = chunk
tokens = total_guardrail_token_usage(last_chunk)

# Agents SDK
result = await Runner.run(agent, input)
tokens = total_guardrail_token_usage(result)
```

### Per-Guardrail Token Usage

Each guardrail result includes its own token usage in the `info` dict:

**OpenAI Clients (GuardrailsAsyncOpenAI, etc.)**:

```python
response = await client.responses.create(model="gpt-4.1", input="Hello")

for gr in response.guardrail_results.all_results:
usage = gr.info.get("token_usage")
if usage:
print(f"{gr.info['guardrail_name']}: {usage['total_tokens']} tokens")
```

**Agents SDK** - access token usage per stage via `RunResult`:

```python
result = await Runner.run(agent, "Hello")

# Input guardrails
for gr in result.input_guardrail_results:
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
if usage:
print(f"Input: {usage['total_tokens']} tokens")

# Output guardrails
for gr in result.output_guardrail_results:
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
if usage:
print(f"Output: {usage['total_tokens']} tokens")

# Tool guardrails: result.tool_input_guardrail_results, result.tool_output_guardrail_results
```

Non-LLM guardrails (URL Filter, Moderation, PII) don't consume tokens and won't have `token_usage` in their info.

## Next Steps

- Explore [examples](./examples.md) for advanced patterns
Expand Down
16 changes: 13 additions & 3 deletions examples/basic/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@
from rich.console import Console
from rich.panel import Panel

from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered, total_guardrail_token_usage

console = Console()

# Pipeline configuration with pre_flight and input guardrails
# Define your pipeline configuration
PIPELINE_CONFIG = {
"version": 1,
"pre_flight": {
"version": 1,
"guardrails": [
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
{"name": "Moderation", "config": {"categories": ["hate", "violence"]}},
{
"name": "Jailbreak",
"config": {
"model": "gpt-4.1-mini",
"confidence_threshold": 0.7,
},
},
],
},
"input": {
Expand Down Expand Up @@ -52,6 +59,9 @@ async def process_input(
# Show guardrail results if any were run
if response.guardrail_results.all_results:
console.print(f"[dim]Guardrails checked: {len(response.guardrail_results.all_results)}[/dim]")
# Use unified function - works with any guardrails response type
tokens = total_guardrail_token_usage(response)
console.print(f"[dim]Token usage: {tokens}[/dim]")

return response.id

Expand Down
3 changes: 2 additions & 1 deletion examples/basic/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rich.console import Console
from rich.panel import Panel

from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered, total_guardrail_token_usage

console = Console()

Expand Down Expand Up @@ -50,6 +50,7 @@ async def process_input(
# Access response content using standard OpenAI API
response_content = response.choices[0].message.content
console.print(f"\nAssistant output: {response_content}", end="\n\n")
console.print(f"Token usage: {total_guardrail_token_usage(response)}")

# Add to conversation history
input_data.append({"role": "user", "content": user_input})
Expand Down
22 changes: 18 additions & 4 deletions examples/basic/multi_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rich.live import Live
from rich.panel import Panel

from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered, total_guardrail_token_usage

console = Console()

Expand All @@ -22,6 +22,13 @@
"name": "URL Filter",
"config": {"url_allow_list": ["example.com", "baz.com"]},
},
{
"name": "Jailbreak",
"config": {
"model": "gpt-4.1-mini",
"confidence_threshold": 0.7,
},
},
],
},
"input": {
Expand Down Expand Up @@ -63,19 +70,26 @@ async def process_input(

# Stream the assistant's output inside a Rich Live panel
output_text = "Assistant output: "
last_chunk = None
with Live(output_text, console=console, refresh_per_second=10) as live:
try:
async for chunk in stream:
last_chunk = chunk
# Access streaming response exactly like native OpenAI API (flattened)
if hasattr(chunk, "delta") and chunk.delta:
output_text += chunk.delta
live.update(output_text)

# Get the response ID from the final chunk
response_id_to_return = None
if hasattr(chunk, "response") and hasattr(chunk.response, "id"):
response_id_to_return = chunk.response.id

if last_chunk and hasattr(last_chunk, "response") and hasattr(last_chunk.response, "id"):
response_id_to_return = last_chunk.response.id

# Print token usage from guardrail results (unified interface)
if last_chunk:
tokens = total_guardrail_token_usage(last_chunk)
if tokens["total_tokens"]:
console.print(f"[dim]📊 Guardrail tokens: {tokens['total_tokens']}[/dim]")
return response_id_to_return

except GuardrailTripwireTriggered:
Expand Down
3 changes: 2 additions & 1 deletion src/guardrails/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
run_guardrails,
)
from .spec import GuardrailSpecMetadata
from .types import GuardrailResult
from .types import GuardrailResult, total_guardrail_token_usage

__all__ = [
"ConfiguredGuardrail", # configured, executable object
Expand All @@ -64,6 +64,7 @@
"load_pipeline_bundles",
"default_spec_registry",
"resources", # resource modules
"total_guardrail_token_usage", # unified token usage aggregation
]

__version__: str = _m.version("openai-guardrails")
Expand Down
30 changes: 20 additions & 10 deletions src/guardrails/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .context import has_context
from .runtime import load_pipeline_bundles
from .types import GuardrailLLMContextProto, GuardrailResult
from .types import GuardrailLLMContextProto, GuardrailResult, aggregate_token_usage_from_infos
from .utils.context import validate_guardrail_context
from .utils.conversation import append_assistant_response, normalize_conversation

Expand Down Expand Up @@ -77,6 +77,23 @@ def triggered_results(self) -> list[GuardrailResult]:
"""Get only the guardrail results that triggered tripwires."""
return [r for r in self.all_results if r.tripwire_triggered]

@property
def total_token_usage(self) -> dict[str, Any]:
"""Aggregate token usage across all LLM-based guardrails.

Sums prompt_tokens, completion_tokens, and total_tokens from all
guardrail results that include token_usage in their info dict.
Non-LLM guardrails (which don't have token_usage) are skipped.

Returns:
Dictionary with:
- prompt_tokens: Sum of all prompt tokens (or None if no data)
- completion_tokens: Sum of all completion tokens (or None if no data)
- total_tokens: Sum of all total tokens (or None if no data)
"""
infos = (result.info for result in self.all_results)
return aggregate_token_usage_from_infos(infos)


@dataclass(frozen=True, slots=True, weakref_slot=True)
class GuardrailsResponse:
Expand Down Expand Up @@ -427,8 +444,7 @@ def _mask_text(text: str) -> str:
or (
len(candidate_lower) >= 3
and any( # Any 3-char chunk overlaps
candidate_lower[i : i + 3] in detected_lower
for i in range(len(candidate_lower) - 2)
candidate_lower[i : i + 3] in detected_lower for i in range(len(candidate_lower) - 2)
)
)
)
Expand Down Expand Up @@ -459,13 +475,7 @@ def _mask_text(text: str) -> str:
modified_content.append(part)
else:
# Handle object-based content parts
if (
hasattr(part, "type")
and hasattr(part, "text")
and part.type in _TEXT_CONTENT_TYPES
and isinstance(part.text, str)
and part.text
):
if hasattr(part, "type") and hasattr(part, "text") and part.type in _TEXT_CONTENT_TYPES and isinstance(part.text, str) and part.text:
try:
part.text = _mask_text(part.text)
except Exception:
Expand Down
Loading
Loading