1111"""This module contains the ASGI server for the restate framework."""
1212
1313import asyncio
14- from typing import Dict , TypedDict , Literal
14+ import inspect
15+ from typing import Any , Dict , TypedDict , Literal
1516import traceback
17+ import typing
1618from restate .discovery import compute_discovery_json
1719from restate .endpoint import Endpoint
1820from restate .server_context import ServerInvocationContext , DisconnectedException
19- from restate .server_types import Receive , ReceiveChannel , Scope , Send , binary_to_header , header_to_binary # pylint: disable=line-too-long
21+ from restate .server_types import Receive , ReceiveChannel , Scope , Send , binary_to_header , header_to_binary , LifeSpan # pylint: disable=line-too-long
2022from restate .vm import VMWrapper
2123from restate ._internal import PyIdentityVerifier , IdentityVerificationException # pylint: disable=import-error,no-name-in-module
2224from restate ._internal import SDK_VERSION # pylint: disable=import-error,no-name-in-module
@@ -186,10 +188,6 @@ async def process_invocation_to_completion(vm: VMWrapper,
186188 finally :
187189 context .on_attempt_finished ()
188190
189- class LifeSpanNotImplemented (ValueError ):
190- """Signal to the asgi server that we didn't implement lifespans"""
191-
192-
193191class ParsedPath (TypedDict ):
194192 """Parsed path from the request."""
195193 type : Literal ["invocation" , "health" , "discover" , "unknown" ]
@@ -216,8 +214,55 @@ def parse_path(request: str) -> ParsedPath:
216214 # anything other than invoke is 404
217215 return { "type" : "unknown" , "service" : None , "handler" : None }
218216
217+ def is_async_context_manager (obj : Any ):
218+ """check if passed object is an async context manager"""
219+ return (hasattr (obj , '__aenter__' ) and
220+ hasattr (obj , '__aexit__' ) and
221+ inspect .iscoroutinefunction (obj .__aenter__ ) and
222+ inspect .iscoroutinefunction (obj .__aexit__ ))
219223
220- def asgi_app (endpoint : Endpoint ):
224+
225+ async def lifespan_processor (
226+ scope : Scope ,
227+ receive : Receive ,
228+ send : Send ,
229+ lifespan : LifeSpan
230+ ) -> None :
231+ """Process lifespan context manager."""
232+ started = False
233+ assert scope ["type" ] in ["lifespan" , "lifespan.startup" , "lifespan.shutdown" ]
234+ assert is_async_context_manager (lifespan ()), "lifespan must be an async context manager"
235+ await receive ()
236+ try :
237+ async with lifespan () as maybe_state :
238+ if maybe_state is not None :
239+ if "state" not in scope :
240+ raise RuntimeError ("The server does not support state in lifespan" )
241+ scope ["state" ] = maybe_state
242+ await send ({
243+ "type" : "lifespan.startup.complete" , # type: ignore
244+ })
245+ started = True
246+ await receive ()
247+ except Exception :
248+ exc_text = traceback .format_exc ()
249+ if started :
250+ await send ({
251+ "type" : "lifespan.shutdown.failed" ,
252+ "message" : exc_text
253+ })
254+ else :
255+ await send ({
256+ "type" : "lifespan.startup.failed" ,
257+ "message" : exc_text
258+ })
259+ raise
260+ await send ({
261+ "type" : "lifespan.shutdown.complete" # type: ignore
262+ })
263+
264+ # pylint: disable=too-many-return-statements
265+ def asgi_app (endpoint : Endpoint , lifespan : typing .Optional [LifeSpan ] = None ):
221266 """Create an ASGI-3 app for the given endpoint."""
222267
223268 # Prepare request signer
@@ -226,14 +271,17 @@ def asgi_app(endpoint: Endpoint):
226271 async def app (scope : Scope , receive : Receive , send : Send ):
227272 try :
228273 if scope ['type' ] == 'lifespan' :
229- raise LifeSpanNotImplemented ()
274+ if lifespan is not None :
275+ await lifespan_processor (scope , receive , send , lifespan )
276+ return
277+ return
278+
230279 if scope ['type' ] != 'http' :
231280 raise NotImplementedError (f"Unknown scope type { scope ['type' ]} " )
232281
233282 request_path = scope ['path' ]
234283 assert isinstance (request_path , str )
235284 request : ParsedPath = parse_path (request_path )
236-
237285 # Health check
238286 if request ['type' ] == 'health' :
239287 await send_health_check (send )
@@ -249,7 +297,6 @@ async def app(scope: Scope, receive: Receive, send: Send):
249297 # Identify verification failed, send back unauthorized and close
250298 await send_status (send , receive , 401 )
251299 return
252-
253300 # might be a discovery request
254301 if request ['type' ] == 'discover' :
255302 await send_discovery (scope , send , endpoint )
@@ -283,8 +330,6 @@ async def app(scope: Scope, receive: Receive, send: Send):
283330 send )
284331 finally :
285332 await receive_channel .close ()
286- except LifeSpanNotImplemented as e :
287- raise e
288333 except Exception as e :
289334 traceback .print_exc ()
290335 raise e
0 commit comments