diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 83f19c294c..201eed7622 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -225,13 +225,27 @@ async def direct_call_tool( except McpError as e: raise exceptions.ModelRetry(e.error.message) - content = [await self._map_tool_result_part(part) for part in result.content] - if result.isError: - text = '\n'.join(str(part) for part in content) - raise exceptions.ModelRetry(text) - else: - return content[0] if len(content) == 1 else content + message: str | None = None + if result.content: # pragma: no branch + text_parts = [part.text for part in result.content if isinstance(part, mcp_types.TextContent)] + message = '\n'.join(text_parts) + + raise exceptions.ModelRetry(message or 'MCP tool call failed') + + # Prefer structured content if there are only text parts, which per the docs would contain the JSON-encoded structured content for backward compatibility. + # See https://github.com/modelcontextprotocol/python-sdk#structured-output + if (structured := result.structuredContent) and not any( + not isinstance(part, mcp_types.TextContent) for part in result.content + ): + # The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function. + # See https://github.com/modelcontextprotocol/python-sdk#structured-output + if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured: + return structured['result'] + return structured + + mapped = [await self._map_tool_result_part(part) for part in result.content] + return mapped[0] if len(mapped) == 1 else mapped async def call_tool( self, diff --git a/tests/cassettes/test_mcp/test_tool_returning_unstructured_dict.yaml b/tests/cassettes/test_mcp/test_tool_returning_unstructured_dict.yaml new file mode 100644 index 0000000000..df70c1f98e --- /dev/null +++ b/tests/cassettes/test_mcp/test_tool_returning_unstructured_dict.yaml @@ -0,0 +1,520 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '3845' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: Get me an unstructured dict, respond on one line + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + additionalProperties: false + properties: + celsius: + type: number + required: + - celsius + type: object + strict: true + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + additionalProperties: false + properties: + location: + type: string + required: + - location + type: object + strict: true + type: function + - function: + description: '' + name: get_image_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_unstructured_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + additionalProperties: false + properties: + value: + default: false + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + additionalProperties: false + properties: + foo: + type: string + required: + - foo + type: object + strict: true + type: function + - function: + description: Use elicitation callback to ask the user a question. + name: use_elicitation + parameters: + additionalProperties: false + properties: + question: + type: string + required: + - question + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1073' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1521' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_unstructured_dict + id: call_R0n2R7S9vL2aZOX25T9jahTd + type: function + created: 1759264706 + id: chatcmpl-CLbP82ODQMEznhobUKdq6Rjn9Aa12 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f33640a400 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 343 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 355 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '4119' + content-type: + - application/json + cookie: + - __cf_bm=tGZk0hevqXH_kU7UScGt.EqwPxg4OTuqy4hBf.hGmwo-1759264708-1.0.1.1-dHK7SUjNTLSpKF9cvNn7RNdg6UBSqsWYbU7k7Dq5oE9NdDkkv4LRlBEYWPjFafQKjfrS_JgwrAIM5Je9qCJeucIatzS2M02wLccvrfU2c6k; + _cfuvid=d8Rl.NMqIC4giHLO2QsuslT6sliz_xnT6LK0PLEZEsA-1759264708042-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: Get me an unstructured dict, respond on one line + role: user + - content: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_unstructured_dict + id: call_R0n2R7S9vL2aZOX25T9jahTd + type: function + - content: '{"foo":"bar","baz":123}' + role: tool + tool_call_id: call_R0n2R7S9vL2aZOX25T9jahTd + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + additionalProperties: false + properties: + celsius: + type: number + required: + - celsius + type: object + strict: true + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + additionalProperties: false + properties: + location: + type: string + required: + - location + type: object + strict: true + type: function + - function: + description: '' + name: get_image_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_unstructured_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + additionalProperties: false + properties: + value: + default: false + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + additionalProperties: false + properties: + foo: + type: string + required: + - foo + type: object + strict: true + type: function + - function: + description: Use elicitation callback to ask the user a question. + name: use_elicitation + parameters: + additionalProperties: false + properties: + question: + type: string + required: + - question + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '833' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '318' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"foo":"bar","baz":123}' + refusal: null + role: assistant + created: 1759264708 + id: chatcmpl-CLbPAOYN3jPYdvYeD8JNOOXF5N554 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f33640a400 + usage: + completion_tokens: 10 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 374 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 384 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 641f396be0..c68c17d8b2 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -6,6 +6,7 @@ from mcp.server.session import ServerSession from mcp.types import ( BlobResourceContents, + CreateMessageResult, EmbeddedResource, ResourceLink, SamplingMessage, @@ -134,6 +135,11 @@ async def get_dict() -> dict[str, Any]: return {'foo': 'bar', 'baz': 123} +@mcp.tool(structured_output=False) +async def get_unstructured_dict() -> dict[str, Any]: + return {'foo': 'bar', 'baz': 123} + + @mcp.tool() async def get_error(value: bool = False): if value: @@ -185,7 +191,7 @@ async def echo_deps(ctx: Context[ServerSession, None]) -> dict[str, Any]: @mcp.tool() -async def use_sampling(ctx: Context[ServerSession, None], foo: str) -> str: +async def use_sampling(ctx: Context[ServerSession, None], foo: str) -> CreateMessageResult: """Use sampling callback.""" result = await ctx.session.create_message( @@ -198,7 +204,7 @@ async def use_sampling(ctx: Context[ServerSession, None], foo: str) -> str: temperature=0.5, stop_sequences=['potato'], ) - return result.model_dump_json(indent=2) + return result class UserResponse(BaseModel): diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b96ca9bf78..c9f2dcd2d7 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -77,14 +77,14 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(17) + assert len(tools) == snapshot(18) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') # Test calling the temperature conversion tool result = await server.direct_call_tool('celsius_to_fahrenheit', {'celsius': 0}) - assert result == snapshot('32.0') + assert result == snapshot(32.0) async def test_reentrant_context_manager(): @@ -130,7 +130,7 @@ async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): result = await server.call_tool( 'foo_celsius_to_fahrenheit', {'celsius': 0}, run_context, tools['foo_celsius_to_fahrenheit'] ) - assert result == snapshot('32.0') + assert result == snapshot(32.0) async def test_stdio_server_with_cwd(run_context: RunContext[int]): @@ -138,7 +138,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(17) + assert len(tools) == snapshot(18) async def test_process_tool_call(run_context: RunContext[int]) -> int: @@ -237,7 +237,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) parts=[ ToolReturnPart( tool_name='celsius_to_fahrenheit', - content='32.0', + content=32.0, tool_call_id='call_QssdxTGkPblTYHmyVES1tKBj', timestamp=IsDatetime(), ) @@ -310,10 +310,6 @@ async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(17) - assert tools[13].name == 'get_log_level' - result = await server.direct_call_tool('get_log_level', {}) assert result == snapshot('unset') @@ -983,6 +979,76 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ) +async def test_tool_returning_unstructured_dict(allow_model_requests: None, agent: Agent): + async with agent: + result = await agent.run('Get me an unstructured dict, respond on one line') + assert result.output == snapshot('{"foo":"bar","baz":123}') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Get me an unstructured dict, respond on one line', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_unstructured_dict', args='{}', tool_call_id='call_R0n2R7S9vL2aZOX25T9jahTd' + ) + ], + usage=RequestUsage( + input_tokens=343, + output_tokens=12, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + provider_name='openai', + provider_details={'finish_reason': 'tool_calls'}, + provider_response_id='chatcmpl-CLbP82ODQMEznhobUKdq6Rjn9Aa12', + finish_reason='tool_call', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_unstructured_dict', + content={'foo': 'bar', 'baz': 123}, + tool_call_id='call_R0n2R7S9vL2aZOX25T9jahTd', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"foo":"bar","baz":123}')], + usage=RequestUsage( + input_tokens=374, + output_tokens=10, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + provider_name='openai', + provider_details={'finish_reason': 'stop'}, + provider_response_id='chatcmpl-CLbPAOYN3jPYdvYeD8JNOOXF5N554', + finish_reason='stop', + ), + ] + ) + + async def test_tool_returning_error(allow_model_requests: None, agent: Agent): async with agent: result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') @@ -1259,9 +1325,9 @@ async def test_client_sampling(run_context: RunContext[int]): result = await server.direct_call_tool('use_sampling', {'foo': 'bar'}) assert result == snapshot( { - 'meta': None, + '_meta': None, 'role': 'assistant', - 'content': {'type': 'text', 'text': 'sampling model response', 'annotations': None, 'meta': None}, + 'content': {'type': 'text', 'text': 'sampling model response', 'annotations': None, '_meta': None}, 'model': 'test', 'stopReason': None, }