diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 24afad6ab0..275b161594 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -8,6 +8,9 @@ from datetime import datetime from typing import Any, Literal, cast, overload +from openai.types.chat.chat_completion_named_tool_choice_param import Function +from openai.types.responses.tool_choice_allowed_param import ToolChoiceAllowedParam +from openai.types.responses.tool_choice_function_param import ToolChoiceFunctionParam from pydantic import ValidationError from typing_extensions import assert_never, deprecated @@ -56,6 +59,8 @@ ChatCompletionContentPartParam, ChatCompletionContentPartTextParam, ) + from openai.types.chat.chat_completion_allowed_tool_choice_param import ChatCompletionAllowedToolChoiceParam + from openai.types.chat.chat_completion_allowed_tools_param import ChatCompletionAllowedToolsParam from openai.types.chat.chat_completion_content_part_image_param import ImageURL from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.chat.chat_completion_content_part_param import File, FileFile @@ -64,7 +69,9 @@ from openai.types.chat.chat_completion_message_function_tool_call_param import ( ChatCompletionMessageFunctionToolCallParam, ) + from openai.types.chat.chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam + from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam from openai.types.chat.completion_create_params import ( WebSearchOptions, WebSearchOptionsUserLocation, @@ -72,6 +79,7 @@ ) from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam from openai.types.responses.response_input_param import FunctionCallOutput, Message + from openai.types.responses.tool_choice_options import ToolChoiceOptions from openai.types.shared import ReasoningEffort from openai.types.shared_params import Reasoning except ImportError as _import_error: @@ -386,13 +394,33 @@ async def _completions_create( tools = self._get_tools(model_request_parameters) web_search_options = self._get_web_search_options(model_request_parameters) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None + tool_choice: ChatCompletionToolChoiceOptionParam | None + model_settings_tool_choice = model_settings.get('tool_choice', None) + # Respect an explicit request to disable tool calls. + if model_settings_tool_choice == 'none': + tool_choice = 'none' + elif not tools: + tool_choice = None elif ( not model_request_parameters.allow_text_output and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required + and not isinstance(model_settings_tool_choice, list) ): tool_choice = 'required' + elif isinstance(model_settings_tool_choice, list): + if len(model_settings_tool_choice) == 1: + tool_choice = ChatCompletionNamedToolChoiceParam( + type='function', + function=Function(name=model_settings_tool_choice[0]), + ) + else: + tool_choice = ChatCompletionAllowedToolChoiceParam( + type='allowed_tools', + allowed_tools=ChatCompletionAllowedToolsParam( + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'function': {'name': name}} for name in model_settings_tool_choice], + ), + ) else: tool_choice = 'auto' @@ -883,10 +911,29 @@ async def _responses_create( + self._get_tools(model_request_parameters) ) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: + tool_choice: ToolChoiceOptions | ToolChoiceAllowedParam | ToolChoiceFunctionParam | None + model_settings_tool_choice = model_settings.get('tool_choice', None) + if model_settings_tool_choice == 'none': + tool_choice = 'none' + elif not tools: + tool_choice = None + elif ( + not model_request_parameters.allow_text_output + and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required + and not isinstance(model_settings_tool_choice, list) + ): tool_choice = 'required' + elif isinstance(model_settings_tool_choice, list): + if len(model_settings_tool_choice) == 1: + name = model_settings_tool_choice[0] + tool_choice = ToolChoiceFunctionParam(type='function', name=name) + else: + # https://github.com/openai/openai-python/issues/2537 + tool_choice = ToolChoiceAllowedParam( + type='allowed_tools', + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'name': name} for name in model_settings_tool_choice], + ) else: tool_choice = 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index f3d515ae69..ef7ab5ff29 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + from httpx import Timeout from typing_extensions import TypedDict @@ -75,6 +77,11 @@ class ModelSettings(TypedDict, total=False): * Mistral """ + tool_choice: Literal['none', 'required', 'auto'] | list[str] | None + """ + TODO(moritz) + """ + parallel_tool_calls: bool """Whether to allow parallel tool calls.