Skip to content

Commit 387ea22

Browse files
committed
Refactor event pipe handling into context manager and iterable
1 parent f376709 commit 387ea22

File tree

3 files changed

+101
-50
lines changed

3 files changed

+101
-50
lines changed

src/blueapi/service/interface.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from collections.abc import Mapping
3+
from dataclasses import dataclass
34
from functools import cache
45
from multiprocessing.connection import Connection
56
from typing import Any
@@ -283,27 +284,35 @@ def get_python_env(
283284
return get_python_environment(config=scratch, name=name, source=source)
284285

285286

286-
SubHandle = tuple[int, int, int]
287+
@dataclass
288+
class SubHandles:
289+
worker: int
290+
progress: int
291+
data: int
287292

288293

289-
def pipe_events(tx: Connection) -> SubHandle:
294+
def pipe_events(tx: Connection) -> SubHandles:
295+
tw = worker()
290296

291297
def handler(
292298
worker_event: WorkerEvent | DataEvent | ProgressEvent,
293299
_cor_id: str | None,
294300
) -> None:
295-
tx.send(worker_event)
296301

297-
task_worker = worker()
298-
w_id = task_worker.worker_events.subscribe(handler)
299-
d_id = task_worker.data_events.subscribe(handler)
300-
p_id = task_worker.progress_events.subscribe(handler)
301-
return (w_id, d_id, p_id)
302+
try:
303+
tx.send(worker_event)
304+
except BrokenPipeError:
305+
LOGGER.warning("Sending event to broken pipe")
306+
pass
302307

308+
w = tw.worker_events.subscribe(handler)
309+
d = tw.data_events.subscribe(handler)
310+
p = tw.progress_events.subscribe(handler)
311+
return SubHandles(worker=w, data=d, progress=p)
303312

304-
def unpipe_events(hnd: SubHandle) -> None:
305-
task_worker = worker()
306-
w, d, p = hnd
307-
task_worker.worker_events.unsubscribe(w)
308-
task_worker.data_events.unsubscribe(d)
309-
task_worker.progress_events.unsubscribe(p)
313+
314+
def unpipe_events(hnd: SubHandles):
315+
tw = worker()
316+
tw.worker_events.unsubscribe(hnd.worker)
317+
tw.data_events.unsubscribe(hnd.data)
318+
tw.progress_events.unsubscribe(hnd.progress)

src/blueapi/service/main.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import urllib.parse
33
from collections.abc import Awaitable, Callable
44
from contextlib import asynccontextmanager
5-
from multiprocessing import Pipe
65
from typing import Annotated, Any
76

87
import jwt
@@ -16,9 +15,9 @@
1615
Request,
1716
Response,
1817
WebSocket,
18+
WebSocketDisconnect,
1919
status,
2020
)
21-
from fastapi.concurrency import run_in_threadpool
2221
from fastapi.datastructures import Address
2322
from fastapi.middleware.cors import CORSMiddleware
2423
from fastapi.responses import RedirectResponse, StreamingResponse
@@ -555,44 +554,45 @@ async def run_plan(
555554
):
556555
user = "alice"
557556

558-
# ack ws
557+
LOGGER.info("Starting WS plan")
559558
await ws.accept()
560-
# accept task request through socket
561559
rq = await ws.receive_json()
562-
# submit task to runner
560+
LOGGER.info("Raw request: %s", rq)
563561
try:
564562
task_request: TaskRequest = TaskRequest.model_validate(rq)
563+
LOGGER.info("Plan request: %s", task_request)
565564
task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
565+
LOGGER.info("Task ID: %s", task_id)
566566
except ValidationError:
567+
LOGGER.error("Args not valid", exc_info=True)
567568
await ws.close(code=1003, reason="invalid args")
568569
return
569570
except KeyError:
571+
LOGGER.error("Plan not found", exc_info=True)
570572
await ws.close(code=1003, reason="unknown plan")
571573
return
572574

