1111"""This module contains the ASGI server for the restate framework."""
1212
1313import asyncio
14- import inspect
15- from typing import Any , Dict , TypedDict , Literal
14+ from typing import Dict , TypedDict , Literal
1615import traceback
17- import typing
1816from restate .discovery import compute_discovery_json
1917from restate .endpoint import Endpoint
2018from restate .server_context import ServerInvocationContext , DisconnectedException
21- from restate .server_types import Receive , ReceiveChannel , Scope , Send , binary_to_header , header_to_binary , LifeSpan # pylint: disable=line-too-long
19+ from restate .server_types import Receive , ReceiveChannel , Scope , Send , binary_to_header , header_to_binary # pylint: disable=line-too-long
2220from restate .vm import VMWrapper
2321from restate ._internal import PyIdentityVerifier , IdentityVerificationException # pylint: disable=import-error,no-name-in-module
2422from restate ._internal import SDK_VERSION # pylint: disable=import-error,no-name-in-module
@@ -188,6 +186,10 @@ async def process_invocation_to_completion(vm: VMWrapper,
188186 finally :
189187 context .on_attempt_finished ()
190188
189+ class LifeSpanNotImplemented (ValueError ):
190+ """Signal to the asgi server that we didn't implement lifespans"""
191+
192+
191193class ParsedPath (TypedDict ):
192194 """Parsed path from the request."""
193195 type : Literal ["invocation" , "health" , "discover" , "unknown" ]
@@ -214,55 +216,8 @@ def parse_path(request: str) -> ParsedPath:
214216 # anything other than invoke is 404
215217 return { "type" : "unknown" , "service" : None , "handler" : None }
216218
217- def is_async_context_manager (obj : Any ):
218- """check if passed object is an async context manager"""
219- return (hasattr (obj , '__aenter__' ) and
220- hasattr (obj , '__aexit__' ) and
221- inspect .iscoroutinefunction (obj .__aenter__ ) and
222- inspect .iscoroutinefunction (obj .__aexit__ ))
223219
224-
225- async def lifespan_processor (
226- scope : Scope ,
227- receive : Receive ,
228- send : Send ,
229- lifespan : LifeSpan
230- ) -> None :
231- """Process lifespan context manager."""
232- started = False
233- assert scope ["type" ] in ["lifespan" , "lifespan.startup" , "lifespan.shutdown" ]
234- assert is_async_context_manager (lifespan ()), "lifespan must be an async context manager"
235- await receive ()
236- try :
237- async with lifespan () as maybe_state :
238- if maybe_state is not None :
239- if "state" not in scope :
240- raise RuntimeError ("The server does not support state in lifespan" )
241- scope ["state" ] = maybe_state
242- await send ({
243- "type" : "lifespan.startup.complete" , # type: ignore
244- })
245- started = True
246- await receive ()
247- except Exception :
248- exc_text = traceback .format_exc ()
249- if started :
250- await send ({
251- "type" : "lifespan.shutdown.failed" ,
252- "message" : exc_text
253- })
254- else :
255- await send ({
256- "type" : "lifespan.startup.failed" ,
257- "message" : exc_text
258- })
259- raise
260- await send ({
261- "type" : "lifespan.shutdown.complete" # type: ignore
262- })
263-
264- # pylint: disable=too-many-return-statements
265- def asgi_app (endpoint : Endpoint , lifespan : typing .Optional [LifeSpan ] = None ):
220+ def asgi_app (endpoint : Endpoint ):
266221 """Create an ASGI-3 app for the given endpoint."""
267222
268223 # Prepare request signer
@@ -271,17 +226,14 @@ def asgi_app(endpoint: Endpoint, lifespan: typing.Optional[LifeSpan] = None):
271226 async def app (scope : Scope , receive : Receive , send : Send ):
272227 try :
273228 if scope ['type' ] == 'lifespan' :
274- if lifespan is not None :
275- await lifespan_processor (scope , receive , send , lifespan )
276- return
277- return
278-
229+ raise LifeSpanNotImplemented ()
279230 if scope ['type' ] != 'http' :
280231 raise NotImplementedError (f"Unknown scope type { scope ['type' ]} " )
281232
282233 request_path = scope ['path' ]
283234 assert isinstance (request_path , str )
284235 request : ParsedPath = parse_path (request_path )
236+
285237 # Health check
286238 if request ['type' ] == 'health' :
287239 await send_health_check (send )
@@ -297,6 +249,7 @@ async def app(scope: Scope, receive: Receive, send: Send):
297249 # Identify verification failed, send back unauthorized and close
298250 await send_status (send , receive , 401 )
299251 return
252+
300253 # might be a discovery request
301254 if request ['type' ] == 'discover' :
302255 await send_discovery (scope , send , endpoint )
@@ -330,6 +283,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
330283 send )
331284 finally :
332285 await receive_channel .close ()
286+ except LifeSpanNotImplemented as e :
287+ raise e
333288 except Exception as e :
334289 traceback .print_exc ()
335290 raise e
0 commit comments