diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index f25f0cdb76..ac6c833a9f 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -154,8 +154,8 @@ def _safe_json_serialize(obj) -> str: return str(obj) -def _content_to_message_param( - content: types.Content, +async def _content_to_message_param( + content: types.Content, custom_llm_provider: str = None ) -> Union[Message, list[Message]]: """Converts a types.Content to a litellm Message or list of Messages. @@ -184,7 +184,9 @@ def _content_to_message_param( # Handle user or assistant messages role = _to_litellm_role(content.role) - message_content = _get_content(content.parts) or None + message_content = ( + await _get_content(content.parts, custom_llm_provider) or None + ) if role == "user": return ChatCompletionUserMessage(role="user", content=message_content) @@ -223,8 +225,8 @@ def _content_to_message_param( ) -def _get_content( - parts: Iterable[types.Part], +async def _get_content( + parts: Iterable[types.Part], custom_llm_provider: str = None ) -> Union[OpenAIMessageContent, str]: """Converts a list of parts to litellm content. @@ -251,6 +253,14 @@ def _get_content( ): base64_string = base64.b64encode(part.inline_data.data).decode("utf-8") data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}" + if custom_llm_provider in ["openai", "azure"]: + open_ai_file_object = await litellm.acreate_file( + file=part.inline_data.data, + purpose="assistants", + custom_llm_provider=custom_llm_provider, # type: ignore + ) + else: + open_ai_file_object = None if part.inline_data.mime_type.startswith("image"): # Use full MIME type (e.g., "image/png") for providers that validate it @@ -273,12 +283,32 @@ def _get_content( "type": "audio_url", "audio_url": {"url": data_uri, "format": format_type}, }) - elif part.inline_data.mime_type == "application/pdf": + elif ( + part.inline_data.mime_type.startswith("text/") + or part.inline_data.mime_type + in { + "application/pdf", + "application/msword", + "application/json", + "application/x-sh", + "application/typescript", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + } + ): format_type = part.inline_data.mime_type - content_objects.append({ - "type": "file", - "file": {"file_data": data_uri, "format": format_type}, - }) + if open_ai_file_object: + content_objects.append({ + "type": "file", + "file": { + "file_id": open_ai_file_object.id, + }, + }) + else: + content_objects.append({ + "type": "file", + "file": {"file_data": data_uri, "format": format_type}, + }) else: raise ValueError("LiteLlm(BaseLlm) does not support this content part.") @@ -524,7 +554,7 @@ def _message_to_generate_content_response( ) -def _get_completion_inputs( +async def _get_completion_inputs( llm_request: LlmRequest, ) -> Tuple[ List[Message], @@ -540,10 +570,24 @@ def _get_completion_inputs( Returns: The litellm inputs (message list, tool dictionary, response format and generation params). """ + # 0. check custom_llm_provider + if llm_request.model is None: + custom_llm_provider = "UNK" + elif "gemini" in llm_request.model: + custom_llm_provider = "vertex_ai" + elif "azure" in llm_request.model: + custom_llm_provider = "azure" + elif "openai" in llm_request.model: + custom_llm_provider = "openai" + else: + custom_llm_provider = "UNK" + # 1. Construct messages messages: List[Message] = [] for content in llm_request.contents or []: - message_param_or_list = _content_to_message_param(content) + message_param_or_list = await _content_to_message_param( + content, custom_llm_provider + ) if isinstance(message_param_or_list, list): messages.extend(message_param_or_list) elif message_param_or_list: # Ensure it's not None before appending @@ -803,7 +847,7 @@ async def generate_content_async( logger.debug(_build_request_log(llm_request)) messages, tools, response_format, generation_params = ( - _get_completion_inputs(llm_request) + await _get_completion_inputs(llm_request) ) if "functions" in self._additional_args: diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index a46d0f7d55..d4f135691e 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -410,6 +410,13 @@ def lite_llm_instance(mock_client): return LiteLlm(model="test_model", llm_client=mock_client) +@pytest.fixture +def openai_instance(mock_client, model: str = None): + if model is None: + model = "openai/gpt-5" + return LiteLlm(model=model, llm_client=mock_client) + + class MockLLMClient(LiteLLMClient): def __init__(self, acompletion_mock, completion_mock): @@ -925,16 +932,18 @@ async def test_generate_content_async_with_usage_metadata( mock_acompletion.assert_called_once() -def test_content_to_message_param_user_message(): +@pytest.mark.asyncio +async def test_content_to_message_param_user_message(): content = types.Content( role="user", parts=[types.Part.from_text(text="Test prompt")] ) - message = _content_to_message_param(content) + message = await _content_to_message_param(content) assert message["role"] == "user" assert message["content"] == "Test prompt" -def test_content_to_message_param_multi_part_function_response(): +@pytest.mark.asyncio +async def test_content_to_message_param_multi_part_function_response(): part1 = types.Part.from_function_response( name="function_one", response={"result": "result_one"}, @@ -951,7 +960,7 @@ def test_content_to_message_param_multi_part_function_response(): role="tool", parts=[part1, part2], ) - messages = _content_to_message_param(content) + messages = await _content_to_message_param(content) assert isinstance(messages, list) assert len(messages) == 2 @@ -964,16 +973,18 @@ def test_content_to_message_param_multi_part_function_response(): assert messages[1]["content"] == '{"value": 123}' -def test_content_to_message_param_assistant_message(): +@pytest.mark.asyncio +async def test_content_to_message_param_assistant_message(): content = types.Content( role="assistant", parts=[types.Part.from_text(text="Test response")] ) - message = _content_to_message_param(content) + message = await _content_to_message_param(content) assert message["role"] == "assistant" assert message["content"] == "Test response" -def test_content_to_message_param_function_call(): +@pytest.mark.asyncio +async def test_content_to_message_param_function_call(): content = types.Content( role="assistant", parts=[ @@ -984,7 +995,7 @@ def test_content_to_message_param_function_call(): ], ) content.parts[1].function_call.id = "test_tool_call_id" - message = _content_to_message_param(content) + message = await _content_to_message_param(content) assert message["role"] == "assistant" assert message["content"] == "test response" @@ -995,7 +1006,8 @@ def test_content_to_message_param_function_call(): assert tool_call["function"]["arguments"] == '{"test_arg": "test_value"}' -def test_content_to_message_param_multipart_content(): +@pytest.mark.asyncio +async def test_content_to_message_param_multipart_content(): """Test handling of multipart content where final_content is a list with text objects.""" content = types.Content( role="assistant", @@ -1004,7 +1016,7 @@ def test_content_to_message_param_multipart_content(): types.Part.from_bytes(data=b"test_image_data", mime_type="image/png"), ], ) - message = _content_to_message_param(content) + message = await _content_to_message_param(content) assert message["role"] == "assistant" # When content is a list and the first element is a text object with type "text", # it should extract the text (for providers like ollama_chat that don't handle lists well) @@ -1013,7 +1025,8 @@ def test_content_to_message_param_multipart_content(): assert message["tool_calls"] is None -def test_content_to_message_param_single_text_object_in_list(): +@pytest.mark.asyncio +async def test_content_to_message_param_single_text_object_in_list(): """Test extraction of text from single text object in list (for ollama_chat compatibility).""" from unittest.mock import patch @@ -1025,7 +1038,7 @@ def test_content_to_message_param_single_text_object_in_list(): role="assistant", parts=[types.Part.from_text(text="single text")], ) - message = _content_to_message_param(content) + message = await _content_to_message_param(content) assert message["role"] == "assistant" # Should extract the text from the single text object assert message["content"] == "single text" @@ -1067,17 +1080,19 @@ def test_message_to_generate_content_response_tool_call(): assert response.content.parts[0].function_call.id == "test_tool_call_id" -def test_get_content_text(): +@pytest.mark.asyncio +async def test_get_content_text(): parts = [types.Part.from_text(text="Test text")] - content = _get_content(parts) + content = await _get_content(parts) assert content == "Test text" -def test_get_content_image(): +@pytest.mark.asyncio +async def test_get_content_image(): parts = [ types.Part.from_bytes(data=b"test_image_data", mime_type="image/png") ] - content = _get_content(parts) + content = await _get_content(parts) assert content[0]["type"] == "image_url" assert ( content[0]["image_url"]["url"] @@ -1086,11 +1101,12 @@ def test_get_content_image(): assert content[0]["image_url"]["format"] == "image/png" -def test_get_content_video(): +@pytest.mark.asyncio +async def test_get_content_video(): parts = [ types.Part.from_bytes(data=b"test_video_data", mime_type="video/mp4") ] - content = _get_content(parts) + content = await _get_content(parts) assert content[0]["type"] == "video_url" assert ( content[0]["video_url"]["url"] @@ -1099,11 +1115,12 @@ def test_get_content_video(): assert content[0]["video_url"]["format"] == "video/mp4" -def test_get_content_pdf(): +@pytest.mark.asyncio +async def test_get_content_pdf(): parts = [ types.Part.from_bytes(data=b"test_pdf_data", mime_type="application/pdf") ] - content = _get_content(parts) + content = await _get_content(parts) assert content[0]["type"] == "file" assert ( content[0]["file"]["file_data"] @@ -1112,11 +1129,12 @@ def test_get_content_pdf(): assert content[0]["file"]["format"] == "application/pdf" -def test_get_content_audio(): +@pytest.mark.asyncio +async def test_get_content_audio(): parts = [ types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg") ] - content = _get_content(parts) + content = await _get_content(parts) assert content[0]["type"] == "audio_url" assert ( content[0]["audio_url"]["url"] @@ -1592,7 +1610,7 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( @pytest.mark.asyncio -def test_get_completion_inputs_generation_params(): +async def test_get_completion_inputs_generation_params(): # Test that generation_params are extracted and mapped correctly req = LlmRequest( contents=[ @@ -1610,7 +1628,7 @@ def test_get_completion_inputs_generation_params(): ) from google.adk.models.lite_llm import _get_completion_inputs - _, _, _, generation_params = _get_completion_inputs(req) + _, _, _, generation_params = await _get_completion_inputs(req) assert generation_params["temperature"] == 0.33 assert generation_params["max_completion_tokens"] == 123 assert generation_params["top_p"] == 0.88 @@ -1750,3 +1768,112 @@ def test_non_gemini_litellm_no_warning(): # Test with non-Gemini model LiteLlm(model="openai/gpt-4o") assert len(w) == 0 + + +@pytest.mark.asyncio +async def test_get_file_id_from_litellm_openai( + mocker, +): + """Test for request with attach file as file_id for OpenAI""" + from google.adk.models.lite_llm import _get_completion_inputs + + mock_return = mocker.MagicMock() + mock_return.id = "test_file_id" + acreate_file_mock = AsyncMock(return_value=mock_return) + mocker.patch( + "google.adk.models.lite_llm.litellm.acreate_file", + new=acreate_file_mock, + ) + + data_part = types.Part.from_bytes( + data=b"test_pdf_data", mime_type="application/pdf" + ) + data_part.inline_data.display_name = "test_file.pdf" + + llm_request = LlmRequest( + model="openai/gpt-4o", + contents=[ + types.Content( + role="user", + parts=[ + types.Part.from_text(text="Test attach PDF file"), + data_part, + ], + ) + ], + config=types.GenerateContentConfig( + tools=[], + ), + ) + messages, tools, response_format, generation_params = ( + await _get_completion_inputs(llm_request) + ) + assert messages + assert messages == [{ + "role": "user", + "content": [ + {"type": "text", "text": "Test attach PDF file"}, + { + "type": "file", + "file": {"file_id": "test_file_id"}, + }, + ], + }] + + +@pytest.mark.asyncio +async def test_get_file_id_from_litellm_gemini( + mocker, +): + """Test for request with attach file **NOT** as file_id for gemini (or other than openai or azure)""" + + from google.adk.models.lite_llm import _get_completion_inputs + + mock_return = mocker.MagicMock() + mock_return.id = "test_file_id" + acreate_file_mock = AsyncMock(return_value=mock_return) + mocker.patch( + "google.adk.models.lite_llm.litellm.acreate_file", + new=acreate_file_mock, + ) + + data_part = types.Part.from_bytes( + data=b"test_pdf_data", mime_type="application/pdf" + ) + data_part.inline_data.display_name = "test_file.pdf" + + llm_request = LlmRequest( + model="gemini", + contents=[ + types.Content( + role="user", + parts=[ + types.Part.from_text(text="Test attach PDF file"), + data_part, + ], + ) + ], + config=types.GenerateContentConfig( + tools=[], + ), + ) + + messages, tools, response_format, generation_params = ( + await _get_completion_inputs(llm_request) + ) + assert messages + assert messages == [{ + "role": "user", + "content": [ + {"type": "text", "text": "Test attach PDF file"}, + { + "type": "file", + "file": { + "file_data": ( + "data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ==" + ), + "format": "application/pdf", + }, + }, + ], + }]