diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 51cd09e66..d431e82d6 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -74,6 +74,7 @@ from .models.openai_provider import OpenAIProvider from .models.openai_responses import OpenAIResponsesModel from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt +from .rate_limit import RateLimitConfig from .repl import run_demo_loop from .result import RunResult, RunResultStreaming from .run import RunConfig, Runner @@ -298,6 +299,7 @@ def enable_verbose_stdout_logging(): "RunResult", "RunResultStreaming", "RunConfig", + "RateLimitConfig", "RawResponsesStreamEvent", "RunItemStreamEvent", "AgentUpdatedStreamEvent", diff --git a/src/agents/rate_limit.py b/src/agents/rate_limit.py new file mode 100644 index 000000000..8970a1296 --- /dev/null +++ b/src/agents/rate_limit.py @@ -0,0 +1,180 @@ +"""Rate limiting utilities for the Agents SDK. + +This module provides rate limiting functionality to help users stay within +API rate limits when using free or low-budget LLM providers. +""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any, Callable, TypeVar + +from .logger import logger + +T = TypeVar("T") + + +@dataclass +class RateLimitConfig: + """Configuration for rate limiting LLM requests. + + Use this to prevent 429 (rate limit) errors when using providers with + strict rate limits (e.g., free tiers with 3 requests/minute). + + Example: + ```python + run_config = RunConfig( + model="groq/llama-3.1-8b-instant", + rate_limit=RateLimitConfig( + requests_per_minute=3, + retry_on_rate_limit=True, + ) + ) + ``` + """ + + requests_per_minute: int | None = None + """Maximum number of requests allowed per minute. If set, the SDK will + automatically pace requests to stay under this limit.""" + + retry_on_rate_limit: bool = True + """If True, automatically retry requests that receive a 429 response + with exponential backoff.""" + + max_retries: int = 3 + """Maximum number of retry attempts for rate-limited requests.""" + + initial_retry_delay: float = 1.0 + """Initial delay in seconds before the first retry attempt.""" + + backoff_multiplier: float = 2.0 + """Multiplier for exponential backoff between retries.""" + + max_retry_delay: float = 60.0 + """Maximum delay in seconds between retry attempts.""" + + +class RateLimiter: + """A simple rate limiter using the token bucket algorithm. + + This class helps pace requests to stay within a specified rate limit. + It tracks request timestamps and waits if necessary before allowing + new requests. + """ + + def __init__(self, config: RateLimitConfig): + """Initialize the rate limiter. + + Args: + config: The rate limit configuration. + """ + self._config = config + self._request_times: list[float] = [] + self._lock = asyncio.Lock() + + @property + def is_enabled(self) -> bool: + """Check if rate limiting is enabled.""" + return self._config.requests_per_minute is not None + + async def acquire(self) -> None: + """Wait until a request slot is available. + + This method blocks until it's safe to make a new request without + exceeding the configured rate limit. + """ + if not self.is_enabled: + return + + async with self._lock: + requests_per_minute = self._config.requests_per_minute + assert requests_per_minute is not None + + now = time.monotonic() + window_start = now - 60.0 # 1 minute window + + # Remove requests outside the current window + self._request_times = [t for t in self._request_times if t > window_start] + + # If we're at the limit, wait until a slot opens up + if len(self._request_times) >= requests_per_minute: + # Calculate how long to wait + oldest_request = self._request_times[0] + wait_time = oldest_request - window_start + if wait_time > 0: + logger.debug( + f"Rate limit: waiting {wait_time:.2f}s " + f"({len(self._request_times)}/{requests_per_minute} requests in window)" + ) + await asyncio.sleep(wait_time) + # Clean up again after waiting + now = time.monotonic() + window_start = now - 60.0 + self._request_times = [t for t in self._request_times if t > window_start] + + # Record this request + self._request_times.append(time.monotonic()) + + async def execute_with_retry( + self, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> T: + """Execute a function with rate limiting and automatic retry on 429 errors. + + Args: + func: The async function to execute. + *args: Positional arguments to pass to the function. + **kwargs: Keyword arguments to pass to the function. + + Returns: + The return value of the function. + + Raises: + The last exception if all retries are exhausted. + """ + # First, wait for rate limit slot + await self.acquire() + + if not self._config.retry_on_rate_limit: + return await func(*args, **kwargs) + + last_exception: Exception | None = None + delay = self._config.initial_retry_delay + + for attempt in range(self._config.max_retries + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + # Check if this is a rate limit error (429) + error_str = str(e).lower() + is_rate_limit = ( + "429" in str(e) + or "rate" in error_str + or "too many requests" in error_str + or "rate_limit" in error_str + ) + + if not is_rate_limit: + raise + + last_exception = e + + if attempt < self._config.max_retries: + logger.warning( + f"Rate limit hit (attempt {attempt + 1}/{self._config.max_retries + 1}). " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + delay = min( + delay * self._config.backoff_multiplier, self._config.max_retry_delay + ) + # Wait for a rate limit slot before retrying + await self.acquire() + + # All retries exhausted + assert last_exception is not None + raise last_exception diff --git a/src/agents/run.py b/src/agents/run.py index 5b5e6fdfa..df365ba8d 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -63,6 +63,7 @@ from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider +from .rate_limit import RateLimitConfig, RateLimiter from .result import RunResult, RunResultStreaming from .run_context import AgentHookContext, RunContextWrapper, TContext from .stream_events import ( @@ -270,6 +271,19 @@ class RunConfig: For example, you can use this to add a system prompt to the input. """ + rate_limit: RateLimitConfig | None = None + """ + Optional rate limiting configuration for LLM requests. Use this when working with + providers that have strict rate limits (e.g., free tiers with 3 requests/minute). + + Example: + ```python + run_config = RunConfig( + rate_limit=RateLimitConfig(requests_per_minute=3) + ) + ``` + """ + class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" @@ -570,6 +584,11 @@ async def run( tool_use_tracker = AgentToolUseTracker() + # Create rate limiter if configured + rate_limiter: RateLimiter | None = None + if run_config.rate_limit is not None: + rate_limiter = RateLimiter(run_config.rate_limit) + with TraceCtxManager( workflow_name=run_config.workflow_name, trace_id=run_config.trace_id, @@ -679,6 +698,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, server_conversation_tracker=server_conversation_tracker, + rate_limiter=rate_limiter, ), ) @@ -696,6 +716,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, server_conversation_tracker=server_conversation_tracker, + rate_limiter=rate_limiter, ) should_run_agent_start_hooks = False @@ -1593,6 +1614,7 @@ async def _run_single_turn( should_run_agent_start_hooks: bool, tool_use_tracker: AgentToolUseTracker, server_conversation_tracker: _ServerConversationTracker | None = None, + rate_limiter: RateLimiter | None = None, ) -> SingleStepResult: # Ensure we run the hooks before anything else if should_run_agent_start_hooks: @@ -1636,6 +1658,7 @@ async def _run_single_turn( tool_use_tracker, server_conversation_tracker, prompt_config, + rate_limiter=rate_limiter, ) return await cls._get_single_step_result_from_response( @@ -1842,6 +1865,7 @@ async def _get_new_response( tool_use_tracker: AgentToolUseTracker, server_conversation_tracker: _ServerConversationTracker | None, prompt_config: ResponsePromptParam | None, + rate_limiter: RateLimiter | None = None, ) -> ModelResponse: # Allow user to modify model input right before the call, if configured filtered = await cls._maybe_filter_model_input( @@ -1881,20 +1905,28 @@ async def _get_new_response( server_conversation_tracker.conversation_id if server_conversation_tracker else None ) - new_response = await model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ) + # Define the model call as a coroutine function for rate limiting + async def _call_model() -> ModelResponse: + return await model.get_response( + system_instructions=filtered.instructions, + input=filtered.input, + model_settings=model_settings, + tools=all_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ) + + # Apply rate limiting if configured + if rate_limiter is not None and rate_limiter.is_enabled: + new_response = await rate_limiter.execute_with_retry(_call_model) + else: + new_response = await _call_model() context_wrapper.usage.add(new_response.usage)