573-
# add listener to runner
574-
tx, rx = Pipe()
575-
h = runner.run(interface.pipe_events, tx=tx)
576-
# start task
577575
try:
578-
task = WorkerTask(task_id=task_id)
579-
runner.run(
580-
interface.begin_task,
581-
task=task,
582-
)
576+
with runner.event_pipe() as events:
577+
LOGGER.info("Created event pipe")
578+
runner.run(interface.begin_task, task=WorkerTask(task_id=task_id))
579+
async for evt in events:
580+
LOGGER.debug("Event: %s", evt)
581+
await ws.send_json(evt.model_dump(mode="json"))
582+
if isinstance(evt, WorkerEvent) and evt.is_complete():
583+
LOGGER.info("End of stream")
584+
break
583585
except WorkerBusyError:
586+
LOGGER.error("Worker was busy")
584587
await ws.close(code=1013, reason="Worker busy")
585-
return
586-
# pipe events to ws
587-
try:
588-
while True:
589-
event: AnyEvent = await run_in_threadpool(rx.recv)
590-
await ws.send_json(event.model_dump(mode="json"))
591-
if isinstance(event, WorkerEvent) and event.is_complete():
592-
break
593-
finally:
588+
except WebSocketDisconnect:
589+
LOGGER.info("Client disconnected")
590+
runner.run(
591+
interface.cancel_active_task, failure=True, reason="Client disconnected"
592+
)
593+
else:
594+
LOGGER.info("Plan complete")
594595
await ws.close()
595-
runner.run(interface.unpipe_events, hnd=h)
596596

597597

598598
@start_as_current_span(TRACER, "config")

src/blueapi/service/runner.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import asyncio
12
import inspect
23
import logging
34
import signal
45
import uuid
5-
from collections.abc import Callable
6+
from collections.abc import AsyncIterator, Callable
67
from importlib import import_module
78
from multiprocessing import Pool, set_start_method
9+
from multiprocessing.connection import Connection, Pipe
810
from multiprocessing.pool import Pool as PoolClass
911
from typing import Any, ParamSpec, TypeVar
1012

@@ -15,11 +17,13 @@
1517
)
1618
from opentelemetry.context import attach
1719
from opentelemetry.propagate import get_global_textmap
18-
from pydantic import TypeAdapter
1920

2021
from blueapi.config import ApplicationConfig
21-
from blueapi.service.interface import setup, teardown
22+
from blueapi.core.bluesky_types import DataEvent
23+
from blueapi.service import interface
24+
from blueapi.service.interface import SubHandles, setup, teardown
2225
from blueapi.service.model import EnvironmentResponse
26+
from blueapi.worker.event import ProgressEvent, WorkerEvent
2327

2428
# The default multiprocessing start method is fork
2529
set_start_method("spawn", force=True)
@@ -145,11 +149,57 @@ def run(
145149
kwargs,
146150
)
147151

152+
def event_pipe(self):
153+
return EventPipe(self)
154+
148155
@property
149156
def state(self) -> EnvironmentResponse:
150157
return self._state
151158

152159

160+
class EventStream:
161+
def __init__(self, rx: Connection):
162+
self._rx = rx
163+
164+
def __aiter__(self) -> AsyncIterator[WorkerEvent | DataEvent | ProgressEvent]:
165+
return self
166+
167+
async def __anext__(self) -> WorkerEvent | DataEvent | ProgressEvent:
168+
data_available = asyncio.Event()
169+
asyncio.get_event_loop().add_reader(self._rx.fileno(), data_available.set)
170+
try:
171+
while not self._rx.poll():
172+
await data_available.wait()
173+
data_available.clear()
174+
return self._rx.recv()
175+
except BrokenPipeError:
176+
raise StopAsyncIteration() from None
177+
finally:
178+
asyncio.get_event_loop().remove_reader(self._rx.fileno())
179+
180+
181+
class EventPipe:
182+
runner: WorkerDispatcher
183+
handles: list[tuple[SubHandles, Connection]]
184+
185+
def __init__(self, runner: WorkerDispatcher):
186+
self.runner = runner
187+
self.handles = []
188+
189+
def __enter__(self) -> EventStream:
190+
tx, rx = Pipe()
191+
hnd = self.runner.run(interface.pipe_events, tx)
192+
LOGGER.debug("Subscribing new event pipe: %s", hnd)
193+
self.handles.append((hnd, tx))
194+
return EventStream(rx)
195+
196+
def __exit__(self, *exc):
197+
hnd, conn = self.handles.pop()
198+
LOGGER.debug("Unsubscribing event pipe: %s", hnd)
199+
conn.close()
200+
self.runner.run(interface.unpipe_events, hnd)
201+
202+
153203
class InvalidRunnerStateError(Exception):
154204
def __init__(self, message):
155205
super().__init__(message)
@@ -173,15 +223,7 @@ def import_and_run_function(
173223
func: Callable[..., T] = _validate_function(
174224
mod.__dict__.get(function_name, None), function_name
175225
)
176-
value = func(*args, **kwargs)
177-
return _valid_return(value, expected_type)
178-
179-
180-
def _valid_return(value: Any, expected_type: type[T] | None = None) -> T:
181-
if expected_type is None:
182-
return value
183-
else:
184-
return TypeAdapter(expected_type).validate_python(value)
226+
return func(*args, **kwargs)
185227

186228

187229
def _validate_function(func: Any, function_name: str) -> Callable:

0 commit comments

Comments
 (0)