2929from restate .exceptions import TerminalError
3030from restate .handler import Handler , handler_from_callable , invoke_handler
3131from restate .serde import BytesSerde , DefaultSerde , JsonSerde , Serde
32- from restate .server_types import Receive , Send
32+ from restate .server_types import ReceiveChannel , Send
3333from restate .vm import Failure , Invocation , NotReady , SuspendedException , VMWrapper , RunRetryConfig # pylint: disable=line-too-long
3434from restate .vm import DoProgressAnyCompleted , DoProgressCancelSignalReceived , DoProgressReadFromInput , DoProgressExecuteRun , DoWaitPendingRun
3535
@@ -220,25 +220,6 @@ def peek(self) -> Awaitable[Any | None]:
220220# disable too many public method
221221# pylint: disable=R0904
222222
223- class SyncPoint :
224- """
225- This class implements a synchronization point.
226- """
227-
228- def __init__ (self ) -> None :
229- self .cond : asyncio .Event | None = None
230-
231- def awaiter (self ):
232- """Wait for the sync point."""
233- if self .cond is None :
234- self .cond = asyncio .Event ()
235- return self .cond .wait ()
236-
237- async def arrive (self ):
238- """arrive at the sync point."""
239- if self .cond is not None :
240- self .cond .set ()
241-
242223class Tasks :
243224 """
244225 This class implements a list of tasks.
@@ -284,7 +265,8 @@ def __init__(self,
284265 invocation : Invocation ,
285266 attempt_headers : Dict [str , str ],
286267 send : Send ,
287- receive : Receive ) -> None :
268+ receive : ReceiveChannel
269+ ) -> None :
288270 super ().__init__ ()
289271 self .vm = vm
290272 self .handler = handler
@@ -293,7 +275,6 @@ def __init__(self,
293275 self .send = send
294276 self .receive = receive
295277 self .run_coros_to_execute : dict [int , Callable [[], Awaitable [None ]]] = {}
296- self .sync_point = SyncPoint ()
297278 self .request_finished_event = asyncio .Event ()
298279 self .tasks = Tasks ()
299280
@@ -365,18 +346,6 @@ def on_attempt_finished(self):
365346 # ignore the cancelled error
366347 pass
367348
368-
369- async def receive_and_notify_input (self ):
370- """Receive input from the state machine."""
371- chunk = await self .receive ()
372- if chunk .get ('type' ) == 'http.disconnect' :
373- raise DisconnectedException ()
374- if chunk .get ('body' , None ) is not None :
375- assert isinstance (chunk ['body' ], bytes )
376- self .vm .notify_input (chunk ['body' ])
377- if not chunk .get ('more_body' , False ):
378- self .vm .notify_input_closed ()
379-
380349 async def take_and_send_output (self ):
381350 """Take output from state machine and send it"""
382351 output = self .vm .take_output ()
@@ -417,21 +386,22 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
417386 async def wrapper (f ):
418387 await f ()
419388 await self .take_and_send_output ()
420- await self .sync_point . arrive ( )
389+ await self .receive . tx ({ 'type' : 'restate.run_completed' , 'body' : bytes (), 'more_body' : True } )
421390
422391 task = asyncio .create_task (wrapper (fn ))
423392 self .tasks .add (task )
424393 continue
425394 if isinstance (do_progress_response , (DoWaitPendingRun , DoProgressReadFromInput )):
426- sync_task = asyncio .create_task (self .sync_point .awaiter ())
427- self .tasks .add (sync_task )
428-
429- read_task = asyncio .create_task (self .receive_and_notify_input ())
430- self .tasks .add (read_task )
431-
432- done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
433- if read_task in done :
434- _ = read_task .result () # propagate exception
395+ chunk = await self .receive ()
396+ if chunk .get ('type' ) == 'restate.run_completed' :
397+ continue
398+ if chunk .get ('type' ) == 'http.disconnect' :
399+ raise DisconnectedException ()
400+ if chunk .get ('body' , None ) is not None :
401+ assert isinstance (chunk ['body' ], bytes )
402+ self .vm .notify_input (chunk ['body' ])
403+ if not chunk .get ('more_body' , False ):
404+ self .vm .notify_input_closed ()
435405
436406 def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
437407 """Create a coroutine that fetches a result from a notification handle."""
0 commit comments