Skip to content

Commit

Permalink
Add ability to specify custom serializer (#764)
Browse files Browse the repository at this point in the history
Allow users to define a custom serializer
  • Loading branch information
eyurtsev authored Sep 14, 2024
1 parent c747e20 commit 1e24edc
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 21 deletions.
7 changes: 5 additions & 2 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
PublicTraceLink,
PublicTraceLinkCreateRequest,
)
from langserve.serialization import WellKnownLCSerializer
from langserve.serialization import Serializer, WellKnownLCSerializer
from langserve.validation import (
BatchBaseResponse,
BatchRequestShallowValidator,
Expand Down Expand Up @@ -536,6 +536,7 @@ def __init__(
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
serializer: Optional[Serializer] = None,
) -> None:
"""Create an API handler for the given runnable.
Expand Down Expand Up @@ -600,6 +601,8 @@ def __init__(
TODO: Introduce deprecation for this parameter to rename it
astream_events_version: version of the stream events endpoint to use.
By default "v2".
serializer: optional serializer to use for serializing the output.
If not provided, the default serializer will be used.
"""
if importlib.util.find_spec("sse_starlette") is None:
raise ImportError(
Expand Down Expand Up @@ -638,7 +641,7 @@ def __init__(
)
self._include_callback_events = include_callback_events
self._per_req_config_modifier = per_req_config_modifier
self._serializer = WellKnownLCSerializer()
self._serializer = serializer or WellKnownLCSerializer()
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
Expand Down
5 changes: 4 additions & 1 deletion langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
cert: Optional[CertTypes] = None,
client_kwargs: Optional[Dict[str, Any]] = None,
use_server_callback_events: bool = True,
serializer: Optional[Serializer] = None,
) -> None:
"""Initialize the client.
Expand All @@ -300,6 +301,8 @@ def __init__(
and async httpx clients
use_server_callback_events: Whether to invoke callbacks on any
callback events returned by the server.
serializer: The serializer to use for serializing and deserializing
data. If not provided, a default serializer will be used.
"""
_client_kwargs = client_kwargs or {}
# Enforce trailing slash
Expand Down Expand Up @@ -327,7 +330,7 @@ def __init__(

# Register cleanup handler once RemoteRunnable is garbage collected
weakref.finalize(self, _close_clients, self.sync_client, self.async_client)
self._lc_serializer = WellKnownLCSerializer()
self._lc_serializer = serializer or WellKnownLCSerializer()
self._use_server_callback_events = use_server_callback_events

def _invoke(
Expand Down
42 changes: 26 additions & 16 deletions langserve/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,39 +157,49 @@ def _decode_event_data(value: Any) -> Any:


class Serializer(abc.ABC):
@abc.abstractmethod
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return orjson.loads(self.dumps(obj))

def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
return self.loadd(orjson.loads(s))

@abc.abstractmethod
def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
"""Dump the given object to a JSON byte string."""

@abc.abstractmethod
def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
def loadd(self, s: bytes) -> Any:
"""Given a python object, load it into a well known object.
@abc.abstractmethod
def loadd(self, obj: Any) -> Any:
"""Load the given object."""
The obj represents content that was json loaded from a string, but
not yet validated or converted into a well known object.
"""


class WellKnownLCSerializer(Serializer):
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return orjson.loads(orjson.dumps(obj, default=default))
"""A pre-defined serializer for well known LangChain objects.
This is the default serialized used by LangServe for serializing and
de-serializing well known LangChain objects.
If you need to extend the serialization capabilities for your own application,
feel free to create a new instance of the Serializer class and implement
the abstract methods dumps and loadd.
"""

def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
"""Dump the given object to a JSON byte string."""
return orjson.dumps(obj, default=default)

def loadd(self, obj: Any) -> Any:
"""Load the given object."""
return _decode_lc_objects(obj)
"""Given a python object, load it into a well known object.
def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
return self.loadd(orjson.loads(s))
The obj represents content that was json loaded from a string, but
not yet validated or converted into a well known object.
"""
return _decode_lc_objects(obj)


def _project_top_level(model: BaseModel) -> Dict[str, Any]:
Expand Down
5 changes: 5 additions & 0 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TokenFeedbackConfig,
_is_hosted,
)
from langserve.serialization import Serializer

try:
from fastapi import APIRouter, Depends, FastAPI, Request, Response
Expand Down Expand Up @@ -263,6 +264,7 @@ def add_routes(
dependencies: Optional[Sequence[Depends]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
serializer: Optional[Serializer] = None,
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
Expand Down Expand Up @@ -383,6 +385,8 @@ def add_routes(
which message types are supported etc.)
astream_events_version: version of the stream events endpoint to use.
By default "v2".
serializer: The serializer to use for serializing the output. If not provided,
the default serializer will be used.
""" # noqa: E501
if not isinstance(runnable, Runnable):
raise TypeError(
Expand Down Expand Up @@ -447,6 +451,7 @@ def add_routes(
stream_log_name_allow_list=stream_log_name_allow_list,
playground_type=playground_type,
astream_events_version=astream_events_version,
serializer=serializer,
)

namespace = path or ""
Expand Down
52 changes: 50 additions & 2 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from langsmith import schemas as ls_schemas
from langsmith.client import Client
from langsmith.schemas import FeedbackIngestToken
from orjson import orjson
from pydantic import BaseModel, Field, __version__
from pytest import MonkeyPatch
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -244,13 +245,17 @@ async def get_async_test_client(

@asynccontextmanager
async def get_async_remote_runnable(
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
server: FastAPI,
*,
path: Optional[str] = None,
raise_app_exceptions: bool = True,
**kwargs: Any,
) -> RemoteRunnable:
"""Get an async client."""
url = "http://localhost:9999"
if path:
url += path
remote_runnable_client = RemoteRunnable(url=url)
remote_runnable_client = RemoteRunnable(url=url, **kwargs)

async with get_async_test_client(
server, path=path, raise_app_exceptions=raise_app_exceptions
Expand Down Expand Up @@ -2280,6 +2285,49 @@ async def check_types(inputs: VariousTypes) -> int:
)


async def test_custom_serialization() -> None:
"""Test updating the config based on the raw request object."""
from langserve.serialization import Serializer

class CustomObject:
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other) -> bool:
return self.x == other.x

class CustomSerializer(Serializer):
def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
if isinstance(obj, CustomObject):
return orjson.dumps({"x": obj.x})
else:
return orjson.dumps(obj)

def loadd(self, obj: Any) -> Any:
"""Load the given object."""
if isinstance(obj, bytes):
obj = obj.decode("utf-8")
if obj.get("x"):
return CustomObject(x=obj["x"])
return obj

def foo(x: int) -> Any:
"""Add one to simulate a valid function."""
return CustomObject(x=5)

app = FastAPI()
server_runnable = RunnableLambda(foo)
add_routes(app, server_runnable, serializer=CustomSerializer())

async with get_async_remote_runnable(
app, raise_app_exceptions=True, serializer=CustomSerializer()
) as runnable:
result = await runnable.ainvoke(5)
assert isinstance(result, CustomObject)
assert result == CustomObject(x=5)


async def test_endpoint_configurations() -> None:
"""Test enabling/disabling endpoints."""
app = FastAPI()
Expand Down

0 comments on commit 1e24edc

Please sign in to comment.