Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from datetime import datetime
from typing import Any, Literal, cast, overload

from openai.types.chat.chat_completion_named_tool_choice_param import Function
from openai.types.responses.tool_choice_allowed_param import ToolChoiceAllowedParam
from openai.types.responses.tool_choice_function_param import ToolChoiceFunctionParam
from pydantic import ValidationError
from typing_extensions import assert_never, deprecated

Expand Down Expand Up @@ -56,6 +59,8 @@
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from openai.types.chat.chat_completion_allowed_tool_choice_param import ChatCompletionAllowedToolChoiceParam
from openai.types.chat.chat_completion_allowed_tools_param import ChatCompletionAllowedToolsParam
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
from openai.types.chat.chat_completion_content_part_param import File, FileFile
Expand All @@ -64,14 +69,17 @@
from openai.types.chat.chat_completion_message_function_tool_call_param import (
ChatCompletionMessageFunctionToolCallParam,
)
from openai.types.chat.chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam
from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam
from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam
from openai.types.chat.completion_create_params import (
WebSearchOptions,
WebSearchOptionsUserLocation,
WebSearchOptionsUserLocationApproximate,
)
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
from openai.types.responses.response_input_param import FunctionCallOutput, Message
from openai.types.responses.tool_choice_options import ToolChoiceOptions
from openai.types.shared import ReasoningEffort
from openai.types.shared_params import Reasoning
except ImportError as _import_error:
Expand Down Expand Up @@ -386,13 +394,33 @@ async def _completions_create(
tools = self._get_tools(model_request_parameters)
web_search_options = self._get_web_search_options(model_request_parameters)

if not tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
tool_choice: ChatCompletionToolChoiceOptionParam | None
model_settings_tool_choice = model_settings.get('tool_choice', None)
# Respect an explicit request to disable tool calls.
if model_settings_tool_choice == 'none':
tool_choice = 'none'
elif not tools:
tool_choice = None
elif (
not model_request_parameters.allow_text_output
and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
and not isinstance(model_settings_tool_choice, list)
):
tool_choice = 'required'
elif isinstance(model_settings_tool_choice, list):
if len(model_settings_tool_choice) == 1:
tool_choice = ChatCompletionNamedToolChoiceParam(
type='function',
function=Function(name=model_settings_tool_choice[0]),
)
else:
tool_choice = ChatCompletionAllowedToolChoiceParam(
type='allowed_tools',
allowed_tools=ChatCompletionAllowedToolsParam(
mode='required' if not model_request_parameters.allow_text_output else 'auto',
tools=[{'type': 'function', 'function': {'name': name}} for name in model_settings_tool_choice],
),
)
else:
tool_choice = 'auto'

Expand Down Expand Up @@ -883,10 +911,29 @@ async def _responses_create(
+ self._get_tools(model_request_parameters)
)

if not tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not model_request_parameters.allow_text_output:
tool_choice: ToolChoiceOptions | ToolChoiceAllowedParam | ToolChoiceFunctionParam | None
model_settings_tool_choice = model_settings.get('tool_choice', None)
if model_settings_tool_choice == 'none':
tool_choice = 'none'
elif not tools:
tool_choice = None
elif (
not model_request_parameters.allow_text_output
and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
and not isinstance(model_settings_tool_choice, list)
):
tool_choice = 'required'
elif isinstance(model_settings_tool_choice, list):
if len(model_settings_tool_choice) == 1:
name = model_settings_tool_choice[0]
tool_choice = ToolChoiceFunctionParam(type='function', name=name)
else:
# https://github.com/openai/openai-python/issues/2537
tool_choice = ToolChoiceAllowedParam(
type='allowed_tools',
mode='required' if not model_request_parameters.allow_text_output else 'auto',
tools=[{'type': 'function', 'name': name} for name in model_settings_tool_choice],
)
else:
tool_choice = 'auto'

Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Literal

from httpx import Timeout
from typing_extensions import TypedDict

Expand Down Expand Up @@ -75,6 +77,11 @@ class ModelSettings(TypedDict, total=False):
* Mistral
"""

tool_choice: Literal['none', 'required', 'auto'] | list[str] | None
"""
TODO(moritz)
"""

parallel_tool_calls: bool
"""Whether to allow parallel tool calls.

Expand Down