diff --git a/src/kaggle_benchmarks/messages.py b/src/kaggle_benchmarks/messages.py index fcacc88..5446dda 100644 --- a/src/kaggle_benchmarks/messages.py +++ b/src/kaggle_benchmarks/messages.py @@ -17,6 +17,8 @@ import warnings from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar +import pydantic + from kaggle_benchmarks import events, utils if TYPE_CHECKING: @@ -75,6 +77,8 @@ def usage(self): def payload(self) -> str | list[dict]: if hasattr(self.content, "get_payload"): return self.content.get_payload() + if isinstance(self.content, pydantic.BaseModel): + return self.content.model_dump_json() if dataclasses.is_dataclass(self.content) and not isinstance( self.content, type ): diff --git a/tests/test_messages.py b/tests/test_messages.py index e5ce409..3e836d7 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -15,6 +15,7 @@ import json from dataclasses import dataclass +import pydantic import pytest from kaggle_benchmarks import actors, chats, messages, user @@ -54,6 +55,15 @@ class Point: assert json.loads(msg.payload) == {"x": 1, "y": 2} +def test_pydantic_payload(): + class Point(pydantic.BaseModel): + x: float + y: float + + msg = messages.Message(Point(x=1.5, y=2.5), sender=user) + assert json.loads(msg.payload) == {"x": 1.5, "y": 2.5} + + def test_class_payload(): class Point: def __init__(self, x):