Skip to content

Commit 946f418

Browse files
authored
Add support for pydantic (#26)
This commit adds an optional support for pydantic models. To use this simply add the pydantic dependency, and annotate a handler with it. ```py Greeting(BaseModel): name: str @svc.handler() async def greet(ctx, greeting: Greeting): .. ``` With this, any bad input (validation error) will result with a TerminalError thrown by the serializer.
1 parent ac0c995 commit 946f418

File tree

8 files changed

+123
-17
lines changed

8 files changed

+123
-17
lines changed

Justfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ mypy:
1818
# Recipe to run pylint for linting
1919
pylint:
2020
@echo "Running pylint..."
21-
{{python}} -m pylint python/restate
22-
{{python}} -m pylint examples/
21+
{{python}} -m pylint python/restate --ignore-paths='^.*.?venv.*$'
22+
{{python}} -m pylint examples/ --ignore-paths='^.*\.?venv.*$'
2323

2424
test:
2525
@echo "Running Python tests..."

python/restate/discovery.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ def compute_discovery_json(endpoint: RestateEndpoint,
113113
headers = {"content-type": "application/vnd.restate.endpointmanifest.v1+json"}
114114
return (headers, json_str)
115115

116+
def try_extract_json_schema(model: Any) -> typing.Optional[typing.Any]:
117+
"""
118+
Try to extract the JSON schema from a schema object
119+
"""
120+
if model:
121+
return model.model_json_schema(mode='serialization')
122+
return None
123+
116124
def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal["bidi", "request_response"]) -> Endpoint:
117125
"""
118126
return restate's discovery object for an endpoint
@@ -131,11 +139,11 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
131139
# input
132140
inp = InputPayload(required=False,
133141
contentType=handler.handler_io.accept,
134-
jsonSchema=None)
142+
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_input_model))
135143
# output
136144
out = OutputPayload(setContentTypeIfEmpty=False,
137145
contentType=handler.handler_io.content_type,
138-
jsonSchema=None)
146+
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_output_model))
139147
# add the handler
140148
service_handlers.append(Handler(name=handler.name, ty=ty, input=inp, output=out))
141149

python/restate/handler.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,34 @@
1616
"""
1717

1818
from dataclasses import dataclass
19+
from inspect import Signature
1920
from typing import Any, Callable, Awaitable, Generic, Literal, Optional, TypeVar
2021

21-
from restate.serde import Serde
22+
from restate.exceptions import TerminalError
23+
from restate.serde import JsonSerde, Serde, PydanticJsonSerde
2224

2325
I = TypeVar('I')
2426
O = TypeVar('O')
2527

2628
# we will use this symbol to store the handler in the function
2729
RESTATE_UNIQUE_HANDLER_SYMBOL = str(object())
2830

31+
32+
def try_import_pydantic_base_model():
33+
"""
34+
Try to import PydanticBaseModel from Pydantic.
35+
"""
36+
try:
37+
from pydantic import BaseModel # type: ignore # pylint: disable=import-outside-toplevel
38+
return BaseModel
39+
except ImportError:
40+
class Dummy: # pylint: disable=too-few-public-methods
41+
"""a dummy class to use when Pydantic is not available"""
42+
43+
return Dummy
44+
45+
PYDANTIC_BASE_MODEL = try_import_pydantic_base_model()
46+
2947
@dataclass
3048
class ServiceTag:
3149
"""
@@ -42,13 +60,45 @@ class HandlerIO(Generic[I, O]):
4260
Attributes:
4361
accept (str): The accept header value for the handler.
4462
content_type (str): The content type header value for the handler.
45-
serializer: The serializer function to convert output to bytes.
46-
deserializer: The deserializer function to convert input type to bytes.
4763
"""
4864
accept: str
4965
content_type: str
5066
input_serde: Serde[I]
5167
output_serde: Serde[O]
68+
pydantic_input_model: Optional[I] = None
69+
pydantic_output_model: Optional[O] = None
70+
71+
def is_pydantic(annotation) -> bool:
72+
"""
73+
Check if an object is a Pydantic model.
74+
"""
75+
try:
76+
return issubclass(annotation, PYDANTIC_BASE_MODEL)
77+
except TypeError:
78+
# annotation is not a class or a type
79+
return False
80+
81+
82+
def infer_pydantic_io(handler_io: HandlerIO[I, O], signature: Signature):
83+
"""
84+
Augment handler_io with Pydantic models when these are provided.
85+
This method will inspect the signature of an handler and will look for
86+
the input and the return types of a function, and will:
87+
* capture any Pydantic models (to be used later at discovery)
88+
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
89+
"""
90+
# check if the handlers I/O is a PydanticBaseModel
91+
annotation = list(signature.parameters.values())[-1].annotation
92+
if is_pydantic(annotation):
93+
handler_io.pydantic_input_model = annotation
94+
if isinstance(handler_io.input_serde, JsonSerde): # type: ignore
95+
handler_io.input_serde = PydanticJsonSerde(annotation)
96+
97+
annotation = signature.return_annotation
98+
if is_pydantic(annotation):
99+
handler_io.pydantic_output_model = annotation
100+
if isinstance(handler_io.output_serde, JsonSerde): # type: ignore
101+
handler_io.output_serde = PydanticJsonSerde(annotation)
52102

53103
@dataclass
54104
class Handler(Generic[I, O]):
@@ -71,7 +121,7 @@ def make_handler(service_tag: ServiceTag,
71121
name: str | None,
72122
kind: Optional[Literal["exclusive", "shared", "workflow"]],
73123
wrapped: Any,
74-
arity: int) -> Handler[I, O]:
124+
signature: Signature) -> Handler[I, O]:
75125
"""
76126
Factory function to create a handler.
77127
"""
@@ -82,12 +132,19 @@ def make_handler(service_tag: ServiceTag,
82132
if not handler_name:
83133
raise ValueError("Handler name must be provided")
84134

135+
if len(signature.parameters) == 0:
136+
raise ValueError("Handler must have at least one parameter")
137+
138+
arity = len(signature.parameters)
139+
infer_pydantic_io(handler_io, signature)
140+
85141
handler = Handler[I, O](service_tag,
86142
handler_io,
87143
kind,
88144
handler_name,
89145
wrapped,
90146
arity)
147+
91148
vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler
92149
return handler
93150

@@ -105,7 +162,10 @@ async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) ->
105162
Invoke the handler with the given context and input.
106163
"""
107164
if handler.arity == 2:
108-
in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore
165+
try:
166+
in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore
167+
except Exception as e:
168+
raise TerminalError(message=f"Unable to parse an input argument. {e}") from e
109169
out_arg = await handler.fn(ctx, in_arg) # type: ignore [call-arg, arg-type]
110170
else:
111171
out_arg = await handler.fn(ctx) # type: ignore [call-arg]

python/restate/object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def wrapper(fn):
8585
def wrapped(*args, **kwargs):
8686
return fn(*args, **kwargs)
8787

88-
arity = len(inspect.signature(fn).parameters)
89-
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, arity)
88+
signature = inspect.signature(fn)
89+
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature)
9090
self.handlers[handler.name] = handler
9191
return wrapped
9292

python/restate/serde.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,43 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
108108
return bytes(json.dumps(obj), "utf-8")
109109

110110

111+
class PydanticJsonSerde(Serde[I]):
112+
"""
113+
Serde for Pydantic models to/from JSON
114+
"""
115+
116+
def __init__(self, model):
117+
self.model = model
118+
119+
def deserialize(self, buf: bytes) -> typing.Optional[I]:
120+
"""
121+
Deserializes a bytearray to a Pydantic model.
122+
123+
Args:
124+
buf (bytearray): The bytearray to deserialize.
125+
126+
Returns:
127+
typing.Optional[I]: The deserialized Pydantic model.
128+
"""
129+
if not buf:
130+
return None
131+
return self.model.model_validate_json(buf)
132+
133+
def serialize(self, obj: typing.Optional[I]) -> bytes:
134+
"""
135+
Serializes a Pydantic model to a bytearray.
136+
137+
Args:
138+
obj (I): The Pydantic model to serialize.
139+
140+
Returns:
141+
bytearray: The serialized bytearray.
142+
"""
143+
if obj is None:
144+
return bytes()
145+
json_str = obj.model_dump_json() # type: ignore[attr-defined]
146+
return json_str.encode("utf-8")
147+
111148
def deserialize_json(buf: typing.ByteString) -> typing.Optional[O]:
112149
"""
113150
Deserializes a bytearray to a JSON object.

