@@ -40,9 +40,14 @@ class Scope(TypedDict):
4040 state : NotRequired [Dict [str , Any ]]
4141 extensions : Optional [Dict [str , Dict [object , object ]]]
4242
43+ class RestateEvent (TypedDict ):
44+ """An event that represents a run completion"""
45+ type : Literal ["restate.run_completed" ]
46+ data : Optional [Dict [str , Any ]]
47+
4348class HTTPRequestEvent (TypedDict ):
4449 """ASGI Request event"""
45- type : Literal ["http.request" , "restate.run_completed" ]
50+ type : Literal ["http.request" ]
4651 body : bytes
4752 more_body : bool
4853
@@ -91,38 +96,35 @@ def binary_to_header(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[str,
9196class ReceiveChannel :
9297 """ASGI receive channel."""
9398
94- def __init__ (self , receive : Receive ):
95- self .queue = asyncio .Queue [ASGIReceiveEvent ]()
99+ def __init__ (self , receive : Receive ) -> None :
100+ self ._queue = asyncio .Queue [Union [ ASGIReceiveEvent , RestateEvent ] ]()
96101
97102 async def loop ():
98103 """Receive loop."""
99104 while True :
100105 event = await receive ()
101- await self .queue .put (event )
106+ await self ._queue .put (event )
102107 if event .get ('type' ) == 'http.disconnect' :
103108 break
104109
105- self .task = asyncio .create_task (loop ())
110+ self ._task = asyncio .create_task (loop ())
106111
107- async def rx (self ) -> ASGIReceiveEvent :
112+ async def __call__ (self ) -> ASGIReceiveEvent | RestateEvent :
108113 """Get the next message."""
109- what = await self .queue .get ()
110- self .queue .task_done ()
114+ what = await self ._queue .get ()
115+ self ._queue .task_done ()
111116 return what
112117
113- async def __call__ (self ):
114- """Get the next message."""
115- return await self .rx ()
116-
117- async def tx (self , what : ASGIReceiveEvent ):
118+ async def enqueue_restate_event (self , what : RestateEvent ):
118119 """Add a message."""
119- await self .queue .put (what )
120+ await self ._queue .put (what )
120121
121122 async def close (self ):
122123 """Close the channel."""
123- if self .task and not self .task .done ():
124- self .task .cancel ()
125- try :
126- await self .task
127- except asyncio .CancelledError :
128- pass
124+ if self ._task .done ():
125+ return
126+ self ._task .cancel ()
127+ try :
128+ await self ._task
129+ except asyncio .CancelledError :
130+ pass
0 commit comments