Skip to content

fix: stream in litellm + adk and add corresponding integration tests #1387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,12 @@ async def generate_content_async(
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
content=text,
tool_calls=tool_calls,
)
)
)
text = ""
function_calls.clear()
elif finish_reason == "stop" and text:
aggregated_llm_response = _message_to_generate_content_response(
Expand Down
109 changes: 105 additions & 4 deletions tests/integration/models/test_litellm_no_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,26 @@
from google.genai.types import Part
import pytest

_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"

_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"

_SYSTEM_PROMPT = """You are a helpful assistant."""


def get_weather(city: str) -> str:
"""Simulates a web search. Use it get information on weather.

Args:
city: A string containing the location to get weather information for.

Returns:
A string with the simulated weather information for the queried city.
"""
if "sf" in city.lower() or "san francisco" in city.lower():
return "It's 70 degrees and foggy."
return "It's 80 degrees and sunny."


@pytest.fixture
def oss_llm():
return LiteLlm(model=_TEST_MODEL_NAME)
Expand All @@ -44,17 +58,57 @@ def llm_request():
)


@pytest.fixture
def llm_request_with_tools():
return LlmRequest(
model=_TEST_MODEL_NAME,
contents=[
Content(
role="user",
parts=[
Part.from_text(text="What is the weather in San Francisco?")
],
)
],
config=types.GenerateContentConfig(
temperature=0.1,
response_modalities=[types.Modality.TEXT],
system_instruction=_SYSTEM_PROMPT,
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="get_weather",
description="Get the weather in a given location",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"city": types.Schema(
type=types.Type.STRING,
description=(
"The city to get the weather for."
),
),
},
required=["city"],
),
)
]
)
],
),
)


@pytest.mark.asyncio
async def test_generate_content_async(oss_llm, llm_request):
async for response in oss_llm.generate_content_async(llm_request):
assert isinstance(response, LlmResponse)
assert response.content.parts[0].text


# Note that, this test disabled streaming because streaming is not supported
# properly in the current test model for now.
@pytest.mark.asyncio
async def test_generate_content_async_stream(oss_llm, llm_request):
async def test_generate_content_async(oss_llm, llm_request):
responses = [
resp
async for resp in oss_llm.generate_content_async(
Expand All @@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request):
]
part = responses[0].content.parts[0]
assert len(part.text) > 0


@pytest.mark.asyncio
async def test_generate_content_async_with_tools(
oss_llm, llm_request_with_tools
):
responses = [
resp
async for resp in oss_llm.generate_content_async(
llm_request_with_tools, stream=False
)
]
function_call = responses[0].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"


@pytest.mark.asyncio
async def test_generate_content_async_stream(oss_llm, llm_request):
responses = [
resp
async for resp in oss_llm.generate_content_async(llm_request, stream=True)
]
text = ""
for i in range(len(responses) - 1):
assert responses[i].partial is True
assert responses[i].content.parts[0].text
text += responses[i].content.parts[0].text

# Last message should be accumulated text
assert responses[-1].content.parts[0].text == text
assert not responses[-1].partial


@pytest.mark.asyncio
async def test_generate_content_async_stream_with_tools(
oss_llm, llm_request_with_tools
):
responses = [
resp
async for resp in oss_llm.generate_content_async(
llm_request_with_tools, stream=True
)
]
function_call = responses[-1].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"
25 changes: 18 additions & 7 deletions tests/integration/models/test_litellm_with_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.adk.models.lite_llm import LiteLlm
from google.genai import types
from google.genai.types import Content
Expand All @@ -23,12 +22,11 @@

litellm.add_function_to_prompt = True

_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"

_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"

_SYSTEM_PROMPT = """
You are a helpful assistant, and call tools optionally.
If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs.
If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs.
"""


Expand All @@ -40,7 +38,7 @@
"properties": {
"city": {
"type": "string",
"description": "The city, e.g. San Francisco",
"description": "The city to get the weather for.",
},
},
"required": ["city"],
Expand Down Expand Up @@ -87,8 +85,6 @@ def llm_request():
)


# Note that, this test disabled streaming because streaming is not supported
# properly in the current test model for now.
@pytest.mark.asyncio
async def test_generate_content_asyn_with_function(
oss_llm_with_function, llm_request
Expand All @@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function(
function_call = responses[0].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"


@pytest.mark.asyncio
async def test_generate_content_asyn_stream_with_function(
oss_llm_with_function, llm_request
):
responses = [
resp
async for resp in oss_llm_with_function.generate_content_async(
llm_request, stream=True
)
]
function_call = responses[-1].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"