From ce3971e3fccc7fd363c618588c65b5baaaddba6c Mon Sep 17 00:00:00 2001 From: David Jadczyk Date: Fri, 24 Oct 2025 11:57:37 +0200 Subject: [PATCH] feat: pass agent run kwargs through to context providers `_prepare_thread_and_messages` and `_notify_thread_of_new_messages` now are passing through the agent's `run` kwargs, sothat they can provide the context_provider with said kwargs #1679 --- python/packages/core/agent_framework/_agents.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 0125adb188..6fb28e5aa1 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -337,6 +337,7 @@ async def _notify_thread_of_new_messages( thread: AgentThread, input_messages: ChatMessage | Sequence[ChatMessage], response_messages: ChatMessage | Sequence[ChatMessage], + **kwargs: Any, ) -> None: """Notify the thread of new messages. @@ -346,13 +347,14 @@ async def _notify_thread_of_new_messages( thread: The thread to notify of new messages. input_messages: The input messages to notify about. response_messages: The response messages to notify about. + **kwargs: Any extra arguments to pass from the agent run. """ if isinstance(input_messages, ChatMessage) or len(input_messages) > 0: await thread.on_new_messages(input_messages) if isinstance(response_messages, ChatMessage) or len(response_messages) > 0: await thread.on_new_messages(response_messages) if thread.context_provider: - await thread.context_provider.invoked(input_messages, response_messages) + await thread.context_provider.invoked(input_messages, response_messages, **kwargs) @property def display_name(self) -> str: @@ -947,7 +949,7 @@ async def run_stream( """ input_messages = self._normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( - thread=thread, input_messages=input_messages + thread=thread, input_messages=input_messages, **kwargs ) agent_name = self._get_agent_name() # Resolve final tool list (runtime provided tools + local MCP server tools) @@ -1011,7 +1013,7 @@ async def run_stream( response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) - await self._notify_thread_of_new_messages(thread, input_messages, response.messages) + await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs) @override def get_new_thread( @@ -1206,6 +1208,7 @@ async def _prepare_thread_and_messages( *, thread: AgentThread | None, input_messages: list[ChatMessage] | None = None, + **kwargs: Any, ) -> tuple[AgentThread, ChatOptions, list[ChatMessage]]: """Prepare the thread and messages for agent execution. @@ -1215,6 +1218,7 @@ async def _prepare_thread_and_messages( Keyword Args: thread: The conversation thread. input_messages: Messages to process. + **kwargs: Any extra arguments to pass from the agent run. Returns: A tuple containing: @@ -1235,7 +1239,7 @@ async def _prepare_thread_and_messages( context: Context | None = None if self.context_provider: async with self.context_provider: - context = await self.context_provider.invoking(input_messages or []) + context = await self.context_provider.invoking(input_messages or [], **kwargs) if context: if context.messages: thread_messages.extend(context.messages)