Skip to content

Commit 31845cd

Browse files
authored
adds model preference support for mcp server sampling requests (microsoft#373)
Example: from mcp.types import ModelPreferences from mcp_extensions import send_sampling_request # Set model preferences model_preferences = ModelPreferences( # Can use hints to prefer models with specific names # hints=[ # ModelHint( # # Prefer models where name starts with `name` value (so `o3` would _include_ `o3-mini`)') # name="gpt-4o", # ) # ], # Setting speed priority to 1 to choose a faster model, like gpt-4o speedPriority=1, # If needing a reasoning model, set intelligence priority to 1 instead # intelligencePriority=1, ) sampling_result = await send_sampling_request( fastmcp_server_context=ctx, system_prompt="<your prompt>", messages=messages, model_preferences=model_preferences, max_tokens=1024, )
1 parent 3758cf5 commit 31845cd

4 files changed

Lines changed: 158 additions & 56 deletions

File tree

assistants/codespace-assistant/assistant/response/response.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,27 @@ async def error_handler(server_config: MCPServerConfig, error: Exception) -> Non
5252
)
5353
)
5454

55+
# Get the AI client configurations for this assistant
56+
generative_ai_client_config = get_ai_client_configs(config, "generative")
57+
reasoning_ai_client_config = get_ai_client_configs(config, "reasoning")
58+
5559
# TODO: This is a temporary hack to allow directing the request to the reasoning model
60+
# Currently we will only use the requested AI client configuration for the turn
5661
request_type = "reasoning" if message.content.startswith("reason:") else "generative"
57-
58-
# Get the AI client configuration based on the request type
59-
request_config, service_config = get_ai_client_configs(config, request_type)
62+
# Set a default AI client configuration based on the request type
63+
default_ai_client_config = (
64+
reasoning_ai_client_config if request_type == "reasoning" else generative_ai_client_config
65+
)
66+
# Set the service and request configurations for the AI client
67+
service_config = default_ai_client_config.service_config
68+
request_config = default_ai_client_config.request_config
6069

6170
# Create a sampling handler for handling requests from the MCP servers
6271
sampling_handler = OpenAISamplingHandler(
63-
service_config=service_config,
64-
request_config=request_config,
72+
ai_client_configs=[
73+
generative_ai_client_config,
74+
reasoning_ai_client_config,
75+
]
6576
)
6677

