1616"""
1717
1818from dataclasses import dataclass
19+ from inspect import Signature
1920from typing import Any , Callable , Awaitable , Generic , Literal , Optional , TypeVar
2021
21- from restate .serde import Serde
22+ from restate .exceptions import TerminalError
23+ from restate .serde import JsonSerde , Serde , PydanticJsonSerde
2224
2325I = TypeVar ('I' )
2426O = TypeVar ('O' )
2527
2628# we will use this symbol to store the handler in the function
2729RESTATE_UNIQUE_HANDLER_SYMBOL = str (object ())
2830
31+
32+ def try_import_pydantic_base_model ():
33+ """
34+ Try to import PydanticBaseModel from Pydantic.
35+ """
36+ try :
37+ from pydantic import BaseModel # type: ignore # pylint: disable=import-outside-toplevel
38+ return BaseModel
39+ except ImportError :
40+ class Dummy : # pylint: disable=too-few-public-methods
41+ """a dummy class to use when Pydantic is not available"""
42+
43+ return Dummy
44+
45+ PYDANTIC_BASE_MODEL = try_import_pydantic_base_model ()
46+
2947@dataclass
3048class ServiceTag :
3149 """
@@ -42,13 +60,45 @@ class HandlerIO(Generic[I, O]):
4260 Attributes:
4361 accept (str): The accept header value for the handler.
4462 content_type (str): The content type header value for the handler.
45- serializer: The serializer function to convert output to bytes.
46- deserializer: The deserializer function to convert input type to bytes.
4763 """
4864 accept : str
4965 content_type : str
5066 input_serde : Serde [I ]
5167 output_serde : Serde [O ]
68+ pydantic_input_model : Optional [I ] = None
69+ pydantic_output_model : Optional [O ] = None
70+
71+ def is_pydantic (annotation ) -> bool :
72+ """
73+ Check if an object is a Pydantic model.
74+ """
75+ try :
76+ return issubclass (annotation , PYDANTIC_BASE_MODEL )
77+ except TypeError :
78+ # annotation is not a class or a type
79+ return False
80+
81+
82+ def infer_pydantic_io (handler_io : HandlerIO [I , O ], signature : Signature ):
83+ """
84+ Augment handler_io with Pydantic models when these are provided.
85+ This method will inspect the signature of an handler and will look for
86+ the input and the return types of a function, and will:
87+ * capture any Pydantic models (to be used later at discovery)
88+ * replace the default json serializer (is unchanged by a user) with a Pydantic serde
89+ """
90+ # check if the handlers I/O is a PydanticBaseModel
91+ annotation = list (signature .parameters .values ())[- 1 ].annotation
92+ if is_pydantic (annotation ):
93+ handler_io .pydantic_input_model = annotation
94+ if isinstance (handler_io .input_serde , JsonSerde ): # type: ignore
95+ handler_io .input_serde = PydanticJsonSerde (annotation )
96+
97+ annotation = signature .return_annotation
98+ if is_pydantic (annotation ):
99+ handler_io .pydantic_output_model = annotation
100+ if isinstance (handler_io .output_serde , JsonSerde ): # type: ignore
101+ handler_io .output_serde = PydanticJsonSerde (annotation )
52102
53103@dataclass
54104class Handler (Generic [I , O ]):
@@ -71,7 +121,7 @@ def make_handler(service_tag: ServiceTag,
71121 name : str | None ,
72122 kind : Optional [Literal ["exclusive" , "shared" , "workflow" ]],
73123 wrapped : Any ,
74- arity : int ) -> Handler [I , O ]:
124+ signature : Signature ) -> Handler [I , O ]:
75125 """
76126 Factory function to create a handler.
77127 """
@@ -82,12 +132,19 @@ def make_handler(service_tag: ServiceTag,
82132 if not handler_name :
83133 raise ValueError ("Handler name must be provided" )
84134
135+ if len (signature .parameters ) == 0 :
136+ raise ValueError ("Handler must have at least one parameter" )
137+
138+ arity = len (signature .parameters )
139+ infer_pydantic_io (handler_io , signature )
140+
85141 handler = Handler [I , O ](service_tag ,
86142 handler_io ,
87143 kind ,
88144 handler_name ,
89145 wrapped ,
90146 arity )
147+
91148 vars (wrapped )[RESTATE_UNIQUE_HANDLER_SYMBOL ] = handler
92149 return handler
93150
@@ -105,7 +162,10 @@ async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) ->
105162 Invoke the handler with the given context and input.
106163 """
107164 if handler .arity == 2 :
108- in_arg = handler .handler_io .input_serde .deserialize (in_buffer ) # type: ignore
165+ try :
166+ in_arg = handler .handler_io .input_serde .deserialize (in_buffer ) # type: ignore
167+ except Exception as e :
168+ raise TerminalError (message = f"Unable to parse an input argument. { e } " ) from e
109169 out_arg = await handler .fn (ctx , in_arg ) # type: ignore [call-arg, arg-type]
110170 else :
111171 out_arg = await handler .fn (ctx ) # type: ignore [call-arg]
0 commit comments