Skip to content

Commit 5588428

Browse files
authored
feat(serialization): introduce dedicated serialization module (#118)
Separates serialization logic from API clients (OpenAI, GoogleGenAI) and make it more customizable per model or provider. - Introduced a unified serialization pipeline through the BaseSerializer and its implementations: OpenAICompletionSerializer, ModelProxyOpenAISerializer, and GenAiSerializer. - Removed inline message serialization loops (like _get_raw_messages) from LLM classes, delegating it directly to the corresponding serializer instances.
1 parent 048cb8a commit 5588428

File tree

9 files changed

+867
-72
lines changed

9 files changed

+867
-72
lines changed

src/kaggle_benchmarks/actors/llms.py

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,9 @@ def invoke(self, messages, system: str = ""):
8585
8686
"""
8787

88-
import base64
8988
import dataclasses
9089
import enum
9190
import json
92-
import mimetypes
9391
import typing
9492
from typing import TYPE_CHECKING, Any, Iterator, TypeVar
9593

@@ -100,6 +98,8 @@ def invoke(self, messages, system: str = ""):
10098
from kaggle_benchmarks import actors, chats, messages, prompting, utils
10199
from kaggle_benchmarks._config import config
102100
from kaggle_benchmarks.content_types import images, videos
101+
from kaggle_benchmarks.serializers import genai as genai_serializer
102+
from kaggle_benchmarks.serializers import openai as openai_serializer
103103

104104
if TYPE_CHECKING:
105105
from kaggle_benchmarks import llm_messages
@@ -278,6 +278,9 @@ def __init__(self, client: openai.OpenAI, model: str, **kwargs):
278278
super().__init__(**kwargs)
279279
self.model = model
280280
self.client = client
281+
self.serializer = openai_serializer.ModelProxyOpenAISerializer(
282+
roles_mapping={"tool": "system"}
283+
)
281284

282285
def _get_usage_meta(
283286
self, usage: openai.types.CompletionUsage | None
@@ -298,9 +301,12 @@ def _should_remove_seed(self) -> bool:
298301
def invoke(
299302
self, messages: list[messages.Message], system: str | None, **kwargs
300303
) -> LLMResponse | Iterator[LLMResponse]:
301-
raw_messages = self._get_raw_messages(messages)
302304
if system:
303-
raw_messages = [{"role": "system", "content": system}] + raw_messages
305+
from kaggle_benchmarks.messages import Message
306+
307+
messages = [Message(sender=actors.system, content=system)] + messages
308+
309+
raw_messages = list(self.serializer.dump_messages(messages))
304310

305311
if self._should_remove_seed():
306312
# TODO(b/430112500): Remove once model proxy supports it for AIS backends.
@@ -309,17 +315,6 @@ def invoke(
309315

310316
return self._call_api(raw_messages, **kwargs)
311317

312-
def _get_raw_messages(self, messages: list[messages.Message]):
313-
return [
314-
{
315-
"role": message.sender.role
316-
if message.sender.role != "tool"
317-
else "system", # TODO: Remove this renaming once ModelProxy supports tools
318-
"content": message.payload,
319-
}
320-
for message in messages
321-
]
322-
323318
def _get_stream_response(
324319
self, response_stream: openai.Stream
325320
) -> Iterator[LLMResponse]:
@@ -390,6 +385,9 @@ def __init__(self, client: genai.Client, model: str, **kwargs):
390385
super().__init__(**kwargs)
391386
self.model = model
392387
self.client = client
388+
self.serializer = genai_serializer.GenAISerializer(
389+
roles_mapping={"assistant": "model", "system": "user", "tool": "user"}
390+
)
393391

394392
def _get_usage_meta(self, usage: types.UsageMetadata | None) -> dict[str, Any]:
395393
if usage is None:
@@ -400,60 +398,6 @@ def _get_usage_meta(self, usage: types.UsageMetadata | None) -> dict[str, Any]:
400398
**_extract_extra_usage_metadata(usage),
401399
}
402400

403-
def _get_raw_messages(self, messages: list[messages.Message]):
404-
"""Converts benchmark messages to Google GenAI's Content format."""
405-
raw_messages = []
406-
for message in messages:
407-
role = "model" if message.sender.role == "assistant" else "user"
408-
content = message.content
409-
payload = message.payload
410-
411-
parts = []
412-
413-
# Video URLs are passed through directly for the model provider to resolve.
414-
if isinstance(content, videos.VideoContent):
415-
parts.append(
416-
types.Part.from_uri(
417-
file_uri=content.url, mime_type=content.mime_type
418-
)
419-
)
420-
421-
elif isinstance(payload, str):
422-
parts.append(types.Part(text=payload))
423-
424-
# Note: The Gemini API is smart enough to process image data URLs even when they are passed as part of a plain text string.
425-
elif isinstance(payload, list) and payload and isinstance(payload[0], dict):
426-
for item in payload:
427-
if item.get("type") == "image_url":
428-
url = item["image_url"]["url"]
429-
430-
image_bytes = None
431-
mime_type = "image/jpeg"
432-
if url.startswith("data:"):
433-
# Handle base64 data URLs
434-
header, b64_string = url.split(",", 1)
435-
mime_type = header.split(";")[0].split(":")[1]
436-
image_bytes = base64.b64decode(b64_string)
437-
else:
438-
# Handle remote http/https URLs
439-
b64_string = images.image_url_to_base64(url)
440-
image_bytes = base64.b64decode(b64_string)
441-
mime_type = mimetypes.guess_type(url)[0] or "image/jpeg"
442-
443-
if image_bytes:
444-
parts.append(
445-
types.Part.from_bytes(
446-
data=image_bytes, mime_type=mime_type
447-
)
448-
)
449-
else:
450-
# Fallback for any other unexpected payload types
451-
parts.append(types.Part(text=str(payload)))
452-
453-
raw_messages.append(types.Content(role=role, parts=parts))
454-
455-
return raw_messages
456-
457401
def _get_stream_response(
458402
self, response_stream: Iterator[types.GenerateContentResponse]
459403
) -> Iterator[LLMResponse]:
@@ -467,7 +411,7 @@ def _get_stream_response(
467411
def invoke(
468412
self, messages: list[messages.Message], system: str | None, **kwargs
469413
) -> LLMResponse | Iterator[LLMResponse]:
470-
raw_messages = self._get_raw_messages(messages)
414+
raw_messages = list(self.serializer.dump_messages(messages))
471415

472416
config_params = {}
473417
if system:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2026 Kaggle Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2026 Kaggle Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import copy
15+
import dataclasses
16+
import itertools
17+
import json
18+
19+
import pydantic
20+
21+
from kaggle_benchmarks import actors, chats, llm_messages, tools
22+
from kaggle_benchmarks import messages as msg
23+
from kaggle_benchmarks.content_types import images, videos
24+
25+
26+
class UnsupportedMessageFormat(ValueError):
27+
pass
28+
29+
30+
def _copy_replace(message, **new_fields):
31+
new = copy.copy(message)
32+
for k, v in new_fields.items():
33+
setattr(new, k, v)
34+
return new
35+
36+
37+
class BaseSerializer:
38+
"""Base class for all message serializers.
39+
40+
Provides the core logic to map generic benchmark messages to provider-specific
41+
formats. Subclasses must implement the specific `dump_*` methods.
42+
"""
43+
44+
def __init__(self, roles_mapping: dict[str, str] | None = None):
45+
self.roles_mapping = roles_mapping or {}
46+
47+
def get_role(self, sender: actors.Actor):
48+
"""Resolves the provider-specific role for a given sender using roles_mapping."""
49+
return self.roles_mapping.get(sender.role, sender.role)
50+
51+
def dump_chat(self, chat: chats.Chat):
52+
"""Serializes an entire chat history into a provider-specific format."""
53+
return self.dump_messages(chat.messages)
54+
55+
def dump_messages(self, messages: list[msg.Message]):
56+
"""Serializes a list of messages into a provider-specific format."""
57+
return itertools.chain(*(self.dump_message(message) for message in messages))
58+
59+
def dump_message(self, message: msg.Message):
60+
"""Dynamically dispatches serialization based on the message content type."""
61+
if isinstance(message, llm_messages.LLMMessage):
62+
try:
63+
yield from self.dump_llm_message(message)
64+
return
65+
except NotImplementedError:
66+
# Fallback if the subclass doesn't support explicit LLM messages
67+
pass
68+
69+
content = message.content
70+
if isinstance(content, str):
71+
yield from self.dump_text_message(message)
72+
elif isinstance(content, images.ImageContent):
73+
yield from self.dump_image(message)
74+
elif isinstance(content, videos.VideoContent):
75+
yield from self.dump_video(message)
76+
elif isinstance(content, dict):
77+
yield from self.dump_json_message(message)
78+
elif isinstance(content, tools.ToolInvocationResult):
79+
yield from self.dump_tool_invocation(message)
80+
elif isinstance(content, pydantic.BaseModel):
81+
yield from self.dump_json_message(
82+
_copy_replace(message, content=message.content.model_dump())
83+
)
84+
elif dataclasses.is_dataclass(content) and not isinstance(content, type):
85+
yield from self.dump_json_message(
86+
_copy_replace(message, content=dataclasses.asdict(content))
87+
)
88+
else:
89+
yield from self._dump_message(message)
90+
91+
def _dump_message(self, message: msg.Message):
92+
"""Fallback method for unsupported message types. Override in subclass to handle."""
93+
raise NotImplementedError(
94+
f"Unsupported message format for: {type(message.content)}"
95+
)
96+
97+
def dump_tool_invocation(self, message: msg.Message[tools.ToolInvocationResult]):
98+
raise NotImplementedError()
99+
100+
def dump_llm_message(self, message: llm_messages.LLMMessage):
101+
"""Serializes an LLM message containing tools and complex outputs."""
102+
raise NotImplementedError()
103+
104+
def dump_text_message(self, message: msg.Message[str]):
105+
"""Serializes a standard text message."""
106+
raise NotImplementedError()
107+
108+
def dump_json_message(self, message: msg.Message[dict]):
109+
"""Serializes a JSON dictionary message by stringifying it as text by default."""
110+
yield from self.dump_text_message(
111+
_copy_replace(message, content=json.dumps(message.content))
112+
)
113+
114+
def dump_image(self, message: msg.Message[images.ImageContent]):
115+
"""Serializes an image message."""
116+
raise NotImplementedError()
117+
118+
def dump_video(self, message: msg.Message[videos.VideoContent]):
119+
"""Serializes a video message."""
120+
raise NotImplementedError()

0 commit comments

Comments
 (0)