6778
mcp_sessions = await establish_mcp_sessions(

assistants/codespace-assistant/assistant/response/utils/openai_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import logging
44
from textwrap import dedent
5-
from typing import List, Literal, Tuple
5+
from typing import List, Literal, Tuple, Union
66

7+
from assistant_extensions.ai_clients.config import AzureOpenAIClientConfigModel, OpenAIClientConfigModel
78
from assistant_extensions.mcp import (
89
ExtendedCallToolRequestParams,
910
MCPSession,
@@ -28,11 +29,32 @@
2829

2930
def get_ai_client_configs(
3031
config: AssistantConfigModel, request_type: Literal["generative", "reasoning"] = "generative"
31-
) -> tuple[OpenAIRequestConfig, AzureOpenAIServiceConfig | OpenAIServiceConfig]:
32-
if request_type == "reasoning":
33-
return config.reasoning_ai_client_config.request_config, config.reasoning_ai_client_config.service_config
32+
) -> Union[AzureOpenAIClientConfigModel, OpenAIClientConfigModel]:
33+
def create_ai_client_config(
34+
service_config: AzureOpenAIServiceConfig | OpenAIServiceConfig,
35+
request_config: OpenAIRequestConfig,
36+
) -> AzureOpenAIClientConfigModel | OpenAIClientConfigModel:
37+
if isinstance(service_config, AzureOpenAIServiceConfig):
38+
return AzureOpenAIClientConfigModel(
39+
service_config=service_config,
40+
request_config=request_config,
41+
)
42+
43+
return OpenAIClientConfigModel(
44+
service_config=service_config,
45+
request_config=request_config,
46+
)
3447

35-
return config.generative_ai_client_config.request_config, config.generative_ai_client_config.service_config
48+
if request_type == "reasoning":
49+
return create_ai_client_config(
50+
config.reasoning_ai_client_config.service_config,
51+
config.reasoning_ai_client_config.request_config,
52+
)
53+
54+
return create_ai_client_config(
55+
config.generative_ai_client_config.service_config,
56+
config.generative_ai_client_config.request_config,
57+
)
3658

3759

3860
async def get_completion(

libraries/python/assistant-extensions/assistant_extensions/mcp/_openai_utils.py

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import deepmerge
55
from mcp import ClientSession, CreateMessageResult, SamplingMessage
66
from mcp.shared.context import RequestContext
7-
from mcp.types import CreateMessageRequestParams, ErrorData, ImageContent, TextContent
7+
from mcp.types import (
8+
CreateMessageRequestParams,
9+
ErrorData,
10+
ImageContent,
11+
ModelPreferences,
12+
TextContent,
13+
)
814
from openai.types.chat import (
915
ChatCompletion,
1016
ChatCompletionAssistantMessageParam,
@@ -14,8 +20,9 @@
1420
ChatCompletionToolParam,
1521
ChatCompletionUserMessageParam,
1622
)
17-
from openai_client import OpenAIRequestConfig, ServiceConfig, create_client
23+
from openai_client import OpenAIRequestConfig, create_client
1824

25+
from ..ai_clients.config import AzureOpenAIClientConfigModel, OpenAIClientConfigModel
1926
from ._model import MCPSamplingMessageHandler
2027
from ._sampling_handler import SamplingHandler
2128

@@ -40,14 +47,14 @@ def message_handler(self) -> MCPSamplingMessageHandler:
4047

4148
def __init__(
4249
self,
43-
service_config: ServiceConfig | None = None,
44-
request_config: OpenAIRequestConfig | None = None,
50+
ai_client_configs: list[
51+
Union[AzureOpenAIClientConfigModel, OpenAIClientConfigModel]
52+
],
4553
assistant_mcp_tools: list[ChatCompletionToolParam] | None = None,
4654
message_processor: OpenAIMessageProcessor | None = None,
4755
handler: MCPSamplingMessageHandler | None = None,
4856
) -> None:
49-
self.service_config = service_config
50-
self.request_config = request_config
57+
self.ai_client_configs = ai_client_configs
5158
self.assistant_mcp_tools = assistant_mcp_tools
5259

5360
# set a default message processor that converts sampling messages to
@@ -82,27 +89,21 @@ async def _default_message_handler(
8289
) -> CreateMessageResult | ErrorData:
8390
logger.info(f"Sampling handler invoked with context: {context}")
8491

85-
if not self.service_config or not self.request_config:
86-
raise ValueError(
87-
"Service config and request config must be set before handling messages."
88-
)
92+
ai_client_config = self._ai_client_config_from_model_preferences(
93+
params.modelPreferences
94+
)
8995

90-
try:
91-
completion_args = await self._create_completion_request(
92-
request=params,
93-
request_config=self.request_config,
94-
template_processor=self.message_processor,
95-
)
96-
except Exception as e:
97-
logger.exception(f"Error creating completion request: {e}")
98-
return ErrorData(
99-
code=500,
100-
message="Error creating completion request.",
101-
data=e,
102-
)
96+
if not ai_client_config:
97+
raise ValueError("No AI client configs defined for sampling requests.")
98+
99+
completion_args = await self._create_completion_request(
100+
request=params,
101+
request_config=ai_client_config.request_config,
102+
template_processor=self.message_processor,
103+
)
103104

104105
completion: ChatCompletion | None = None
105-
async with create_client(self.service_config) as client:
106+
async with create_client(ai_client_config.service_config) as client:
106107
completion = await client.chat.completions.create(**completion_args)
107108

108109
if completion is None:
@@ -112,12 +113,6 @@ async def _default_message_handler(
112113
)
113114

114115
choice = completion.choices[0]
115-
if choice.message.content is None:
116-
return ErrorData(
117-
code=500,
118-
message="No content returned from completion choice.",
119-
)
120-
121116
content = choice.message.content
122117
if content is None:
123118
content = "[no content]"
@@ -141,7 +136,61 @@ async def handle_message(
141136
context: RequestContext[ClientSession, Any],
142137
params: CreateMessageRequestParams,
143138
) -> CreateMessageResult | ErrorData:
144-
return await self._message_handler(context, params)
139+
try:
140+
return await self._message_handler(context, params)
141+
except Exception as e:
142+
logger.error(f"Error handling sampling request: {e}")
143+
code = getattr(e, "status_code", 500)
144+
message = getattr(e, "message", "Error handling sampling request.")
145+
data = str(e)
146+
return ErrorData(code=code, message=message, data=data)
147+
148+
def _ai_client_config_from_model_preferences(
149+
self, model_preferences: ModelPreferences | None
150+
) -> Union[AzureOpenAIClientConfigModel, OpenAIClientConfigModel] | None:
151+
"""
152+
Returns an AI client config from model preferences.
153+
"""
154+
155+
# if no configs are provided, return None
156+
if not self.ai_client_configs or len(self.ai_client_configs) == 0:
157+
return None
158+
159+
# if not provided, return the first config
160+
if not model_preferences:
161+
return self.ai_client_configs[0]
162+
163+
# if hints are provided, return the first hint where the name value matches
164+
# the start of the model name
165+
if model_preferences.hints:
166+
for hint in model_preferences.hints:
167+
if not hint.name:
168+
continue
169+
for ai_client_config in self.ai_client_configs:
170+
if ai_client_config.request_config.model.startswith(hint.name):
171+
return ai_client_config
172+
173+
# if any of the priority values are set, return the first config that matches our
174+
# criteria: speedPriority equates to non-reasoning models, intelligencePriority
175+
# equates to reasoning models for now
176+
# note: we are ignoring costPriority for now
177+
speed_priority = model_preferences.speedPriority or 0
178+
intelligence_priority = model_preferences.intelligencePriority or 0
179+
# cost_priority = 0 # ignored for now
180+
181+
# later we will support more than just reasoning or non-reasoning choices, but
182+
# for now we can keep it simple
183+
use_reasoning_model = intelligence_priority > speed_priority
184+
185+
for ai_client_config in self.ai_client_configs:
186+
if (
187+
ai_client_config.request_config.is_reasoning_model
188+
== use_reasoning_model
189+
):
190+
return ai_client_config
191+
192+
# failing to find a config via preferences, return first config
193+
return self.ai_client_configs[0]
145194

146195
async def _create_completion_request(
147196
self,

mcp-servers/mcp-server-giphy/mcp_server/sampling.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, List, Union
99

1010
from mcp.server.fastmcp import Context
11-
from mcp.types import ImageContent, SamplingMessage, TextContent
11+
from mcp.types import ImageContent, ModelPreferences, SamplingMessage, TextContent
1212
from mcp_extensions import send_sampling_request, send_tool_call_progress
1313

1414
from .utils import fetch_url
@@ -74,8 +74,8 @@ async def generate_sampling_messages(search_results: List[Dict]) -> List[Samplin
7474
text_content = get_text_content(result)
7575
if text_content is not None:
7676
messages.append(SamplingMessage(role="user", content=text_content))
77-
if image_content is not None:
78-
messages.append(SamplingMessage(role="user", content=image_content))
77+
# if image_content is not None:
78+
# messages.append(SamplingMessage(role="user", content=image_content))
7979
return messages
8080

8181

@@ -108,19 +108,39 @@ async def perform_sampling(
108108
# Generate sampling messages
109109
messages += await generate_sampling_messages(search_results)
110110

111+
# Set model preferences
112+
model_preferences = ModelPreferences(
113+
# Can use hints to prefer models with specific names
114+
# hints=[
115+
# ModelHint(
116+
# # Prefer models where name starts with `name` value (so `o3` would _include_ `o3-mini`)')
117+
# name="gpt-4o",
118+
# )
119+
# ],
120+
# Setting speed priority to 1 to choose a faster model, like gpt-4o
121+
speedPriority=1,
122+
# If needing a reasoning model, set intelligence priority to 1 instead
123+
# intelligencePriority=1,
124+
)
125+
111126
await send_tool_call_progress(ctx, "choosing image...")
112127

113128
# FIXME add support for structured output to enforce image selection
114129
# Send sampling request to FastMCP server
115-
sampling_result = await send_sampling_request(
116-
fastmcp_server_context=ctx,
117-
system_prompt=dedent(f"""
118-
Analyze these images and choose the best choice based on provided context.
119-
Context: {context}
120-
Return the url for the chosen image.
121-
""").strip(),
122-
messages=messages,
123-
max_tokens=100,
124-
)
125-
126-
return sampling_result.content
130+
try:
131+
sampling_result = await send_sampling_request(
132+
fastmcp_server_context=ctx,
133+
system_prompt=dedent(f"""
134+
Analyze these images and choose the best choice based on provided context.
135+
Context: {context}
136+
Return the url for the chosen image.
137+
""").strip(),
138+
messages=messages,
139+
model_preferences=model_preferences,
140+
max_tokens=100,
141+
)
142+
143+
return sampling_result.content
144+
except Exception as e:
145+
logger.error(f"Failed to perform sampling: {str(e)}")
146+
raise e

0 commit comments

Comments
 (0)