Skip to content

Commit

Permalink
fix: update SK adapter stream tool call processing. (#5449)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

<!-- Please give a short summary of the change and the problem this
solves. -->

The current stream processing of SK model adapter returns on the first
function call chunk but this behavior is incorrect end ends up returning
with an incomplete function call. The observed behavior is that the
function name and arguments are split into different chunks and this
update correctly processes the chunks in this way.

## Related issue number

<!-- For example: "Closes #1234" -->

Fixes the reply in #5420 

## Checks

- [ ] I've included any doc changes needed for
https://microsoft.github.io/autogen/. See
https://microsoft.github.io/autogen/docs/Contribute#documentation to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed.

---------

Co-authored-by: Leonardo Pinheiro <[email protected]>
  • Loading branch information
lspinheiro and lpinheiroms authored Feb 9, 2025
1 parent b5eaab8 commit b868e32
Show file tree
Hide file tree
Showing 2 changed files with 399 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, Literal, Mapping, Optional, Sequence
import warnings

from autogen_core import FunctionCall
from autogen_core._cancellation_token import CancellationToken
Expand All @@ -18,7 +19,6 @@
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.functions.kernel_plugin import KernelPlugin
from semantic_kernel.kernel import Kernel
from typing_extensions import AsyncGenerator, Union
Expand Down Expand Up @@ -427,6 +427,28 @@ async def create(
thought=thought,
)

@staticmethod
def _merge_function_call_content(existing_call: FunctionCallContent, new_chunk: FunctionCallContent) -> None:
"""Helper to merge partial argument chunks from new_chunk into existing_call."""
if isinstance(existing_call.arguments, str) and isinstance(new_chunk.arguments, str):
existing_call.arguments += new_chunk.arguments
elif isinstance(existing_call.arguments, dict) and isinstance(new_chunk.arguments, dict):
existing_call.arguments.update(new_chunk.arguments)
elif not existing_call.arguments or existing_call.arguments in ("{}", ""):
# If existing had no arguments yet, just take the new one
existing_call.arguments = new_chunk.arguments
else:
# If there's a mismatch (str vs dict), handle as needed
warnings.warn("Mismatch in argument types during merge. Existing arguments retained.", stacklevel=2)

# Optionally update name/function_name if newly provided
if new_chunk.name:
existing_call.name = new_chunk.name
if new_chunk.plugin_name:
existing_call.plugin_name = new_chunk.plugin_name
if new_chunk.function_name:
existing_call.function_name = new_chunk.function_name

async def create_stream(
self,
messages: Sequence[LLMMessage],
Expand Down Expand Up @@ -460,6 +482,7 @@ async def create_stream(
Yields:
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
"""

kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
user_settings = self._get_prompt_settings(extra_create_args)
Expand All @@ -468,54 +491,105 @@ async def create_stream(

prompt_tokens = 0
completion_tokens = 0
accumulated_content = ""
accumulated_text = ""

# Keep track of in-progress function calls. Keyed by ID
# because partial chunks for the same function call might arrive separately.
function_calls_in_progress: dict[str, FunctionCallContent] = {}

# Track the ID of the last function call we saw so we can continue
# accumulating chunk arguments for that call if new items have id=None
last_function_call_id: Optional[str] = None

async for streaming_messages in self._sk_client.get_streaming_chat_message_contents(
chat_history, settings=settings, kernel=kernel
):
for msg in streaming_messages:
if not isinstance(msg, StreamingChatMessageContent):
continue

# Track token usage
if msg.metadata and "usage" in msg.metadata:
usage = msg.metadata["usage"]
prompt_tokens = getattr(usage, "prompt_tokens", 0)
completion_tokens = getattr(usage, "completion_tokens", 0)

# Check for function calls
if any(isinstance(item, FunctionCallContent) for item in msg.items):
function_calls = self._process_tool_calls(msg)
# Process function call deltas
for item in msg.items:
if isinstance(item, FunctionCallContent):
# If the chunk has a valid ID, we start or continue that ID explicitly
if item.id:
last_function_call_id = item.id
if last_function_call_id not in function_calls_in_progress:
function_calls_in_progress[last_function_call_id] = item
else:
# Merge partial arguments into existing call
existing_call = function_calls_in_progress[last_function_call_id]
self._merge_function_call_content(existing_call, item)
else:
# item.id is None, so we assume it belongs to the last known ID
if not last_function_call_id:
# No call in progress means we can't merge
# You could either skip or raise an error here
warnings.warn(
"Received function call chunk with no ID and no call in progress.", stacklevel=2
)
continue

existing_call = function_calls_in_progress[last_function_call_id]
# Merge partial chunk
self._merge_function_call_content(existing_call, item)

# Check if the model signaled tool_calls finished
if msg.finish_reason == "tool_calls" and function_calls_in_progress:
calls_to_yield: list[FunctionCall] = []
for _, call_content in function_calls_in_progress.items():
plugin_name = call_content.plugin_name or ""
function_name = call_content.function_name
if plugin_name:
full_name = f"{plugin_name}-{function_name}"
else:
full_name = function_name

if isinstance(call_content.arguments, dict):
arguments = json.dumps(call_content.arguments)
else:
assert isinstance(call_content.arguments, str)
arguments = call_content.arguments or "{}"

calls_to_yield.append(
FunctionCall(
id=call_content.id or "unknown_id",
name=full_name,
arguments=arguments,
)
)
# Yield all function calls in progress
yield CreateResult(
content=function_calls,
content=calls_to_yield,
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
)
return

# Handle text content
# Handle any plain text in the message
if msg.content:
accumulated_content += msg.content
accumulated_text += msg.content
yield msg.content

# Final yield if there was text content
if accumulated_content:
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens

if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_content = parse_r1_content(accumulated_content)
else:
thought = None

yield CreateResult(
content=accumulated_content,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)
# If we exit the loop without tool calls finishing, yield whatever text was accumulated
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens

thought = None
if isinstance(accumulated_text, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_text = parse_r1_content(accumulated_text)

yield CreateResult(
content=accumulated_text,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)

def actual_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens)
Expand Down
Loading

0 comments on commit b868e32

Please sign in to comment.