Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ async def _notify_thread_of_new_messages(
thread: AgentThread,
input_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage],
**kwargs: Any,
) -> None:
"""Notify the thread of new messages.

Expand All @@ -346,13 +347,14 @@ async def _notify_thread_of_new_messages(
thread: The thread to notify of new messages.
input_messages: The input messages to notify about.
response_messages: The response messages to notify about.
**kwargs: Any extra arguments to pass from the agent run.
"""
if isinstance(input_messages, ChatMessage) or len(input_messages) > 0:
await thread.on_new_messages(input_messages)
if isinstance(response_messages, ChatMessage) or len(response_messages) > 0:
await thread.on_new_messages(response_messages)
if thread.context_provider:
await thread.context_provider.invoked(input_messages, response_messages)
await thread.context_provider.invoked(input_messages, response_messages, **kwargs)

@property
def display_name(self) -> str:
Expand Down Expand Up @@ -947,7 +949,7 @@ async def run_stream(
"""
input_messages = self._normalize_messages(messages)
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
thread=thread, input_messages=input_messages
thread=thread, input_messages=input_messages, **kwargs
)
agent_name = self._get_agent_name()
# Resolve final tool list (runtime provided tools + local MCP server tools)
Expand Down Expand Up @@ -1011,7 +1013,7 @@ async def run_stream(

response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format)
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs)

@override
def get_new_thread(
Expand Down Expand Up @@ -1206,6 +1208,7 @@ async def _prepare_thread_and_messages(
*,
thread: AgentThread | None,
input_messages: list[ChatMessage] | None = None,
**kwargs: Any,
) -> tuple[AgentThread, ChatOptions, list[ChatMessage]]:
"""Prepare the thread and messages for agent execution.

Expand All @@ -1215,6 +1218,7 @@ async def _prepare_thread_and_messages(
Keyword Args:
thread: The conversation thread.
input_messages: Messages to process.
**kwargs: Any extra arguments to pass from the agent run.

Returns:
A tuple containing:
Expand All @@ -1235,7 +1239,7 @@ async def _prepare_thread_and_messages(
context: Context | None = None
if self.context_provider:
async with self.context_provider:
context = await self.context_provider.invoking(input_messages or [])
context = await self.context_provider.invoking(input_messages or [], **kwargs)
if context:
if context.messages:
thread_messages.extend(context.messages)
Expand Down