Skip to content

Commit be5cda6

Browse files
samuelcolvinKludex
andauthored
Support MCP sampling (#1884)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent f2646de commit be5cda6

File tree

8 files changed

+607
-128
lines changed

8 files changed

+607
-128
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ test: ## Run tests and collect coverage data
6464
@uv run coverage report
6565

6666
.PHONY: test-fast
67-
test-fast: ## Same as test except no coverage. ~1/4th the time depending on hardware.
67+
test-fast: ## Same as test except no coverage and 4x faster depending on hardware
6868
uv run pytest -n auto --dist=loadgroup
6969

7070
.PHONY: test-all-python
@@ -78,12 +78,12 @@ test-all-python: ## Run tests on Python 3.9 to 3.13
7878
@uv run coverage report
7979

8080
.PHONY: testcov
81-
testcov: test ## Run tests and generate a coverage report
81+
testcov: test ## Run tests and generate an HTML coverage report
8282
@echo "building coverage html"
8383
@uv run coverage html
8484

8585
.PHONY: test-mrp
86-
test-mrp: ## Build and tests of mcp-run-python
86+
test-mrp: ## Build and tests of mcp-run-python
8787
cd mcp-run-python && deno task build
8888
uv run --package mcp-run-python pytest mcp-run-python -v
8989

pydantic_ai_slim/pydantic_ai/_mcp.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import base64
2+
from collections.abc import Sequence
3+
from typing import Literal
4+
5+
from . import exceptions, messages
6+
7+
try:
8+
from mcp import types as mcp_types
9+
except ImportError as _import_error:
10+
raise ImportError(
11+
'Please install the `mcp` package to use the MCP server, '
12+
'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
13+
) from _import_error
14+
15+
16+
def map_from_mcp_params(params: mcp_types.CreateMessageRequestParams) -> list[messages.ModelMessage]:
17+
"""Convert from MCP create message request parameters to pydantic-ai messages."""
18+
pai_messages: list[messages.ModelMessage] = []
19+
request_parts: list[messages.ModelRequestPart] = []
20+
if params.systemPrompt:
21+
request_parts.append(messages.SystemPromptPart(content=params.systemPrompt))
22+
response_parts: list[messages.ModelResponsePart] = []
23+
for msg in params.messages:
24+
content = msg.content
25+
if msg.role == 'user':
26+
# if there are any response parts, add a response message wrapping them
27+
if response_parts:
28+
pai_messages.append(messages.ModelResponse(parts=response_parts))
29+
response_parts = []
30+
31+
# TODO(Marcelo): We can reuse the `_map_tool_result_part` from the mcp module here.
32+
if isinstance(content, mcp_types.TextContent):
33+
user_part_content: str | Sequence[messages.UserContent] = content.text
34+
else:
35+
# image content
36+
user_part_content = [
37+
messages.BinaryContent(data=base64.b64decode(content.data), media_type=content.mimeType)
38+
]
39+
40+
request_parts.append(messages.UserPromptPart(content=user_part_content))
41+
else:
42+
# role is assistant
43+
# if there are any request parts, add a request message wrapping them
44+
if request_parts:
45+
pai_messages.append(messages.ModelRequest(parts=request_parts))
46+
request_parts = []
47+
48+
response_parts.append(map_from_sampling_content(content))
49+
50+
if response_parts:
51+
pai_messages.append(messages.ModelResponse(parts=response_parts))
52+
if request_parts:
53+
pai_messages.append(messages.ModelRequest(parts=request_parts))
54+
return pai_messages
55+
56+
57+
def map_from_pai_messages(pai_messages: list[messages.ModelMessage]) -> tuple[str, list[mcp_types.SamplingMessage]]:
58+
"""Convert from pydantic-ai messages to MCP sampling messages.
59+
60+
Returns:
61+
A tuple containing the system prompt and a list of sampling messages.
62+
"""
63+
sampling_msgs: list[mcp_types.SamplingMessage] = []
64+
65+
def add_msg(
66+
role: Literal['user', 'assistant'],
67+
content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
68+
):
69+
sampling_msgs.append(mcp_types.SamplingMessage(role=role, content=content))
70+
71+
system_prompt: list[str] = []
72+
for pai_message in pai_messages:
73+
if isinstance(pai_message, messages.ModelRequest):
74+
if pai_message.instructions is not None:
75+
system_prompt.append(pai_message.instructions)
76+
77+
for part in pai_message.parts:
78+
if isinstance(part, messages.SystemPromptPart):
79+
system_prompt.append(part.content)
80+
if isinstance(part, messages.UserPromptPart):
81+
if isinstance(part.content, str):
82+
add_msg('user', mcp_types.TextContent(type='text', text=part.content))
83+
else:
84+
for chunk in part.content:
85+
if isinstance(chunk, str):
86+
add_msg('user', mcp_types.TextContent(type='text', text=chunk))
87+
elif isinstance(chunk, messages.BinaryContent) and chunk.is_image:
88+
add_msg(
89+
'user',
90+
mcp_types.ImageContent(
91+
type='image',
92+
data=base64.b64decode(chunk.data).decode(),
93+
mimeType=chunk.media_type,
94+
),
95+
)
96+
# TODO(Marcelo): Add support for audio content.
97+
else:
98+
raise NotImplementedError(f'Unsupported content type: {type(chunk)}')
99+
else:
100+
add_msg('assistant', map_from_model_response(pai_message))
101+
return ''.join(system_prompt), sampling_msgs
102+
103+
104+
def map_from_model_response(model_response: messages.ModelResponse) -> mcp_types.TextContent:
105+
"""Convert from a model response to MCP text content."""
106+
text_parts: list[str] = []
107+
for part in model_response.parts:
108+
if isinstance(part, messages.TextPart):
109+
text_parts.append(part.content)
110+
# TODO(Marcelo): We should ignore ThinkingPart here.
111+
else:
112+
raise exceptions.UnexpectedModelBehavior(f'Unexpected part type: {type(part).__name__}, expected TextPart')
113+
return mcp_types.TextContent(type='text', text=''.join(text_parts))
114+
115+
116+
def map_from_sampling_content(
117+
content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
118+
) -> messages.TextPart:
119+
"""Convert from sampling content to a pydantic-ai text part."""
120+
if isinstance(content, mcp_types.TextContent): # pragma: no branch
121+
return messages.TextPart(content=content.text)
122+
else:
123+
raise NotImplementedError('Image and Audio responses in sampling are not yet supported')

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1691,14 +1691,23 @@ def is_end_node(
16911691
return isinstance(node, End)
16921692

16931693
@asynccontextmanager
1694-
async def run_mcp_servers(self) -> AsyncIterator[None]:
1694+
async def run_mcp_servers(
1695+
self, model: models.Model | models.KnownModelName | str | None = None
1696+
) -> AsyncIterator[None]:
16951697
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
16961698
16971699
Returns: a context manager to start and shutdown the servers.
16981700
"""
1701+
try:
1702+
sampling_model: models.Model | None = self._get_model(model)
1703+
except exceptions.UserError: # pragma: no cover
1704+
sampling_model = None
1705+
16991706
exit_stack = AsyncExitStack()
17001707
try:
17011708
for mcp_server in self._mcp_servers:
1709+
if sampling_model is not None: # pragma: no branch
1710+
mcp_server.sampling_model = sampling_model
17021711
await exit_stack.enter_async_context(mcp_server)
17031712
yield
17041713
finally:

0 commit comments

Comments
 (0)