python/restate/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def wrapper(fn):
8484
def wrapped(*args, **kwargs):
8585
return fn(*args, **kwargs)
8686

87-
arity = len(inspect.signature(fn).parameters)
88-
handler = make_handler(self.service_tag, handler_io, name, None, wrapped, arity)
87+
signature = inspect.signature(fn)
88+
handler = make_handler(self.service_tag, handler_io, name, None, wrapped, signature)
8989
self.handlers[handler.name] = handler
9090
return wrapped
9191

python/restate/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def wrapper(fn):
114114
def wrapped(*args, **kwargs):
115115
return fn(*args, **kwargs)
116116

117-
arity = len(inspect.signature(fn).parameters)
118-
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, arity)
117+
signature = inspect.signature(fn)
118+
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature)
119119
self.handlers[handler.name] = handler
120120
return wrapped
121121

shell.nix

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{ pkgs ? import <nixpkgs> {} }:
22

33
(pkgs.buildFHSUserEnv {
4-
name = "my-python-env";
4+
name = "sdk-python";
55
targetPkgs = pkgs: (with pkgs; [
66
python3
77
python3Packages.pip
@@ -10,6 +10,7 @@
1010

1111
# rust
1212
rustup
13+
cargo
1314
clang
1415
llvmPackages.bintools
1516
protobuf
@@ -29,6 +30,6 @@
2930
LIBCLANG_PATH = pkgs.lib.makeLibraryPath [ pkgs.llvmPackages_latest.libclang.lib ];
3031

3132
runScript = ''
32-
bash
33+
bash
3334
'';
3435
}).env

0 commit comments

Comments
 (0)