@@ -85,11 +85,9 @@ def invoke(self, messages, system: str = ""):
8585
8686"""
8787
88- import base64
8988import dataclasses
9089import enum
9190import json
92- import mimetypes
9391import typing
9492from typing import TYPE_CHECKING , Any , Iterator , TypeVar
9593
@@ -100,6 +98,8 @@ def invoke(self, messages, system: str = ""):
10098from kaggle_benchmarks import actors , chats , messages , prompting , utils
10199from kaggle_benchmarks ._config import config
102100from kaggle_benchmarks .content_types import images , videos
101+ from kaggle_benchmarks .serializers import genai as genai_serializer
102+ from kaggle_benchmarks .serializers import openai as openai_serializer
103103
104104if TYPE_CHECKING :
105105 from kaggle_benchmarks import llm_messages
@@ -278,6 +278,9 @@ def __init__(self, client: openai.OpenAI, model: str, **kwargs):
278278 super ().__init__ (** kwargs )
279279 self .model = model
280280 self .client = client
281+ self .serializer = openai_serializer .ModelProxyOpenAISerializer (
282+ roles_mapping = {"tool" : "system" }
283+ )
281284
282285 def _get_usage_meta (
283286 self , usage : openai .types .CompletionUsage | None
@@ -298,9 +301,12 @@ def _should_remove_seed(self) -> bool:
298301 def invoke (
299302 self , messages : list [messages .Message ], system : str | None , ** kwargs
300303 ) -> LLMResponse | Iterator [LLMResponse ]:
301- raw_messages = self ._get_raw_messages (messages )
302304 if system :
303- raw_messages = [{"role" : "system" , "content" : system }] + raw_messages
305+ from kaggle_benchmarks .messages import Message
306+
307+ messages = [Message (sender = actors .system , content = system )] + messages
308+
309+ raw_messages = list (self .serializer .dump_messages (messages ))
304310
305311 if self ._should_remove_seed ():
306312 # TODO(b/430112500): Remove once model proxy supports it for AIS backends.
@@ -309,17 +315,6 @@ def invoke(
309315
310316 return self ._call_api (raw_messages , ** kwargs )
311317
312- def _get_raw_messages (self , messages : list [messages .Message ]):
313- return [
314- {
315- "role" : message .sender .role
316- if message .sender .role != "tool"
317- else "system" , # TODO: Remove this renaming once ModelProxy supports tools
318- "content" : message .payload ,
319- }
320- for message in messages
321- ]
322-
323318 def _get_stream_response (
324319 self , response_stream : openai .Stream
325320 ) -> Iterator [LLMResponse ]:
@@ -390,6 +385,9 @@ def __init__(self, client: genai.Client, model: str, **kwargs):
390385 super ().__init__ (** kwargs )
391386 self .model = model
392387 self .client = client
388+ self .serializer = genai_serializer .GenAISerializer (
389+ roles_mapping = {"assistant" : "model" , "system" : "user" , "tool" : "user" }
390+ )
393391
394392 def _get_usage_meta (self , usage : types .UsageMetadata | None ) -> dict [str , Any ]:
395393 if usage is None :
@@ -400,60 +398,6 @@ def _get_usage_meta(self, usage: types.UsageMetadata | None) -> dict[str, Any]:
400398 ** _extract_extra_usage_metadata (usage ),
401399 }
402400
403- def _get_raw_messages (self , messages : list [messages .Message ]):
404- """Converts benchmark messages to Google GenAI's Content format."""
405- raw_messages = []
406- for message in messages :
407- role = "model" if message .sender .role == "assistant" else "user"
408- content = message .content
409- payload = message .payload
410-
411- parts = []
412-
413- # Video URLs are passed through directly for the model provider to resolve.
414- if isinstance (content , videos .VideoContent ):
415- parts .append (
416- types .Part .from_uri (
417- file_uri = content .url , mime_type = content .mime_type
418- )
419- )
420-
421- elif isinstance (payload , str ):
422- parts .append (types .Part (text = payload ))
423-
424- # Note: The Gemini API is smart enough to process image data URLs even when they are passed as part of a plain text string.
425- elif isinstance (payload , list ) and payload and isinstance (payload [0 ], dict ):
426- for item in payload :
427- if item .get ("type" ) == "image_url" :
428- url = item ["image_url" ]["url" ]
429-
430- image_bytes = None
431- mime_type = "image/jpeg"
432- if url .startswith ("data:" ):
433- # Handle base64 data URLs
434- header , b64_string = url .split ("," , 1 )
435- mime_type = header .split (";" )[0 ].split (":" )[1 ]
436- image_bytes = base64 .b64decode (b64_string )
437- else :
438- # Handle remote http/https URLs
439- b64_string = images .image_url_to_base64 (url )
440- image_bytes = base64 .b64decode (b64_string )
441- mime_type = mimetypes .guess_type (url )[0 ] or "image/jpeg"
442-
443- if image_bytes :
444- parts .append (
445- types .Part .from_bytes (
446- data = image_bytes , mime_type = mime_type
447- )
448- )
449- else :
450- # Fallback for any other unexpected payload types
451- parts .append (types .Part (text = str (payload )))
452-
453- raw_messages .append (types .Content (role = role , parts = parts ))
454-
455- return raw_messages
456-
457401 def _get_stream_response (
458402 self , response_stream : Iterator [types .GenerateContentResponse ]
459403 ) -> Iterator [LLMResponse ]:
@@ -467,7 +411,7 @@ def _get_stream_response(
467411 def invoke (
468412 self , messages : list [messages .Message ], system : str | None , ** kwargs
469413 ) -> LLMResponse | Iterator [LLMResponse ]:
470- raw_messages = self ._get_raw_messages (messages )
414+ raw_messages = list ( self .serializer . dump_messages (messages ) )
471415
472416 config_params = {}
473417 if system :
0 commit comments