From cf36cb3fb7d72aaca31fbfee99865bac14aa0a54 Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 16 Oct 2025 09:46:30 +0200 Subject: [PATCH 01/10] refactor: make active workers private --- .../tests/end-2-end/test_procedures_e2e.py | 12 +++++----- .../scan_server/procedures/manager.py | 22 +++++++++---------- .../scan_server/procedures/worker_base.py | 1 - .../tests_scan_server/test_procedures.py | 14 ++++++------ 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py index 409c11f92..5a9daeb8e 100644 --- a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py @@ -62,7 +62,7 @@ def test_procedure_runner_spawns_worker( client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager], ): client, _, manager = client_logtool_and_manager - assert manager.active_workers == {} + assert manager._active_workers == {} endpoint = MessageEndpoints.procedure_request() msg = messages.ProcedureRequestMessage( identifier="sleep", args_kwargs=((), {"time_s": 2}), queue="test" @@ -77,8 +77,8 @@ def cb(worker: ContainerProcedureWorker): manager.add_callback("test", cb) client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump()) - _wait_while(lambda: manager.active_workers == {}, 5) - _wait_while(lambda: manager.active_workers != {}, 20) + _wait_while(lambda: manager._active_workers == {}, 5) + _wait_while(lambda: manager._active_workers != {}, 20) assert logs != [] @@ -91,7 +91,7 @@ def test_happy_path_container_procedure_runner( test_args = (1, 2, 3) test_kwargs = {"a": "b", "c": "d"} client, logtool, manager = client_logtool_and_manager - assert manager.active_workers == {} + assert manager._active_workers == {} conn = client.connector endpoint = MessageEndpoints.procedure_request() msg = messages.ProcedureRequestMessage( @@ -99,8 +99,8 @@ def test_happy_path_container_procedure_runner( ) conn.xadd(topic=endpoint, msg_dict=msg.model_dump()) - _wait_while(lambda: manager.active_workers == {}, 5) - _wait_while(lambda: manager.active_workers != {}, 20) + _wait_while(lambda: manager._active_workers == {}, 5) + _wait_while(lambda: manager._active_workers != {}, 20) logtool.fetch() assert logtool.is_present_in_any_message("procedure accepted: True, message:") diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index d402d92e3..7ce78dba3 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -45,7 +45,7 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): self._parent = parent self.lock = RLock() - self.active_workers: dict[str, ProcedureWorkerEntry] = {} + self._active_workers: dict[str, ProcedureWorkerEntry] = {} self.executor = ThreadPoolExecutor( max_workers=PROCEDURE.WORKER.MAX_WORKERS, thread_name_prefix="user_procedure_" ) @@ -86,7 +86,7 @@ def add_callback(self, queue: str, cb: Callable[[ProcedureWorker], Any]): self._callbacks[queue].append(cb) def _run_callbacks(self, queue: str): - if (worker := self.active_workers[queue]["worker"]) is None: + if (worker := self._active_workers[queue]["worker"]) is None: return for cb in self._callbacks.get(queue, []): cb(worker) @@ -106,7 +106,7 @@ def process_queue_request(self, msg: dict[str, Any]): self._ack(True, f"Running procedure {message_obj.identifier}") queue = message_obj.queue or PROCEDURE.WORKER.DEFAULT_QUEUE endpoint = MessageEndpoints.procedure_execution(queue) - logger.debug(f"active workers: {self.active_workers}, worker requested: {queue}") + logger.debug(f"active workers: {self._active_workers}, worker requested: {queue}") self._conn.rpush( endpoint, endpoint.message_type( @@ -120,14 +120,14 @@ def cleanup_worker(fut): with self.lock: logger.debug(f"cleaning up worker {fut} for queue {queue}...") self._run_callbacks(queue) - del self.active_workers[queue] + del self._active_workers[queue] with self.lock: - if queue not in self.active_workers: + if queue not in self._active_workers: new_worker = self.executor.submit(self.spawn, queue=queue) new_worker.add_done_callback(_log_on_end) new_worker.add_done_callback(cleanup_worker) - self.active_workers[queue] = {"worker": None, "future": new_worker} + self._active_workers[queue] = {"worker": None, "future": new_worker} def spawn(self, queue: str): """Spawn a procedure worker future which listens to a given queue, i.e. procedure queue list in Redis. @@ -135,13 +135,13 @@ def spawn(self, queue: str): Args: queue (str): name of the queue to spawn a worker for""" - if queue in self.active_workers and self.active_workers[queue]["worker"] is not None: + if queue in self._active_workers and self._active_workers[queue]["worker"] is not None: raise WorkerAlreadyExists( - f"Queue {queue} already has an active worker in {self.active_workers}!" + f"Queue {queue} already has an active worker in {self._active_workers}!" ) with self._worker_cls(self._server, queue, PROCEDURE.WORKER.QUEUE_TIMEOUT_S) as worker: with self.lock: - self.active_workers[queue]["worker"] = worker + self._active_workers[queue]["worker"] = worker worker.work() def shutdown(self): @@ -152,7 +152,7 @@ def shutdown(self): ) self._conn.shutdown() # cancel futures by hand to give us the opportunity to detatch them from redis if they have started - for entry in self.active_workers.values(): + for entry in self._active_workers.values(): cancelled = entry["future"].cancel() if not cancelled: # unblock any waiting workers and let them shutdown @@ -160,7 +160,7 @@ def shutdown(self): # redis unblock executor.client_id worker.abort() futures.wait( - (entry["future"] for entry in self.active_workers.values()), + (entry["future"] for entry in self._active_workers.values()), timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S, ) self.executor.shutdown() diff --git a/bec_server/bec_server/scan_server/procedures/worker_base.py b/bec_server/bec_server/scan_server/procedures/worker_base.py index a18b0dcf8..d422d6e2b 100644 --- a/bec_server/bec_server/scan_server/procedures/worker_base.py +++ b/bec_server/bec_server/scan_server/procedures/worker_base.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from enum import Enum, auto from typing import cast from bec_lib.endpoints import MessageEndpoints diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py index 0fda944c9..4ad5f7038 100644 --- a/bec_server/tests/tests_scan_server/test_procedures.py +++ b/bec_server/tests/tests_scan_server/test_procedures.py @@ -112,7 +112,7 @@ def test_process_request_happy_paths(process_request_manager, message: Procedure assert queue in endpoint.endpoint assert execution_msg.identifier == message.identifier process_request_manager.spawn.assert_called() - assert queue in process_request_manager.active_workers.keys() + assert queue in process_request_manager._active_workers.keys() def test_process_request_failure(process_request_manager): @@ -120,7 +120,7 @@ def test_process_request_failure(process_request_manager): process_request_manager._ack.assert_not_called() process_request_manager._conn.rpush.assert_not_called() process_request_manager.spawn.assert_not_called() - assert process_request_manager.active_workers == {} + assert process_request_manager._active_workers == {} class UnlockableWorker(ProcedureWorker): @@ -162,13 +162,13 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): procedure_manager._validate_request = MagicMock(side_effect=lambda msg: msg) # trigger the running of the test message procedure_manager.process_queue_request(message) # type: ignore - assert queue in procedure_manager.active_workers.keys() + assert queue in procedure_manager._active_workers.keys() # spawn method should be added as a future - _wait_until(procedure_manager.active_workers[queue]["future"].running) + _wait_until(procedure_manager._active_workers[queue]["future"].running) # and then create the worker - _wait_until(lambda: procedure_manager.active_workers[queue].get("worker") is not None) - worker = procedure_manager.active_workers[queue]["worker"] + _wait_until(lambda: procedure_manager._active_workers[queue].get("worker") is not None) + worker = procedure_manager._active_workers[queue]["worker"] assert isinstance(worker, UnlockableWorker) _wait_until(lambda: worker.status == ProcedureWorkerStatus.RUNNING) @@ -185,7 +185,7 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): worker.event_2.set() _wait_until(lambda: worker.status == ProcedureWorkerStatus.FINISHED) # spawn deletes the worker queue - _wait_until(lambda: len(procedure_manager.active_workers) == 0) + _wait_until(lambda: len(procedure_manager._active_workers) == 0) @patch("bec_server.scan_server.procedures.worker_base.RedisConnector", MagicMock()) From e368ba260d9e4f6295e6810e3715df8f3c2b14d6 Mon Sep 17 00:00:00 2001 From: David Perl Date: Fri, 17 Oct 2025 11:38:29 +0200 Subject: [PATCH 02/10] feat: add status API to procedure manager --- bec_lib/bec_lib/messages.py | 1 + .../scan_server/procedures/manager.py | 13 +++++++- .../tests_scan_server/test_procedures.py | 32 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 28c12943b..318f11b05 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -19,6 +19,7 @@ class ProcedureWorkerStatus(Enum): IDLE = auto() FINISHED = auto() DEAD = auto() # worker lost communication with the container + NONE = auto() # worker doesn't exist in manager, caught during creation and cleanup class BECStatus(Enum): diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index 7ce78dba3..271422ca4 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -10,7 +10,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.messages import ProcedureRequestMessage, RequestResponseMessage +from bec_lib.messages import ProcedureRequestMessage, ProcedureWorkerStatus, RequestResponseMessage from bec_lib.redis_connector import RedisConnector from bec_server.scan_server.procedures import procedure_registry from bec_server.scan_server.procedures.constants import PROCEDURE, WorkerAlreadyExists @@ -164,3 +164,14 @@ def shutdown(self): timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S, ) self.executor.shutdown() + + def active_workers(self) -> list[str]: + with self.lock: + return list(self._active_workers.keys()) + + def worker_statuses(self) -> dict[str, ProcedureWorkerStatus]: + with self.lock: + return { + q: w["worker"].status if w["worker"] is not None else ProcedureWorkerStatus.NONE + for q, w in self._active_workers.items() + } diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py index 4ad5f7038..694b7145d 100644 --- a/bec_server/tests/tests_scan_server/test_procedures.py +++ b/bec_server/tests/tests_scan_server/test_procedures.py @@ -267,3 +267,35 @@ def test_register_rejects_already_registered(): with pytest.raises(ProcedureRegistryError) as e: register("run scan", lambda *_, **__: None) assert e.match("already registered") + + +def _yield_once(): + yield "value" + while True: + yield None + + +@patch( + "bec_server.scan_server.procedures.worker_base.RedisConnector", + side_effect=lambda *_: MagicMock( + blocking_list_pop_to_set_add=MagicMock(side_effect=_yield_once()) + ), +) +def test_manager_status_api(_conn, procedure_manager): + procedure_manager._worker_cls = UnlockableWorker + for message in PROCESS_REQUEST_TEST_CASES: + procedure_manager.process_queue_request(message) + _wait_until(lambda: procedure_manager.active_workers() == ["primary", "queue2"]) + _wait_until( + lambda: procedure_manager.worker_statuses() + == {"primary": ProcedureWorkerStatus.RUNNING, "queue2": ProcedureWorkerStatus.RUNNING} + ) + for w in procedure_manager._active_workers.values(): + w["worker"].event_1.set() + _wait_until( + lambda: procedure_manager.worker_statuses() + == {"primary": ProcedureWorkerStatus.IDLE, "queue2": ProcedureWorkerStatus.IDLE} + ) + for w in procedure_manager._active_workers.values(): + w["worker"].event_2.set() + _wait_until(lambda: procedure_manager.active_workers() == []) From 03cc330ecb477d4722137254a7335ab0950f3793 Mon Sep 17 00:00:00 2001 From: perl_d Date: Tue, 21 Oct 2025 10:14:52 +0200 Subject: [PATCH 03/10] docs: fix message docstring --- bec_server/bec_server/scan_server/procedures/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index 271422ca4..b171a5427 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -6,12 +6,12 @@ from threading import RLock from typing import Any, Callable, TypedDict -from pydantic import ValidationError - from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.messages import ProcedureRequestMessage, ProcedureWorkerStatus, RequestResponseMessage from bec_lib.redis_connector import RedisConnector +from pydantic import ValidationError + from bec_server.scan_server.procedures import procedure_registry from bec_server.scan_server.procedures.constants import PROCEDURE, WorkerAlreadyExists from bec_server.scan_server.procedures.worker_base import ProcedureWorker From a0f622e90649d765a59a99281271b77e25fdae33 Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 23 Oct 2025 11:08:42 +0200 Subject: [PATCH 04/10] feat: add lrem redis operation --- bec_lib/bec_lib/endpoints.py | 2 +- bec_lib/bec_lib/redis_connector.py | 17 +++++++++++++++++ bec_lib/tests/test_redis_connector_fakeredis.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index cc4f56b90..3e46747c7 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -41,7 +41,7 @@ class MessageOp(list[str], enum.Enum): SET_PUBLISH = ["register", "set_and_publish", "delete", "get", "keys"] SEND = ["send", "register"] STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last", "delete"] - LIST = ["lpush", "lrange", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"] + LIST = ["lpush", "lrange", "lrem", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"] KEY_VALUE = ["set", "get", "delete", "keys"] SET = ["remove_from_set", "get_set_members"] diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index 115d7b820..e302b691d 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1105,6 +1105,23 @@ def lrange(self, topic: str, start: int, end: int, pipe: Pipeline | None = None) ret.append(msg) return ret + @validate_endpoint("topic") + def lrem(self, topic: str, count: int, msg, pipe: Pipeline | None = None): + """Removes the first count occurrences of elements equal to element from the list stored at key. + The count argument influences the operation in the following ways: + count > 0: Remove elements equal to element moving from head to tail. + count < 0: Remove elements equal to element moving from tail to head. + count = 0: Remove all elements equal to element. + For example, LREM list -2 "hello" will remove the last two occurrences of "hello" in the list stored at list. + + Returns the number of items removed + """ + + client = pipe if pipe is not None else self._redis_conn + if isinstance(msg, BECMessage): + msg = MsgpackSerialization.dumps(msg) + return client.lrem(topic, count, msg) + @validate_endpoint("topic") def set_and_publish( self, topic: str, msg, pipe: Pipeline | None = None, expire: int | None = None diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index beb4173d8..777d8b300 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -568,3 +568,17 @@ def test_redis_connector_message_alternated_pass(connected_connector): assert msg_received == [data_to_check] assert msg_received == [data_original] + + +def test_lrem(connected_connector: RedisConnector): + conn, ep = connected_connector, MessageEndpoints.procedure_execution + msgs = [ + messages.ProcedureExecutionMessage(identifier=_id, queue="primary", args_kwargs=((), {})) + for _id in ("a", "b", "c") + ] + for msg in msgs: + conn.rpush(ep(msg.queue), msg) + assert len(conn.lrange(ep("primary"), 0, -1)) == 3 + conn.lrem(ep(msgs[1].queue), 0, msgs[1]) + list_contents = conn.lrange(ep("primary"), 0, -1) + assert list_contents == [msgs[0], msgs[2]] From fe73256df09cf8cb49e5ff3bd09a3fb787829b19 Mon Sep 17 00:00:00 2001 From: David Perl Date: Mon, 20 Oct 2025 14:29:54 +0200 Subject: [PATCH 05/10] feat: extend procedure manager and API - add helper for Redis communication - add abort functionality - add removal from aborted queues - add status API - fix bugs in manager to update redis state correctly --- .../tests/end-2-end/test_procedures_e2e.py | 7 +- bec_lib/bec_lib/endpoints.py | 70 ++++++- bec_lib/bec_lib/messages.py | 63 +++++- bec_lib/bec_lib/redis_connector.py | 7 +- bec_lib/tests/test_bec_messages.py | 25 +++ .../scan_server/procedures/_dev_runner.py | 1 - .../scan_server/procedures/constants.py | 1 + .../scan_server/procedures/container_utils.py | 35 +++- .../procedures/container_worker.py | 59 ++++-- .../scan_server/procedures/helper.py | 190 ++++++++++++++++++ .../procedures/in_process_worker.py | 6 +- .../scan_server/procedures/manager.py | 162 ++++++++++++--- .../scan_server/procedures/protocol.py | 2 +- .../scan_server/procedures/worker_base.py | 19 +- .../tests_scan_server/test_container_utils.py | 3 +- .../test_procedure_container_worker.py | 27 +-- .../tests_scan_server/test_procedures.py | 150 +++++++++++++- 17 files changed, 736 insertions(+), 91 deletions(-) create mode 100644 bec_server/bec_server/scan_server/procedures/helper.py diff --git a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py index 5a9daeb8e..370d83792 100644 --- a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py @@ -11,7 +11,6 @@ from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_server.scan_server.procedures.constants import PROCEDURE from bec_server.scan_server.procedures.container_utils import get_backend from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager @@ -52,7 +51,7 @@ def _wait_while(cond: Callable[[], bool], timeout_s): def test_building_worker_image(): podman_utils = get_backend() build = podman_utils.build_worker_image() - assert len(build._command_output.splitlines()[-1]) == 64 + assert len(build._command_output.splitlines()[-1]) == 64 # type: ignore assert podman_utils.image_exists(f"bec_procedure_worker:v{version('bec_lib')}") @@ -75,7 +74,7 @@ def cb(worker: ContainerProcedureWorker): logs = worker._backend.logs(worker._container_id) manager.add_callback("test", cb) - client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump()) + client.connector.xadd(topic=endpoint, msg_dict=msg) _wait_while(lambda: manager._active_workers == {}, 5) _wait_while(lambda: manager._active_workers != {}, 20) @@ -97,7 +96,7 @@ def test_happy_path_container_procedure_runner( msg = messages.ProcedureRequestMessage( identifier="log execution message args", args_kwargs=(test_args, test_kwargs) ) - conn.xadd(topic=endpoint, msg_dict=msg.model_dump()) + conn.xadd(topic=endpoint, msg_dict=msg) _wait_while(lambda: manager._active_workers == {}, 5) _wait_while(lambda: manager._active_workers != {}, 20) diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 3e46747c7..d554c452a 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -43,7 +43,7 @@ class MessageOp(list[str], enum.Enum): STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last", "delete"] LIST = ["lpush", "lrange", "lrem", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"] KEY_VALUE = ["set", "get", "delete", "keys"] - SET = ["remove_from_set", "get_set_members"] + SET = ["remove_from_set", "get_set_members", "delete"] MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]") @@ -1420,8 +1420,8 @@ def available_procedures() -> EndpointInfo: @staticmethod def procedure_request() -> EndpointInfo: """ - Endpoint for scan queue request. This endpoint is used to request the new scans. - The request is sent using a messages.ScanQueueMessage message. + Endpoint for requesting new procedures. + The request is sent using a messages.ProcedureRequestMessage message. Returns: EndpointInfo: Endpoint for scan queue request. @@ -1467,6 +1467,25 @@ def procedure_execution(queue_id: str): message_op=MessageOp.LIST, ) + @staticmethod + def unhandled_procedure_execution(queue_id: str): + """ + Endpoint for procedure executions which were pending when the manager was shutdown. + Messages from procedure_execution are moved here on manager startup. + The request is sent using a messages.ProcedureExecutionMessage message. + + Returns: + EndpointInfo: Endpoint for scan queue request. + """ + endpoint = ( + f"{EndpointType.INTERNAL.value}/procedures/unhandled_procedure_execution/{queue_id}" + ) + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureExecutionMessage, + message_op=MessageOp.LIST, + ) + @staticmethod def active_procedure_executions(): """ @@ -1483,6 +1502,36 @@ def active_procedure_executions(): message_op=MessageOp.SET, ) + @staticmethod + def procedure_abort(): + """ + Endpoint to request aborting a running procedure + + Returns: + EndpointInfo: Endpoint for set of active procedure executions. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/abort" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureAbortMessage, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def procedure_clear_unhandled(): + """ + Endpoint to request aborting a running procedure + + Returns: + EndpointInfo: Endpoint for set of active procedure executions. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/clear_unhandled" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureClearUnhandledMessage, + message_op=MessageOp.STREAM, + ) + @staticmethod def procedure_worker_status_update(queue_id: str): """ @@ -1498,6 +1547,21 @@ def procedure_worker_status_update(queue_id: str): message_op=MessageOp.LIST, ) + @staticmethod + def procedure_queue_notif(): + """ + PubSub channel for a consumer (e.g. BEC widgets) to be notified of changes to a procedure queue + + Returns: + EndpointInfo: Endpoint for procedure queue updates for given queue. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/queue_notif" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureQNotifMessage, + message_op=MessageOp.SEND, + ) + @staticmethod def gui_registry_state(gui_id: str): """ diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 318f11b05..39bdc51c9 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -5,7 +5,8 @@ import warnings from copy import deepcopy from enum import Enum, auto -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, Self +from uuid import uuid4 import numpy as np from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator @@ -1213,18 +1214,67 @@ class ProcedureRequestMessage(BECMessage): queue: str | None = None +class ProcedureQNotifMessage(BECMessage): + """Message type for notifying watchers of changes to queues""" + + msg_type: ClassVar[str] = "procedure_queue_notif_message" + queue_name: str + queue_type: Literal["execution", "unhandled"] + + class ProcedureExecutionMessage(BECMessage): """Message type for sending procedure execution instructions to the scheduler Sent by the user to the procedure_request topic. It will be consumed by the scan server. Args: identifier (str): name of the procedure registered with the server + queue (str): the procedure queue this execution belongs to + args_kwargs (tuple[tuple[Any, ...], dict[str, Any]]): arguments for the procedure function """ msg_type: ClassVar[str] = "procedure_execution_message" identifier: str queue: str args_kwargs: tuple[tuple[Any, ...], dict[str, Any]] = (), {} + execution_id: str = Field(default_factory=lambda: str(uuid4())) + + +class ProcedureAbortMessage(BECMessage): + """Message type to request aborting a procedure or procedure queue + + One and only one of the args should be supplied. + Args: + queue (str | None): the procedure queue to abort + execution_id (str | None): the procedure execution to abort + abort_all (bool | None): abort all procedures if true + """ + + msg_type: ClassVar[str] = "procedure_abort_message" + queue: str | None = None + execution_id: str | None = None + abort_all: bool | None = None + + @model_validator(mode="after") + def mutually_exclusive(self) -> Self: + if (self.queue, self.execution_id, self.abort_all).count(None) != 2: + raise ValueError( + "Please only supply one argument! Supplied: \n" + f" {self.queue=}, {self.execution_id=}, {self.abort_all=}" + ) + return self + + +class ProcedureClearUnhandledMessage(ProcedureAbortMessage): + """Message type to request clearing an unhandled procedure or procedure queue + + One and only one of the args should be supplied. + Args: + queue (str | None): the procedure queue to abort + execution_id (str | None): the procedure execution to abort + abort_all (bool | None): abort all procedures if true + """ + + ... class ProcedureWorkerStatusMessage(BECMessage): @@ -1233,12 +1283,21 @@ class ProcedureWorkerStatusMessage(BECMessage): Args: worker_queue (str): Worker queue ID status (str): Worker status - + current_execution_id (str | None): ID of the current job, only allowed for RUNNING """ msg_type: ClassVar[str] = "procedure_worker_status_message" worker_queue: str status: ProcedureWorkerStatus + current_execution_id: str | None = None + + @model_validator(mode="after") + def check_id(self) -> Self: + if self.current_execution_id is not None and self.status != ProcedureWorkerStatus.RUNNING: + raise ValueError("Adding an execution ID is only valid for the RUNNING status") + if self.current_execution_id is None and self.status == ProcedureWorkerStatus.RUNNING: + raise ValueError("Adding an execution ID is mandatory for the RUNNING status") + return self class LoginInfoMessage(BECMessage): diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index e302b691d..0eb336bd1 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1179,7 +1179,7 @@ def mget(self, topics: list[str], pipe: Pipeline | None = None): def xadd( self, topic: str, - msg_dict: dict, + msg_dict: dict | BECMessage, max_size=None, pipe: Pipeline | None = None, expire: int | None = None, @@ -1189,7 +1189,7 @@ def xadd( Args: topic (str): redis topic - msg_dict (dict): message to add + msg_dict (dict | BECMessage): message to add max_size (int, optional): max size of stream. Defaults to None. pipe (Pipeline, optional): redis pipe. Defaults to None. expire (int, optional): expire time. Defaults to None. @@ -1205,7 +1205,8 @@ def xadd( else: client = self._redis_conn - msg_dict = {key: MsgpackSerialization.dumps(val) for key, val in msg_dict.items()} + msg = msg_dict.model_dump() if isinstance(msg_dict, BECMessage) else msg_dict + msg_dict = {key: MsgpackSerialization.dumps(val) for key, val in msg.items()} if max_size: client.xadd(topic, msg_dict, maxlen=max_size) diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index e94ba9196..a1a7cf2d4 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -243,6 +243,31 @@ def test_ProcedureWorkerStatusMessage(): assert res_loaded == msg +def test_ProcedureWorkerStatusMessage_validation(): + with pytest.raises(pydantic.ValidationError) as e: + messages.ProcedureWorkerStatusMessage( + worker_queue="background tasks", + status=messages.ProcedureWorkerStatus.RUNNING, + metadata={"RID": "1234"}, + ) + assert e.match("Adding an execution ID is mandatory") + with pytest.raises(pydantic.ValidationError) as e: + messages.ProcedureWorkerStatusMessage( + worker_queue="background tasks", + status=messages.ProcedureWorkerStatus.IDLE, + metadata={"RID": "1234"}, + current_execution_id="test", + ) + assert e.match("Adding an execution ID is only valid") + + +def test_ProcedureAbortMessage_validation(): + with pytest.raises(pydantic.ValidationError) as e: + messages.ProcedureAbortMessage(queue="test", execution_id="test") + assert e.match("only supply one argument") + messages.ProcedureAbortMessage(queue="test") + + def test_FileMessage(): msg = messages.FileMessage( device_name="samx", diff --git a/bec_server/bec_server/scan_server/procedures/_dev_runner.py b/bec_server/bec_server/scan_server/procedures/_dev_runner.py index e87ae4b78..b95db1c27 100644 --- a/bec_server/bec_server/scan_server/procedures/_dev_runner.py +++ b/bec_server/bec_server/scan_server/procedures/_dev_runner.py @@ -7,7 +7,6 @@ from bec_lib.logger import bec_logger from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker - from bec_server.scan_server.procedures.in_process_worker import InProcessProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager logger = bec_logger.logger diff --git a/bec_server/bec_server/scan_server/procedures/constants.py b/bec_server/bec_server/scan_server/procedures/constants.py index 376266275..6bdfeee85 100644 --- a/bec_server/bec_server/scan_server/procedures/constants.py +++ b/bec_server/bec_server/scan_server/procedures/constants.py @@ -77,4 +77,5 @@ class PodmanContainerStates(str, Enum): RUNNING = "running" PAUSED = "paused" STOPPED = "stopped" + STOPPING = "stopping" EXITED = "exited" diff --git a/bec_server/bec_server/scan_server/procedures/container_utils.py b/bec_server/bec_server/scan_server/procedures/container_utils.py index 83e23c31e..766c11901 100644 --- a/bec_server/bec_server/scan_server/procedures/container_utils.py +++ b/bec_server/bec_server/scan_server/procedures/container_utils.py @@ -96,6 +96,7 @@ def run( volumes: list[VolumeSpec], command: str, pod_name: str | None = None, + container_name: str | None = None, ) -> str: with PodmanClient(base_url=self.uri) as client: try: @@ -106,6 +107,7 @@ def run( environment=environment, mounts=volumes, pod=pod_name, + name=container_name, ) # type: ignore # running with detach returns container object except APIError as e: if e.status_code == HTTPStatus.INTERNAL_SERVER_ERROR: @@ -151,8 +153,9 @@ def pretty_print(self) -> str: class PodmanCliUtils(_PodmanUtilsBase): - def _run_and_capture_error(self, *args: str): - logger.debug(f"Running {args}") + def _run_and_capture_error(self, *args: str, log: bool = True): + if log: + logger.debug(f"Running {args}") output = subprocess.run([*args], capture_output=True) if output.returncode != 0: raise ProcedureWorkerError( @@ -166,7 +169,7 @@ def _run_and_capture_error(self, *args: str): def _podman_ls_json(self, subcom: Literal["image", "container"] = "container"): return json.loads( self._run_and_capture_error( - "podman", subcom, "list", "--all", "--format", "json" + "podman", subcom, "list", "--all", "--format", "json", log=False ).stdout ) @@ -193,6 +196,7 @@ def run( volumes: list[VolumeSpec], command: str, pod_name: str | None = None, + container_name: str | None = None, ) -> str: _volumes = [ f"{vol['source']}:{vol['target']}{':ro' if vol['read_only'] else ''}" for vol in volumes @@ -200,22 +204,39 @@ def run( _volume_args = list(chain(*(("-v", vol) for vol in _volumes))) _environment = _multi_args_from_dict("-e", environment) # type: ignore # this is actually a dict[str, str] _pod_arg = ["--pod", pod_name] if pod_name else [] + _name_arg = ["--replace", "--name", container_name] if container_name else [] return ( self._run_and_capture_error( - "podman", "run", *_environment, "-d", *_volume_args, *_pod_arg, image_tag, command + "podman", + "run", + *_environment, + "-d", + *_name_arg, + *_volume_args, + *_pod_arg, + image_tag, + command, ) .stdout.decode() .strip() ) def kill(self, id: str): - self._run_and_capture_error("podman", "kill", id) - self._run_and_capture_error("podman", "rm", id) + try: + self._run_and_capture_error("podman", "kill", id) + except ProcedureWorkerError as e: + logger.error(e) def logs(self, id: str) -> list[str]: - return self._run_and_capture_error("podman", "logs", id).stderr.decode().splitlines() + try: + return self._run_and_capture_error("podman", "logs", id).stderr.decode().splitlines() + except ProcedureWorkerError as e: + logger.error(e) + return [f"No logs found for container {id}\n"] def state(self, id: str) -> PodmanContainerStates | None: for container in self._podman_ls_json(): if container["Id"] == id or container["Id"].startswith(id): + if names := container.get("Names"): + logger.debug(f"Container {names[0]} status: {container['State']}") return PodmanContainerStates(container["State"]) diff --git a/bec_server/bec_server/scan_server/procedures/container_worker.py b/bec_server/bec_server/scan_server/procedures/container_worker.py index 0d6c2e5b1..53cc109d9 100644 --- a/bec_server/bec_server/scan_server/procedures/container_worker.py +++ b/bec_server/bec_server/scan_server/procedures/container_worker.py @@ -15,6 +15,7 @@ ProcedureWorkerError, ) from bec_server.scan_server.procedures.container_utils import get_backend +from bec_server.scan_server.procedures.helper import BackendProcedureHelper from bec_server.scan_server.procedures.protocol import ContainerCommandBackend from bec_server.scan_server.procedures.worker_base import ProcedureWorker @@ -34,7 +35,7 @@ def _worker_environment(self) -> ContainerWorkerEnv: minimum necessary, or things which are only necessary for the functioning of the worker, and other information should be passed through redis""" return { - "redis_server": f"{self._conn.host}:{self._conn.port}", + "redis_server": f"redis:{self._conn.port}", "queue": self._queue, "timeout_s": str(self._lifetime_s), } @@ -42,6 +43,7 @@ def _worker_environment(self) -> ContainerWorkerEnv: def _setup_execution_environment(self): self._backend: ContainerCommandBackend = get_backend() image_tag = f"{PROCEDURE.CONTAINER.IMAGE_NAME}:v{PROCEDURE.BEC_VERSION}" + self.container_name = f"bec_procedure_{PROCEDURE.BEC_VERSION}_{self._queue}" if not self._backend.image_exists(image_tag): self._backend.build_worker_image() self._container_id = self._backend.run( @@ -57,6 +59,7 @@ def _setup_execution_environment(self): ], PROCEDURE.CONTAINER.COMMAND, pod_name=PROCEDURE.CONTAINER.POD_NAME, + container_name=self.container_name, ) def _run_task(self, item: ProcedureExecutionMessage): @@ -64,12 +67,16 @@ def _run_task(self, item: ProcedureExecutionMessage): f"Container worker _run_task() called with {item} - this should never happen!" ) - def _kill_process(self): - if self._backend.state(self._container_id) not in [ + def _ending_or_ended(self): + return self._backend.state(self._container_id) in [ PodmanContainerStates.EXITED, PodmanContainerStates.STOPPED, - ]: - self._backend.kill(self._container_id) + PodmanContainerStates.STOPPING, + ] + + def _kill_process(self): + if not self._ending_or_ended(): + self._backend.kill(self.container_name) def work(self): """block until the container is finished, listen for status updates in the meantime""" @@ -77,27 +84,37 @@ def work(self): # on timeout check if container is still running status_update = None - while self._backend.state(self._container_id) not in [ - PodmanContainerStates.EXITED, - PodmanContainerStates.STOPPED, - ]: + while not self._ending_or_ended(): status_update = self._conn.blocking_list_pop( - MessageEndpoints.procedure_worker_status_update(self._queue), timeout_s=1 + MessageEndpoints.procedure_worker_status_update(self._queue), timeout_s=0.2 ) if status_update is not None: if not isinstance(status_update, messages.ProcedureWorkerStatusMessage): raise ProcedureWorkerError(f"Received unexpected message {status_update}") self.status = status_update.status + self._current_execution_id = status_update.current_execution_id logger.info( f"Container worker '{self._queue}' status update: {status_update.status.name}" ) # TODO: we probably do want to handle some kind of timeout here but we don't know how # long a running procedure should actually take - it could theoretically be infinite + if self.status != ProcedureWorkerStatus.FINISHED: + self.status = ProcedureWorkerStatus.DEAD + + def abort_execution(self, execution_id: str): + """Abort the execution with the given id. Has no effect if the given ID is not the current job""" + if execution_id == self._current_execution_id: + self._backend.kill(self._container_id) + self._helper.remove_from_active.by_exec_id(execution_id) + logger.info( + f"Aborting execution {execution_id}, restarting worker for queue: {self._queue}" + ) + self._setup_execution_environment() def main(): """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command""" - logger.info(f"Container worker starting up") + logger.info("Container worker starting up") try: needed_keys = ContainerWorkerEnv.__annotations__.keys() logger.debug(f"Checking for environment variables: {needed_keys}") @@ -125,6 +142,7 @@ def main(): endpoint_info = MessageEndpoints.procedure_execution(env["queue"]) conn = RedisConnector(env["redis_server"]) + helper = BackendProcedureHelper(conn) active_procs_endpoint = MessageEndpoints.active_procedure_executions() status_endpoint = MessageEndpoints.procedure_worker_status_update(env["queue"]) @@ -138,11 +156,13 @@ def main(): ) timeout_s = PROCEDURE.WORKER.QUEUE_TIMEOUT_S - def _push_status(status: ProcedureWorkerStatus): + def _push_status(status: ProcedureWorkerStatus, id: str | None = None): logger.debug(f"Updating container worker status to {status.name}") conn.rpush( status_endpoint, - messages.ProcedureWorkerStatusMessage(worker_queue=env["queue"], status=status), + messages.ProcedureWorkerStatusMessage( + worker_queue=env["queue"], status=status, current_execution_id=id + ), ) def _run_task(item: ProcedureExecutionMessage): @@ -159,9 +179,16 @@ def _run_task(item: ProcedureExecutionMessage): endpoint_info, active_procs_endpoint, timeout_s=timeout_s ) ) is not None: - _push_status(ProcedureWorkerStatus.RUNNING) + _push_status(ProcedureWorkerStatus.RUNNING, item.execution_id) + helper.notify_watchers(env["queue"], queue_type="execution") logger.debug(f"running task {item!r}") - _run_task(item) + try: + _run_task(item) + except Exception as e: + logger.error(f"Encountered error running procedure {item}") + logger.error(e) + finally: + helper.remove_from_active.by_exec_id(item.execution_id) _push_status(ProcedureWorkerStatus.IDLE) except Exception as e: logger.error(e) # don't stop ProcedureManager.spawn from cleaning up @@ -170,7 +197,7 @@ def _run_task(item: ProcedureExecutionMessage): _push_status(ProcedureWorkerStatus.FINISHED) client.shutdown() if item is not None: # in this case we are here due to an exception, not a timeout - conn.remove_from_set(active_procs_endpoint, item) + helper.remove_from_active.by_exec_id(item.execution_id) if __name__ == "__main__": diff --git a/bec_server/bec_server/scan_server/procedures/helper.py b/bec_server/bec_server/scan_server/procedures/helper.py new file mode 100644 index 000000000..080628fe8 --- /dev/null +++ b/bec_server/bec_server/scan_server/procedures/helper.py @@ -0,0 +1,190 @@ +from typing import Literal + +from bec_lib.endpoints import MessageEndpoints as ME +from bec_lib.logger import bec_logger +from bec_lib.messages import ProcedureAbortMessage as AbrtMsg +from bec_lib.messages import ProcedureClearUnhandledMessage as ClrMsg +from bec_lib.messages import ProcedureExecutionMessage as ExecMsg +from bec_lib.messages import ProcedureQNotifMessage as QNotifMsg +from bec_lib.messages import ProcedureRequestMessage as ReqMsg +from bec_lib.redis_connector import RedisConnector + +logger = bec_logger.logger + + +class _HelperBase: + def __init__(self, conn: RedisConnector) -> None: + self._conn = conn + + +class _Request(_HelperBase): + def procedure(self, msg: ReqMsg): + self._conn.xadd(ME.procedure_request(), msg) + + def abort_execution(self, execution_id: str): + """Send a message requesting an abort of execution_id""" + return self._conn.xadd(ME.procedure_abort(), AbrtMsg(execution_id=execution_id)) + + def abort_queue(self, queue: str): + """Send a message requesting an abort of execution_id""" + return self._conn.xadd(ME.procedure_abort(), AbrtMsg(queue=queue)) + + def abort_all(self): + """Send a message requesting an abort of execution_id""" + return self._conn.xadd(ME.procedure_abort(), AbrtMsg(abort_all=True)) + + def clear_unhandled_execution(self, execution_id: str): + """Send a message requesting an abort of execution_id""" + return self._conn.xadd(ME.procedure_clear_unhandled(), ClrMsg(execution_id=execution_id)) + + def clear_unhandled_queue(self, queue: str): + """Send a message requesting an abort of execution_id""" + return self._conn.xadd(ME.procedure_clear_unhandled(), ClrMsg(queue=queue)) + + def clear_all_unhandled(self): + """Send a message requesting an abort of execution_id""" + return self._conn.xadd(ME.procedure_clear_unhandled(), ClrMsg(abort_all=True)) + + +class _Get(_HelperBase): + def running_procedures(self) -> set[ExecMsg]: + """Get all the running procedures""" + return self._conn.get_set_members(ME.active_procedure_executions()) + + def exec_queue(self, queue: str) -> list[ExecMsg]: + """Get all the ProcedureExecutionMessages from a given execution queue""" + return self._conn.lrange(ME.procedure_execution(queue), 0, -1) + + def unhandled_queue(self, queue: str) -> list[ExecMsg]: + """Get all the ProcedureExecutionMessages from a given unhandled execution queue""" + return self._conn.lrange(ME.unhandled_procedure_execution(queue), 0, -1) + + def active_and_pending_queue_names(self) -> list[str]: + """Get the names of all pending queues and queues of currently running procedures""" + return list(set(self.queue_names()) | set(self.active_queue_names())) + + def active_queue_names(self) -> list[str]: + """Get the names of all queues of currently running procedures""" + return list({msg.queue for msg in self.running_procedures()}) + + def queue_names(self, queue_type: Literal["execution", "unhandled"] = "execution") -> list[str]: + """Get the names of queues currently containing pending ProcedureExecutionMessages + + Args: + queue_type (Literal["execution", "unhandled"]): Type of queue, default "execution" for currently active executions, "unhandled" for aborted executions + """ + ep = ( + ME.procedure_execution + if queue_type == "execution" + else ME.unhandled_procedure_execution + ) + raw: list[str] = [s.decode() for s in self._conn.keys(ep("*"))] + return [s.split("/")[-1] for s in raw] + + +class FrontendProcedureHelper(_HelperBase): + + def __init__(self, conn: RedisConnector) -> None: + super().__init__(conn) + self.request = _Request(conn) + self.get = _Get(conn) + + +class _BackenHelperBase(_HelperBase): + def __init__(self, conn: RedisConnector, parent: "BackendProcedureHelper") -> None: + self._conn = conn + self._parent = parent + + +class _Push(_BackenHelperBase): + def exec(self, queue: str, msg: ExecMsg): + """Push execution message `msg` to execution queue `queue`""" + self._conn.rpush(ME.procedure_execution(queue), msg) + self._parent.notify_watchers(queue, "execution") + + def unhandled(self, queue: str, msg: ExecMsg): + """Push execution message `msg` to unhandled execution queue `queue`""" + self._conn.rpush(ME.unhandled_procedure_execution(queue), msg) + self._parent.notify_watchers(queue, "unhandled") + + +class _Clear(_BackenHelperBase): + def all_unhandled(self): + """Remove all unhandled execution queues""" + for queue in self._parent.get.queue_names("unhandled"): + self.unhandled_queue(queue) + + def unhandled_queue(self, queue: str): + """Remove an unhandled execution queue""" + self._conn.delete(ME.unhandled_procedure_execution(queue)) + self._parent.notify_watchers(queue, "unhandled") + + def unhandled_execution(self, execution_id: str): + """Remove a ProcedureExecutionMessage from its unhandled queue by its execution ID""" + for queue in self._parent.get.queue_names("unhandled"): + for msg in self._parent.get.unhandled_queue(queue): + if msg.execution_id == execution_id: + if self._conn.lrem(ME.unhandled_procedure_execution(msg.queue), 0, msg) > 0: + logger.debug(f"Removed execution {msg} from queue.") + self._parent.notify_watchers(queue, "unhandled") + return + logger.debug(f"Execution {execution_id} not found in any unhandled queue.") + + +class _Move(_BackenHelperBase): + def all_active_to_unhandled(self): + """Move all messages in the active executions set to unhandled""" + for msg in self._parent.get.running_procedures(): + self._parent.push.unhandled(msg.queue, msg) + self._conn.delete(ME.active_procedure_executions()) + + def execution_queue_to_unhandled(self, queue: str): + """Move all messages from execution queue to unhandled execution queue of the same name""" + for msg in self._parent.get.exec_queue(queue): + self._parent.push.unhandled(queue, msg) + self._conn.delete(ME.procedure_execution(queue)) + + def all_execution_queues_to_unhandled(self): + """Move all messages from all execution queues to unhandled execution queues of the same name""" + for queue in self._parent.get.queue_names(): + self.execution_queue_to_unhandled(queue) + + +class _RemoveFromActive(_BackenHelperBase): + def by_exec_id(self, execution_id: str): + """Remove a message from the set of currently active executions""" + for msg in self._conn.get_set_members(ME.active_procedure_executions()): + if msg.execution_id == execution_id: + self._conn.remove_from_set(ME.active_procedure_executions(), msg) + logger.debug(f"removed active procedure {execution_id}") + return + logger.debug(f"No active procedure {execution_id} to remove") + + def by_queue(self, queue: str): + """Remove a message from the set of currently active executions""" + removed = False + for msg in self._conn.get_set_members(ME.active_procedure_executions()): + if msg.queue == queue: + self._conn.remove_from_set(ME.active_procedure_executions(), msg) + logger.debug(f"removed active procedure {msg} with queue {queue}") + if removed: + return + logger.debug(f"No active procedure with queue {queue} to remove") + + +class BackendProcedureHelper(FrontendProcedureHelper): + def __init__(self, conn: RedisConnector) -> None: + super().__init__(conn) + self.push = _Push(conn, self) + self.clear = _Clear(conn, self) + self.move = _Move(conn, self) + self.remove_from_active = _RemoveFromActive(conn, self) + + def notify_watchers(self, queue: str, queue_type: Literal["execution", "unhandled"]): + return self._conn.send( + ME.procedure_queue_notif(), QNotifMsg(queue_name=queue, queue_type=queue_type) + ) + + def notify_all(self, queue_type: Literal["execution", "unhandled"]): + for queue in self.get.queue_names(queue_type): + self.notify_watchers(queue, queue_type) diff --git a/bec_server/bec_server/scan_server/procedures/in_process_worker.py b/bec_server/bec_server/scan_server/procedures/in_process_worker.py index e225a88b4..d56c9eebe 100644 --- a/bec_server/bec_server/scan_server/procedures/in_process_worker.py +++ b/bec_server/bec_server/scan_server/procedures/in_process_worker.py @@ -12,7 +12,7 @@ class InProcessProcedureWorker(ProcedureWorker): """A simple in-process procedure worker. Be careful with this, it should only run trusted code. - Intended for built-in procedures like those to run a single scan.""" + Intended for built-in procedures like those to run a single scan, or testing.""" def _setup_execution_environment(self): from bec_lib.logger import bec_logger @@ -45,3 +45,7 @@ def _kill_process(self): self.logger.info( f"In-process procedure worker for queue {self.key.endpoint} timed out after {self._lifetime_s} s, shutting down" ) + + def abort_execution(self, execution_id: str): + """Abort the execution with the given id""" + ... # No nice way to abort a running function, don't use this class for such things diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index b171a5427..e0b002368 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -4,22 +4,34 @@ from concurrent import futures from concurrent.futures import Future, ThreadPoolExecutor from threading import RLock -from typing import Any, Callable, TypedDict +from typing import Any, Callable, TypedDict, TypeVar + +from pydantic import ValidationError from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.messages import ProcedureRequestMessage, ProcedureWorkerStatus, RequestResponseMessage +from bec_lib.messages import ( + BECMessage, + ProcedureAbortMessage, + ProcedureClearUnhandledMessage, + ProcedureExecutionMessage, + ProcedureRequestMessage, + ProcedureWorkerStatus, + RequestResponseMessage, +) from bec_lib.redis_connector import RedisConnector -from pydantic import ValidationError - from bec_server.scan_server.procedures import procedure_registry from bec_server.scan_server.procedures.constants import PROCEDURE, WorkerAlreadyExists +from bec_server.scan_server.procedures.helper import BackendProcedureHelper from bec_server.scan_server.procedures.worker_base import ProcedureWorker from bec_server.scan_server.scan_server import ScanServer logger = bec_logger.logger +# TODO garbage collect IDs + + class ProcedureWorkerEntry(TypedDict): worker: ProcedureWorker | None future: Future @@ -33,6 +45,15 @@ def _log_on_end(future: Future): logger.success(f"Procedure worker {future} shut down gracefully") +_T = TypeVar("_T", bound=BECMessage) + + +def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T: + if isinstance(msg, dict): + return MsgType.model_validate(msg.get("data")) + return msg + + class ProcedureManager: def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): @@ -42,10 +63,17 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): Args: parent (ScanServer): the scan server to get the Redis server address from. worker_type (type[ProcedureWorker]): which kind of worker to use.""" - self._parent = parent self.lock = RLock() + + logger.success("Initialising procedure manager...") + + self._conn = RedisConnector([self._parent.bootstrap_server]) + self._helper = BackendProcedureHelper(self._conn) + self._startup() + self._active_workers: dict[str, ProcedureWorkerEntry] = {} + self._messages_by_ids: dict[str, ProcedureExecutionMessage] = {} self.executor = ThreadPoolExecutor( max_workers=PROCEDURE.WORKER.MAX_WORKERS, thread_name_prefix="user_procedure_" ) @@ -53,11 +81,26 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): self._callbacks: dict[str, list[Callable[[ProcedureWorker], Any]]] = {} self._worker_cls = worker_type - self._conn = RedisConnector([self._parent.bootstrap_server]) self._reply_endpoint = MessageEndpoints.procedure_request_response() self._server = f"{self._conn.host}:{self._conn.port}" - self._conn.register(MessageEndpoints.procedure_request(), None, self.process_queue_request) + self._conn.register(MessageEndpoints.procedure_abort(), None, self._process_abort) + self._conn.register( + MessageEndpoints.procedure_clear_unhandled(), None, self._process_clear_unhandled + ) + self._conn.register(MessageEndpoints.procedure_request(), None, self._process_queue_request) + logger.success("Done initialising procedure manager.") + + def _startup(self): + # If the server is restarted, clear any pending requests, they'll have to be resubmitted + self._conn.delete(MessageEndpoints.procedure_request()) + previous_queues = self._helper.get.active_and_pending_queue_names() + logger.debug(f"Clearing previous procedure queues {previous_queues}...") + self._helper.move.all_execution_queues_to_unhandled() + self._helper.move.all_active_to_unhandled() + for queue in previous_queues: + self._helper.notify_watchers(queue, "execution") + self._helper.notify_all("unhandled") def _ack(self, accepted: bool, msg: str): logger.info(f"procedure accepted: {accepted}, message: {msg}") @@ -65,9 +108,9 @@ def _ack(self, accepted: bool, msg: str): self._reply_endpoint, RequestResponseMessage(accepted=accepted, message=msg) ) - def _validate_request(self, msg: dict[str, Any]): + def _validate_request(self, msg: dict[str, Any] | ProcedureRequestMessage): try: - message_obj = ProcedureRequestMessage.model_validate(msg) + message_obj = _resolve_dict(msg, ProcedureRequestMessage) if not procedure_registry.is_registered(message_obj.identifier): self._ack( False, @@ -86,13 +129,17 @@ def add_callback(self, queue: str, cb: Callable[[ProcedureWorker], Any]): self._callbacks[queue].append(cb) def _run_callbacks(self, queue: str): - if (worker := self._active_workers[queue]["worker"]) is None: - return + with self.lock: + if queue not in self._active_workers: + logger.error(f"Attempted to run callbacks for nonexistent worker {queue}") + return + if (worker := self._active_workers[queue]["worker"]) is None: + return for cb in self._callbacks.get(queue, []): cb(worker) self._callbacks[queue] = [] - def process_queue_request(self, msg: dict[str, Any]): + def _process_queue_request(self, msg: dict[str, Any] | ProcedureRequestMessage): """Read a `ProcedureRequestMessage` and if it is valid, create a corresponding `ProcedureExecutionMessage`. If there is already a worker for the queue for that request message, add the execution message to that queue, otherwise create a new queue and a new worker. @@ -101,25 +148,22 @@ def process_queue_request(self, msg: dict[str, Any]): msg (dict[str, Any]): dict corresponding to a ProcedureRequestMessage""" logger.debug(f"Procedure manager got request message {msg}") - if (message_obj := self._validate_request(msg)) is None: + if (message := self._validate_request(msg)) is None: return - self._ack(True, f"Running procedure {message_obj.identifier}") - queue = message_obj.queue or PROCEDURE.WORKER.DEFAULT_QUEUE - endpoint = MessageEndpoints.procedure_execution(queue) - logger.debug(f"active workers: {self._active_workers}, worker requested: {queue}") - self._conn.rpush( - endpoint, - endpoint.message_type( - identifier=message_obj.identifier, - queue=queue, - args_kwargs=message_obj.args_kwargs or ((), {}), - ), + self._ack(True, f"Running procedure {message.identifier}") + queue = message.queue or PROCEDURE.WORKER.DEFAULT_QUEUE + exec_message = ProcedureExecutionMessage( + identifier=message.identifier, queue=queue, args_kwargs=message.args_kwargs or ((), {}) ) + logger.debug(f"active workers: {self._active_workers}, worker requested: {queue}") + self._helper.push.exec(queue, exec_message) def cleanup_worker(fut): with self.lock: logger.debug(f"cleaning up worker {fut} for queue {queue}...") + self._helper.remove_from_active.by_queue(queue) self._run_callbacks(queue) + self._helper.notify_watchers(queue, "execution") del self._active_workers[queue] with self.lock: @@ -128,6 +172,69 @@ def cleanup_worker(fut): new_worker.add_done_callback(_log_on_end) new_worker.add_done_callback(cleanup_worker) self._active_workers[queue] = {"worker": None, "future": new_worker} + self._messages_by_ids[exec_message.execution_id] = exec_message + + def _process_abort(self, msg: dict[str, Any] | ProcedureAbortMessage): + message = _resolve_dict(msg, ProcedureAbortMessage) + with self.lock: + if message.abort_all: + self._abort_all() + if message.queue is not None: + self._abort_queue(message.queue) + if message.execution_id is not None: + self._abort_execution(message.execution_id) + + def _abort_execution(self, execution_id: str): + if (msg := self._messages_by_ids.get(execution_id)) is None: + logger.warning(f"Procedure execution with ID {execution_id} not known.") + return + # Remove it from the queue if not yet started + if self._conn.lrem(MessageEndpoints.procedure_execution(msg.queue), 0, msg) > 0: + logger.debug(f"Removed execution {msg} from queue.") + self._helper.notify_watchers(msg.queue, "execution") + # Otherwise try to remove it from whichever worker has it + for entry in self._active_workers.values(): + if (worker := entry["worker"]) is not None and worker: + worker.abort_execution(execution_id) + # Move it to unhandled and stop tracking + self._helper.push.unhandled(msg.queue, msg) + del self._messages_by_ids[execution_id] + + def _abort_queue(self, queue: str): + self._helper.move.execution_queue_to_unhandled(queue) + if (entry := self._active_workers.get(queue)) is not None: + if entry["worker"] is not None: + entry["worker"].abort() + entry["future"].cancel() + futures.wait((entry["future"],), PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S) + else: + logger.warning(f"Received abort request for unknown queue {queue}!") + self._helper.notify_watchers(queue, "execution") + + def _abort_all(self): + for entry in self._active_workers.values(): + if entry["worker"] is not None: + entry["worker"].abort() + for entry in self._active_workers.values(): + entry["future"].cancel() + self._helper.move.all_execution_queues_to_unhandled() + self._wait_for_all_futures() + + def _process_clear_unhandled(self, msg: dict[str, Any] | ProcedureClearUnhandledMessage): + message = _resolve_dict(msg, ProcedureClearUnhandledMessage) + with self.lock: + if message.abort_all: + self._helper.clear.all_unhandled() + if message.queue is not None: + self._helper.clear.unhandled_queue(message.queue) + if message.execution_id is not None: + self._helper.clear.unhandled_execution(message.execution_id) + + def _wait_for_all_futures(self): + futures.wait( + (entry["future"] for entry in self._active_workers.values()), + timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S, + ) def spawn(self, queue: str): """Spawn a procedure worker future which listens to a given queue, i.e. procedure queue list in Redis. @@ -148,7 +255,7 @@ def shutdown(self): """Shutdown the procedure manager. Unregisters from the request endpoint, cancel any procedure workers which haven't started, and abort any which have.""" self._conn.unregister( - MessageEndpoints.procedure_request(), None, self.process_queue_request + MessageEndpoints.procedure_request(), None, self._process_queue_request ) self._conn.shutdown() # cancel futures by hand to give us the opportunity to detatch them from redis if they have started @@ -159,10 +266,7 @@ def shutdown(self): if worker := entry["worker"]: # redis unblock executor.client_id worker.abort() - futures.wait( - (entry["future"] for entry in self._active_workers.values()), - timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S, - ) + self._wait_for_all_futures() self.executor.shutdown() def active_workers(self) -> list[str]: diff --git a/bec_server/bec_server/scan_server/procedures/protocol.py b/bec_server/bec_server/scan_server/procedures/protocol.py index 6d79b80a6..06d40c818 100644 --- a/bec_server/bec_server/scan_server/procedures/protocol.py +++ b/bec_server/bec_server/scan_server/procedures/protocol.py @@ -11,7 +11,6 @@ class VolumeSpec(TypedDict): class ContainerCommandOutput(Protocol): - def pretty_print(self) -> str: ... @@ -30,6 +29,7 @@ def run( volumes: list[VolumeSpec], command: str, pod_name: str | None = None, + container_name: str | None = None, ) -> str: ... def kill(self, id: str): ... def logs(self, id: str) -> list[str]: ... diff --git a/bec_server/bec_server/scan_server/procedures/worker_base.py b/bec_server/bec_server/scan_server/procedures/worker_base.py index d422d6e2b..dd5060f73 100644 --- a/bec_server/bec_server/scan_server/procedures/worker_base.py +++ b/bec_server/bec_server/scan_server/procedures/worker_base.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from threading import Event from typing import cast from bec_lib.endpoints import MessageEndpoints @@ -8,6 +9,7 @@ from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus from bec_lib.redis_connector import RedisConnector from bec_server.scan_server.procedures.constants import PROCEDURE +from bec_server.scan_server.procedures.helper import BackendProcedureHelper logger = bec_logger.logger @@ -35,8 +37,11 @@ def __init__(self, server: str, queue: str, lifetime_s: float | None = None): self._active_procs_endpoint = MessageEndpoints.active_procedure_executions() self.status = ProcedureWorkerStatus.IDLE self._conn = RedisConnector([server]) + self._helper = BackendProcedureHelper(self._conn) self._lifetime_s = lifetime_s or PROCEDURE.WORKER.QUEUE_TIMEOUT_S self.client_id = self._conn.client_id() + self._current_execution_id: str | None = None + self._aborted = Event() self._setup_execution_environment() @@ -61,8 +66,15 @@ def _run_task(self, item: ProcedureExecutionMessage): def _setup_execution_environment(self): ... def abort(self): + """Abort the entire worker""" + self._aborted.set() self._kill_process() + @abstractmethod + def abort_execution(self, execution_id: str): + """Abort the execution with the given id""" + ... + def __exit__(self, exc_type, exc_val, exc_tb): self._kill_process() @@ -75,9 +87,10 @@ def work(self): self.key, self._active_procs_endpoint, timeout_s=self._lifetime_s ) ) is not None: - self.status = ProcedureWorkerStatus.RUNNING - self._run_task(cast(ProcedureExecutionMessage, item)) - self.status = ProcedureWorkerStatus.IDLE + if not self._aborted.is_set(): + self.status = ProcedureWorkerStatus.RUNNING + self._run_task(cast(ProcedureExecutionMessage, item)) + self.status = ProcedureWorkerStatus.IDLE except Exception as e: logger.error(e) # don't stop ProcedureManager.spawn from cleaning up finally: diff --git a/bec_server/tests/tests_scan_server/test_container_utils.py b/bec_server/tests/tests_scan_server/test_container_utils.py index 55fb81408..31c948f6b 100644 --- a/bec_server/tests/tests_scan_server/test_container_utils.py +++ b/bec_server/tests/tests_scan_server/test_container_utils.py @@ -138,6 +138,7 @@ def test_api_utils_run(api_utils: tuple[PodmanApiUtils, MagicMock]): environment={"a": "b"}, mounts=[{"source": "a", "target": "b", "read_only": True, "type": "bind"}], pod=None, + name=None, ) @@ -185,7 +186,7 @@ def test_cli_kill(cli_utils: tuple[PodmanCliUtils, MagicMock]): utils, run_mock = cli_utils run_mock.reset_mock() utils.kill("test") - assert run_mock.call_count == 2 + assert run_mock.call_count == 1 def test_cli_get_state(cli_utils_with_fake_container_json: PodmanCliUtils): diff --git a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py index b80ffabc9..a0de0aa4d 100644 --- a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py +++ b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py @@ -45,20 +45,20 @@ def test_container_worker_work(redis_mock, logger_mock): redis_mock().host = "server" redis_mock().port = "port" + msgs = [ + ProcedureWorkerStatusMessage( + worker_queue="test_queue", + status=ProcedureWorkerStatus.RUNNING, + current_execution_id="test", + ), + ProcedureWorkerStatusMessage( + worker_queue="test_queue", status=ProcedureWorkerStatus.FINISHED + ), + ] + def _mock_pop(): - msgs = [ - ProcedureWorkerStatusMessage( - worker_queue="test_queue", status=ProcedureWorkerStatus.RUNNING - ), - ProcedureWorkerStatusMessage( - worker_queue="test_queue", status=ProcedureWorkerStatus.FINISHED - ), - ] - for msg in msgs: - sleep(0.05) - yield msg + yield from msgs while True: - sleep(0.05) yield from repeat(None) mock_pop = _mock_pop() @@ -74,7 +74,7 @@ def cleanup(): t.start() start = time.monotonic() - while time.monotonic() < start + 5: + while time.monotonic() < start + 1000: try: assert ( logger_mock.info.call_args_list[0].args[0] @@ -125,6 +125,7 @@ def __init__(self, name: str, args: tuple = (), kwargs: dict = {}): self.identifier = name self.args = args self.kwargs = kwargs + self.execution_id = f"test_{name}" def __repr__(self) -> str: return self.identifier diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py index 694b7145d..1eee31e31 100644 --- a/bec_server/tests/tests_scan_server/test_procedures.py +++ b/bec_server/tests/tests_scan_server/test_procedures.py @@ -1,6 +1,7 @@ import threading import time from functools import partial +from itertools import starmap from typing import Any, Callable from unittest.mock import MagicMock, patch @@ -8,6 +9,7 @@ import pytest from bec_lib.client import BECClient, RedisConnector +from bec_lib.endpoints import MessageEndpoints from bec_lib.messages import ( ProcedureExecutionMessage, ProcedureRequestMessage, @@ -15,6 +17,7 @@ RequestResponseMessage, ) from bec_lib.serialization import MsgpackSerialization +from bec_lib.service_config import ServiceConfig from bec_server.scan_server.procedures.constants import PROCEDURE, BecProcedure, WorkerAlreadyExists from bec_server.scan_server.procedures.in_process_worker import InProcessProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager, ProcedureWorker @@ -32,11 +35,16 @@ LOG_MSG_PROC_NAME = "log execution message args" +FAKEREDIS_HOST = "127.0.0.1" +FAKEREDIS_PORT = 6380 @pytest.fixture(autouse=True) def shutdown_client(): - bec_client = BECClient() + bec_client = BECClient( + config=ServiceConfig(config={"redis": {"host": FAKEREDIS_HOST, "port": FAKEREDIS_PORT}}), + connector_cls=partial(RedisConnector, redis_cls=fakeredis.FakeRedis), + ) bec_client.start() yield bec_client.shutdown() @@ -45,7 +53,7 @@ def shutdown_client(): @pytest.fixture def procedure_manager(): server = MagicMock() - server.bootstrap_server = "localhost:1" + server.bootstrap_server = f"{FAKEREDIS_HOST}:{FAKEREDIS_PORT}" with patch( "bec_server.scan_server.procedures.manager.RedisConnector", partial(RedisConnector, redis_cls=fakeredis.FakeRedis), # type: ignore @@ -104,7 +112,7 @@ def process_request_manager(procedure_manager: ProcedureManager): @pytest.mark.parametrize("message", PROCESS_REQUEST_TEST_CASES) def test_process_request_happy_paths(process_request_manager, message: ProcedureRequestMessage): - process_request_manager.process_queue_request(message) + process_request_manager._process_queue_request(message) process_request_manager._ack.assert_called_with(True, f"Running procedure {message.identifier}") process_request_manager._conn.rpush.assert_called() endpoint, execution_msg = process_request_manager._conn.rpush.call_args.args @@ -116,7 +124,7 @@ def test_process_request_happy_paths(process_request_manager, message: Procedure def test_process_request_failure(process_request_manager): - process_request_manager.process_queue_request(None) + process_request_manager._process_queue_request(None) process_request_manager._ack.assert_not_called() process_request_manager._conn.rpush.assert_not_called() process_request_manager.spawn.assert_not_called() @@ -131,6 +139,7 @@ def __init__(self, server: str, queue: str, lifetime_s: int | None = None): self.event_1 = threading.Event() self.event_2 = threading.Event() + def abort_execution(self, execution_id: str): ... def _setup_execution_environment(self): ... def _kill_process(self): ... def _run_task(self, item): @@ -161,7 +170,7 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): queue = message.queue or PROCEDURE.WORKER.DEFAULT_QUEUE procedure_manager._validate_request = MagicMock(side_effect=lambda msg: msg) # trigger the running of the test message - procedure_manager.process_queue_request(message) # type: ignore + procedure_manager._process_queue_request(message) # type: ignore assert queue in procedure_manager._active_workers.keys() # spawn method should be added as a future @@ -259,7 +268,7 @@ def test_callable_from_message(): def test_register_rejects_wrong_type(): with pytest.raises(ProcedureRegistryError) as e: - register("test", "test") + register("test", "test") # type: ignore assert e.match("not a valid procedure") @@ -284,7 +293,7 @@ def _yield_once(): def test_manager_status_api(_conn, procedure_manager): procedure_manager._worker_cls = UnlockableWorker for message in PROCESS_REQUEST_TEST_CASES: - procedure_manager.process_queue_request(message) + procedure_manager._process_queue_request(message) _wait_until(lambda: procedure_manager.active_workers() == ["primary", "queue2"]) _wait_until( lambda: procedure_manager.worker_statuses() @@ -299,3 +308,130 @@ def test_manager_status_api(_conn, procedure_manager): for w in procedure_manager._active_workers.values(): w["worker"].event_2.set() _wait_until(lambda: procedure_manager.active_workers() == []) + + +_ManagerWithMsgs = tuple[ProcedureManager, list[ProcedureExecutionMessage]] + + +@pytest.fixture +def manager_with_test_msgs(procedure_manager: ProcedureManager): + procedure_manager._conn._redis_conn.flushdb() + contents = [ + ("test_identifier_1", "queue1", ((), {})), + ("test_identifier_2", "queue1", ((), {})), + ("test_identifier_1", "queue2", ((), {})), + ("test_identifier_2", "queue2", ((), {})), + ] + msgs = iter( + ProcedureRequestMessage(identifier=c[0], queue=c[1], args_kwargs=c[2]) for c in contents + ) + procedure_manager._validate_request = lambda msg: next(msgs) + for _ in range(len(contents)): + procedure_manager._process_queue_request({}) + return ( + procedure_manager, + [ + ProcedureExecutionMessage(metadata={}, identifier=c[0], queue=c[1], args_kwargs=c[2]) + for c in contents + ], + ) + + +def _eq_except_id(a: ProcedureExecutionMessage, b: ProcedureExecutionMessage): + return a.identifier == b.identifier and a.queue == b.queue and a.args_kwargs == b.args_kwargs + + +def _all_eq_except_id(a: list[ProcedureExecutionMessage], b: list[ProcedureExecutionMessage]): + if len(a) != len(b): + return False + return all(starmap(_eq_except_id, zip(a, b))) + + +@pytest.mark.parametrize("queue", ["queue1", "queue2"]) +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_startup(manager_with_test_msgs: _ManagerWithMsgs, queue: str): + procedure_manager, expected = manager_with_test_msgs + queue_expected = list(filter(lambda msg: msg.queue == queue, expected)) + + execution_list = procedure_manager._helper.get.exec_queue(queue) + assert _all_eq_except_id(execution_list, queue_expected) + + procedure_manager._startup() + + # on startup, the manager should move active queues to unhandled queues + execution_list = procedure_manager._helper.get.exec_queue(queue) + unhandled_execution_list = procedure_manager._conn.lrange( + MessageEndpoints.unhandled_procedure_execution(queue), 0, -1 + ) + assert execution_list == [] + assert _all_eq_except_id(unhandled_execution_list, queue_expected) + + +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_abort_queue(manager_with_test_msgs: _ManagerWithMsgs): + procedure_manager, expected = manager_with_test_msgs + remaining_expected = list(filter(lambda msg: msg.queue == "queue2", expected)) + aborted_expected = list(filter(lambda msg: msg.queue == "queue1", expected)) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + assert _all_eq_except_id(q1_execution_list, aborted_expected) + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q2_execution_list, remaining_expected) + + procedure_manager._process_abort({"queue": "queue1"}) + + # on abort, the manager should move active queues to unhandled queues + # this should happen for q1 and not q2 + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + unhandled_execution_list = procedure_manager._conn.lrange( + MessageEndpoints.unhandled_procedure_execution("queue1"), 0, -1 + ) + assert _all_eq_except_id(q2_execution_list, remaining_expected) + assert _all_eq_except_id(unhandled_execution_list, aborted_expected) + + +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_abort_individual(manager_with_test_msgs: _ManagerWithMsgs): + procedure_manager, expected = manager_with_test_msgs + q1_expected = list(filter(lambda msg: msg.queue == "queue1", expected)) + q2_expected = list(filter(lambda msg: msg.queue == "queue2", expected)) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + assert _all_eq_except_id( + q1_execution_list, list(filter(lambda msg: msg.queue == "queue1", q1_expected)) + ) + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q2_execution_list, q2_expected) + + procedure_manager._process_abort({"execution_id": q2_execution_list[1].execution_id}) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q1_execution_list, q1_expected) + assert _all_eq_except_id(q2_execution_list, [q2_expected[0]]) + + +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_abort_all(manager_with_test_msgs: _ManagerWithMsgs): + procedure_manager, expected = manager_with_test_msgs + q1_expected = list(filter(lambda msg: msg.queue == "queue1", expected)) + q2_expected = list(filter(lambda msg: msg.queue == "queue2", expected)) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + assert _all_eq_except_id( + q1_execution_list, list(filter(lambda msg: msg.queue == "queue1", q1_expected)) + ) + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q2_execution_list, q2_expected) + + procedure_manager._process_abort({"abort_all": True}) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + q1_unhandled_list = procedure_manager._helper.get.unhandled_queue("queue1") + q2_unhandled_list = procedure_manager._helper.get.unhandled_queue("queue2") + + assert q1_execution_list == [] + assert q2_execution_list == [] + assert _all_eq_except_id(q1_unhandled_list, q1_expected) + assert _all_eq_except_id(q2_unhandled_list, q2_expected) From 77ce5b7c9dfd844dfe6e56e516253340cc6dfeee Mon Sep 17 00:00:00 2001 From: David Perl Date: Fri, 31 Oct 2025 14:33:36 +0100 Subject: [PATCH 06/10] feat: push container worker stdout to redis --- bec_lib/bec_lib/endpoints.py | 25 +++++-- bec_lib/bec_lib/logger.py | 13 ++-- bec_lib/bec_lib/redis_connector.py | 5 +- .../procedures/container_worker.py | 67 +++++++++++++++---- .../scan_server/procedures/helper.py | 19 ++++-- .../scan_server/procedures/manager.py | 2 +- .../test_procedure_container_worker.py | 4 +- 7 files changed, 99 insertions(+), 36 deletions(-) diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index d554c452a..989be3c99 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -1460,7 +1460,7 @@ def procedure_execution(queue_id: str): Returns: EndpointInfo: Endpoint for scan queue request. """ - endpoint = f"{EndpointType.INTERNAL.value}/procedures/procedure_execution/{queue_id}" + endpoint = f"{EndpointType.INFO.value}/procedures/procedure_execution/{queue_id}" return EndpointInfo( endpoint=endpoint, message_type=messages.ProcedureExecutionMessage, @@ -1477,9 +1477,7 @@ def unhandled_procedure_execution(queue_id: str): Returns: EndpointInfo: Endpoint for scan queue request. """ - endpoint = ( - f"{EndpointType.INTERNAL.value}/procedures/unhandled_procedure_execution/{queue_id}" - ) + endpoint = f"{EndpointType.INFO.value}/procedures/unhandled_procedure_execution/{queue_id}" return EndpointInfo( endpoint=endpoint, message_type=messages.ProcedureExecutionMessage, @@ -1510,7 +1508,7 @@ def procedure_abort(): Returns: EndpointInfo: Endpoint for set of active procedure executions. """ - endpoint = f"{EndpointType.INFO.value}/procedures/abort" + endpoint = f"{EndpointType.USER.value}/procedures/abort" return EndpointInfo( endpoint=endpoint, message_type=messages.ProcedureAbortMessage, @@ -1520,12 +1518,12 @@ def procedure_abort(): @staticmethod def procedure_clear_unhandled(): """ - Endpoint to request aborting a running procedure + Endpoint to request removing an aborted procedure Returns: EndpointInfo: Endpoint for set of active procedure executions. """ - endpoint = f"{EndpointType.INFO.value}/procedures/clear_unhandled" + endpoint = f"{EndpointType.USER.value}/procedures/clear_unhandled" return EndpointInfo( endpoint=endpoint, message_type=messages.ProcedureClearUnhandledMessage, @@ -1562,6 +1560,19 @@ def procedure_queue_notif(): message_op=MessageOp.SEND, ) + @staticmethod + def procedure_logs(queue: str): + """ + Endpoint for logs for a given procedure queue + + Returns: + EndpointInfo: Endpoint for procedure queue updates for given queue. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/logs/{queue}" + return EndpointInfo( + endpoint=endpoint, message_type=messages.RawMessage, message_op=MessageOp.STREAM + ) + @staticmethod def gui_registry_state(gui_id: str): """ diff --git a/bec_lib/bec_lib/logger.py b/bec_lib/bec_lib/logger.py index 3e02f305a..201efb3a8 100644 --- a/bec_lib/bec_lib/logger.py +++ b/bec_lib/bec_lib/logger.py @@ -64,6 +64,7 @@ class BECLogger: "{service_name} | {{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} |" " {{thread.name}} ({{thread.id}}) | {{extra[stack]}} - {{message}}\n" ) + CONTAINER_FORMAT = "{{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} | {{message}}\n" LOGLEVEL = LogLevel _logger = None @@ -224,18 +225,21 @@ def _logger_callback(self, msg): # because it depends on the connector pass - def get_format(self, level: LogLevel = None, is_stderr=False) -> str: + def get_format(self, level: LogLevel = None, is_stderr=False, is_container=False) -> str: """ Get the format for a specific log level. Args: level (LogLevel, optional): Log level. Defaults to None. If None, the current log level will be used. is_stderr (bool, optional): Whether the log is for stderr. Defaults to False. + is_container (bool, optional): Simple logging for procedure container. Defaults to False. Returns: str: Log format. """ service_name = self.service_name if self.service_name else "" + if is_container: + return self.CONTAINER_FORMAT.format() if level is None: level = self.level if level > self.LOGLEVEL.DEBUG: @@ -246,15 +250,16 @@ def get_format(self, level: LogLevel = None, is_stderr=False) -> str: return self.DEBUG_FORMAT.format(service_name=service_name) return self.TRACE_FORMAT.format(service_name=service_name) - def formatting(self, is_stderr=False): + def formatting(self, is_stderr=False, is_container=False): """ Format the log message. Args: record (dict): Log record. + is_container (bool, optional): Simple logging for procedure container. Defaults to False. Returns: - dict: Formatted log record. + str: Log format. """ def _update_record(record): @@ -269,7 +274,7 @@ def _update_record(record): def _format(record): level = _update_record(record) - return self.get_format(level) + return self.get_format(level, is_container=is_container) def _format_stderr(record): level = _update_record(record) diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index 0eb336bd1..8d8764f1e 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1179,7 +1179,7 @@ def mget(self, topics: list[str], pipe: Pipeline | None = None): def xadd( self, topic: str, - msg_dict: dict | BECMessage, + msg_dict: dict, max_size=None, pipe: Pipeline | None = None, expire: int | None = None, @@ -1205,8 +1205,7 @@ def xadd( else: client = self._redis_conn - msg = msg_dict.model_dump() if isinstance(msg_dict, BECMessage) else msg_dict - msg_dict = {key: MsgpackSerialization.dumps(val) for key, val in msg.items()} + msg_dict = {key: MsgpackSerialization.dumps(val) for key, val in msg_dict.items()} if max_size: client.xadd(topic, msg_dict, maxlen=max_size) diff --git a/bec_server/bec_server/scan_server/procedures/container_worker.py b/bec_server/bec_server/scan_server/procedures/container_worker.py index 53cc109d9..f75917bf1 100644 --- a/bec_server/bec_server/scan_server/procedures/container_worker.py +++ b/bec_server/bec_server/scan_server/procedures/container_worker.py @@ -1,10 +1,14 @@ import os +import sys +from contextlib import redirect_stdout +from typing import AnyStr, TextIO +from bec_ipython_client.main import BECIPythonClient from bec_lib import messages from bec_lib.client import BECClient from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import LogLevel, bec_logger -from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus +from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus, RawMessage from bec_lib.redis_connector import RedisConnector from bec_lib.service_config import ServiceConfig from bec_server.scan_server.procedures import procedure_registry @@ -22,6 +26,28 @@ logger = bec_logger.logger +class RedisOutputDiverter(TextIO): + def __init__(self, conn: RedisConnector, queue: str): + + self._conn = conn + self._ep = MessageEndpoints.procedure_logs(queue) + self._conn.delete(self._ep) + + def write(self, data: AnyStr): + if data: + self._conn.xadd(self._ep, {"data": RawMessage(data=str(data))}) + return len(data) + + def flush(self): ... + + @property + def encoding(self): + return "utf-8" + + def close(self): + return + + class ContainerProcedureWorker(ProcedureWorker): """A worker which runs scripts in a container with a full BEC environment, mounted from the filesystem, and only access to Redis""" @@ -112,8 +138,7 @@ def abort_execution(self, execution_id: str): self._setup_execution_environment() -def main(): - """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command""" +def _setup(): logger.info("Container worker starting up") try: needed_keys = ContainerWorkerEnv.__annotations__.keys() @@ -121,7 +146,7 @@ def main(): env: ContainerWorkerEnv = {k: os.environ[k] for k in needed_keys} # type: ignore except KeyError as e: logger.error(f"Missing environment variable needed by container worker: {e}") - return + exit(1) bec_logger.level = LogLevel.DEBUG bec_logger._console_log = True @@ -134,20 +159,27 @@ def main(): host, port = env["redis_server"].split(":") redis = {"host": host, "port": port} - client = BECClient(config=ServiceConfig(redis=redis)) + + client = BECIPythonClient(config=ServiceConfig(redis=redis)) + logger.debug("starting client") client.start() logger.info(f"ContainerWorker started container for queue {env['queue']}") logger.debug(f"ContainerWorker environment: {env}") - endpoint_info = MessageEndpoints.procedure_execution(env["queue"]) conn = RedisConnector(env["redis_server"]) + logger.debug(f"ContainerWorker {env['queue']} connected to Redis at {conn.host}:{conn.port}") helper = BackendProcedureHelper(conn) + + return env, helper, client, conn + + +def _main(env, helper, client, conn): + + exec_endpoint = MessageEndpoints.procedure_execution(env["queue"]) active_procs_endpoint = MessageEndpoints.active_procedure_executions() status_endpoint = MessageEndpoints.procedure_worker_status_update(env["queue"]) - logger.debug(f"ContainerWorker connecting to Redis at {conn.host}:{conn.port}") - try: timeout_s = int(env["timeout_s"]) except ValueError as e: @@ -173,10 +205,10 @@ def _run_task(item: ProcedureExecutionMessage): _push_status(ProcedureWorkerStatus.IDLE) item = None try: - logger.debug(f"ContainerWorker waiting for instructions on {endpoint_info}") + logger.debug(f"ContainerWorker waiting for instructions on {exec_endpoint}") while ( item := conn.blocking_list_pop_to_set_add( - endpoint_info, active_procs_endpoint, timeout_s=timeout_s + exec_endpoint, active_procs_endpoint, timeout_s=timeout_s ) ) is not None: _push_status(ProcedureWorkerStatus.RUNNING, item.execution_id) @@ -200,5 +232,16 @@ def _run_task(item: ProcedureExecutionMessage): helper.remove_from_active.by_exec_id(item.execution_id) -if __name__ == "__main__": - main() +def main(): + """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command""" + + env, helper, client, conn = _setup() + output_diverter = RedisOutputDiverter(conn, env["queue"]) + with redirect_stdout(output_diverter): + logger.add( + output_diverter, + level=LogLevel.SUCCESS, + format=bec_logger.formatting(is_container=True), + filter=bec_logger.filter(), + ) + _main(env, helper, client, conn) diff --git a/bec_server/bec_server/scan_server/procedures/helper.py b/bec_server/bec_server/scan_server/procedures/helper.py index 080628fe8..b68bafa31 100644 --- a/bec_server/bec_server/scan_server/procedures/helper.py +++ b/bec_server/bec_server/scan_server/procedures/helper.py @@ -1,7 +1,9 @@ from typing import Literal +from bec_lib.endpoints import EndpointInfo from bec_lib.endpoints import MessageEndpoints as ME from bec_lib.logger import bec_logger +from bec_lib.messages import BECMessage from bec_lib.messages import ProcedureAbortMessage as AbrtMsg from bec_lib.messages import ProcedureClearUnhandledMessage as ClrMsg from bec_lib.messages import ProcedureExecutionMessage as ExecMsg @@ -18,32 +20,35 @@ def __init__(self, conn: RedisConnector) -> None: class _Request(_HelperBase): + def _xadd(self, ep: EndpointInfo, msg: BECMessage): + self._conn.xadd(ep, msg.model_dump()) + def procedure(self, msg: ReqMsg): - self._conn.xadd(ME.procedure_request(), msg) + self._xadd(ME.procedure_request(), msg) def abort_execution(self, execution_id: str): """Send a message requesting an abort of execution_id""" - return self._conn.xadd(ME.procedure_abort(), AbrtMsg(execution_id=execution_id)) + return self._xadd(ME.procedure_abort(), AbrtMsg(execution_id=execution_id)) def abort_queue(self, queue: str): """Send a message requesting an abort of execution_id""" - return self._conn.xadd(ME.procedure_abort(), AbrtMsg(queue=queue)) + return self._xadd(ME.procedure_abort(), AbrtMsg(queue=queue)) def abort_all(self): """Send a message requesting an abort of execution_id""" - return self._conn.xadd(ME.procedure_abort(), AbrtMsg(abort_all=True)) + return self._xadd(ME.procedure_abort(), AbrtMsg(abort_all=True)) def clear_unhandled_execution(self, execution_id: str): """Send a message requesting an abort of execution_id""" - return self._conn.xadd(ME.procedure_clear_unhandled(), ClrMsg(execution_id=execution_id)) + return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(execution_id=execution_id)) def clear_unhandled_queue(self, queue: str): """Send a message requesting an abort of execution_id""" - return self._conn.xadd(ME.procedure_clear_unhandled(), ClrMsg(queue=queue)) + return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(queue=queue)) def clear_all_unhandled(self): """Send a message requesting an abort of execution_id""" - return self._conn.xadd(ME.procedure_clear_unhandled(), ClrMsg(abort_all=True)) + return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(abort_all=True)) class _Get(_HelperBase): diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index e0b002368..1b913699b 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -50,7 +50,7 @@ def _log_on_end(future: Future): def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T: if isinstance(msg, dict): - return MsgType.model_validate(msg.get("data")) + return MsgType.model_validate(msg) return msg diff --git a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py index a0de0aa4d..5b686a04a 100644 --- a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py +++ b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py @@ -29,7 +29,7 @@ def test_container_worker_init(redis_mock): redis_mock().port = "port" worker = ContainerProcedureWorker(server="server:port", queue="test_queue", lifetime_s=1) assert worker._worker_environment() == { - "redis_server": "server:port", + "redis_server": "redis:port", "queue": "test_queue", "timeout_s": "1", } @@ -100,7 +100,7 @@ def cleanup(): @patch("bec_server.scan_server.procedures.container_worker.logger") def test_main_exits_without_env_variables(logger_mock): - with patch.dict(os.environ, clear=True): + with patch.dict(os.environ, clear=True), pytest.raises(SystemExit): container_worker_main() assert "Missing environment variable " in logger_mock.error.call_args.args[0] From 7cdb5d066734c168953b21899f47c0b0ddb3141d Mon Sep 17 00:00:00 2001 From: David Perl Date: Tue, 4 Nov 2025 20:57:47 +0100 Subject: [PATCH 07/10] feat: add procedure to run IDE script --- .github/actions/bec_e2e_install/action.yml | 2 +- .../tests/end-2-end/test_procedures_e2e.py | 9 ++++++--- .../scan_server/procedures/builtin_procedures.py | 4 ++++ .../scan_server/procedures/container_worker.py | 16 +++++++++++++--- .../bec_server/scan_server/procedures/manager.py | 3 +++ .../scan_server/procedures/procedure_registry.py | 4 +++- .../scan_server/procedures/worker_base.py | 3 +++ 7 files changed, 33 insertions(+), 8 deletions(-) diff --git a/.github/actions/bec_e2e_install/action.yml b/.github/actions/bec_e2e_install/action.yml index 8a0c3bd61..ffa9b0c15 100644 --- a/.github/actions/bec_e2e_install/action.yml +++ b/.github/actions/bec_e2e_install/action.yml @@ -62,7 +62,7 @@ runs: cd ./_e2e_test_checkout_/bec source ./bin/install_bec_dev.sh -t pip install -e ../ophyd_devices - podman pod create --net host local_bec + podman pod create --add-host redis:host-gateway local_bec python ./bec_ipython_client/tests/end-2-end/_ensure_requirements_container.py pytest -v --files-path ./ --start-servers --random-order ./bec_ipython_client/tests/end-2-end/ diff --git a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py index 370d83792..5175de066 100644 --- a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py @@ -74,7 +74,7 @@ def cb(worker: ContainerProcedureWorker): logs = worker._backend.logs(worker._container_id) manager.add_callback("test", cb) - client.connector.xadd(topic=endpoint, msg_dict=msg) + client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump()) _wait_while(lambda: manager._active_workers == {}, 5) _wait_while(lambda: manager._active_workers != {}, 20) @@ -96,14 +96,17 @@ def test_happy_path_container_procedure_runner( msg = messages.ProcedureRequestMessage( identifier="log execution message args", args_kwargs=(test_args, test_kwargs) ) - conn.xadd(topic=endpoint, msg_dict=msg) + conn.xadd(topic=endpoint, msg_dict=msg.model_dump()) _wait_while(lambda: manager._active_workers == {}, 5) _wait_while(lambda: manager._active_workers != {}, 20) logtool.fetch() assert logtool.is_present_in_any_message("procedure accepted: True, message:") - assert logtool.is_present_in_any_message("ContainerWorker started container for queue primary") + assert "test string" in "\n".join(manager._logs) + assert logtool.is_present_in_any_message( + "ContainerWorker started container for queue primary" + ), f"Log content relating to procedures: {manager}" res, msg = logtool.are_present_in_order( [ "Container worker 'primary' status update: IDLE", diff --git a/bec_server/bec_server/scan_server/procedures/builtin_procedures.py b/bec_server/bec_server/scan_server/procedures/builtin_procedures.py index 863b76544..cf708797e 100644 --- a/bec_server/bec_server/scan_server/procedures/builtin_procedures.py +++ b/bec_server/bec_server/scan_server/procedures/builtin_procedures.py @@ -31,3 +31,7 @@ def run_scan(scan_name: str, args: tuple, parameters: dict, *, bec: BECClient): raise ValueError(f"Scan {scan_name} doesn't exist in this client!") scan_report: ScanReport = scan(*args, **parameters) scan_report.wait() + + +def run_script(script_id: str, *, bec: BECClient): + bec._run_script(script_id) diff --git a/bec_server/bec_server/scan_server/procedures/container_worker.py b/bec_server/bec_server/scan_server/procedures/container_worker.py index f75917bf1..47965eb59 100644 --- a/bec_server/bec_server/scan_server/procedures/container_worker.py +++ b/bec_server/bec_server/scan_server/procedures/container_worker.py @@ -1,3 +1,4 @@ +import inspect import os import sys from contextlib import redirect_stdout @@ -137,6 +138,11 @@ def abort_execution(self, execution_id: str): ) self._setup_execution_environment() + def logs(self): + if self._container_id is None: + return [""] + return self._backend.logs(self._container_id) + def _setup(): logger.info("Container worker starting up") @@ -198,9 +204,13 @@ def _push_status(status: ProcedureWorkerStatus, id: str | None = None): ) def _run_task(item: ProcedureExecutionMessage): - procedure_registry.callable_from_execution_message(item)( - *item.args_kwargs[0], **item.args_kwargs[1] - ) + kwargs = item.args_kwargs[1] + proc_func = procedure_registry.callable_from_execution_message(item) + if bec_arg := inspect.signature(proc_func).parameters.get("bec"): + if bec_arg.kind == bec_arg.KEYWORD_ONLY and bec_arg.annotation.__name__ == "BECClient": + logger.debug(f"Injecting BEC client argument for {item}") + kwargs["bec"] = client + procedure_registry.callable_from_execution_message(item)(*item.args_kwargs[0], **kwargs) _push_status(ProcedureWorkerStatus.IDLE) item = None diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index 1b913699b..faa401ac7 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import atexit +from collections import deque from concurrent import futures from concurrent.futures import Future, ThreadPoolExecutor from threading import RLock @@ -68,6 +69,7 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): logger.success("Initialising procedure manager...") + self._logs = deque([], maxlen=1000) self._conn = RedisConnector([self._parent.bootstrap_server]) self._helper = BackendProcedureHelper(self._conn) self._startup() @@ -250,6 +252,7 @@ def spawn(self, queue: str): with self.lock: self._active_workers[queue]["worker"] = worker worker.work() + self._logs.extend(worker.logs()) def shutdown(self): """Shutdown the procedure manager. Unregisters from the request endpoint, cancel any diff --git a/bec_server/bec_server/scan_server/procedures/procedure_registry.py b/bec_server/bec_server/scan_server/procedures/procedure_registry.py index 8a2582f56..4ed3ff114 100644 --- a/bec_server/bec_server/scan_server/procedures/procedure_registry.py +++ b/bec_server/bec_server/scan_server/procedures/procedure_registry.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Iterable, Iterator +from typing import Iterable from bec_lib.messages import ProcedureExecutionMessage from bec_server.scan_server.procedures.builtin_procedures import ( log_message_args_kwargs, run_scan, + run_script, sleep, ) from bec_server.scan_server.procedures.constants import BecProcedure @@ -12,6 +13,7 @@ "log execution message args": log_message_args_kwargs, "run scan": run_scan, "sleep": sleep, + "run_script": run_script, } _PROCEDURE_REGISTRY: dict[str, BecProcedure] = {} | _BUILTIN_PROCEDURES diff --git a/bec_server/bec_server/scan_server/procedures/worker_base.py b/bec_server/bec_server/scan_server/procedures/worker_base.py index dd5060f73..7b07256db 100644 --- a/bec_server/bec_server/scan_server/procedures/worker_base.py +++ b/bec_server/bec_server/scan_server/procedures/worker_base.py @@ -65,6 +65,9 @@ def _run_task(self, item: ProcedureExecutionMessage): @abstractmethod def _setup_execution_environment(self): ... + def logs(self) -> list[str]: + return [""] + def abort(self): """Abort the entire worker""" self._aborted.set() From 84fcb7c726e54424ae486d38710b9164b2764648 Mon Sep 17 00:00:00 2001 From: perl_d Date: Wed, 12 Nov 2025 10:37:49 +0100 Subject: [PATCH 08/10] fix: redis connection in e2e tests --- .github/actions/bec_e2e_install/action.yml | 2 +- .../tests/end-2-end/test_procedures_e2e.py | 27 ++++++++++++------- .../scan_server/procedures/constants.py | 1 + .../procedures/container_worker.py | 2 +- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/.github/actions/bec_e2e_install/action.yml b/.github/actions/bec_e2e_install/action.yml index ffa9b0c15..8f93f393b 100644 --- a/.github/actions/bec_e2e_install/action.yml +++ b/.github/actions/bec_e2e_install/action.yml @@ -62,7 +62,7 @@ runs: cd ./_e2e_test_checkout_/bec source ./bin/install_bec_dev.sh -t pip install -e ../ophyd_devices - podman pod create --add-host redis:host-gateway local_bec + podman pod create --network=host local_bec python ./bec_ipython_client/tests/end-2-end/_ensure_requirements_container.py pytest -v --files-path ./ --start-servers --random-order ./bec_ipython_client/tests/end-2-end/ diff --git a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py index 5175de066..48afa5550 100644 --- a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from dataclasses import dataclass from importlib.metadata import version from typing import TYPE_CHECKING, Callable, Generator from unittest.mock import MagicMock, patch @@ -11,6 +12,7 @@ from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_server.scan_server.procedures.constants import _CONTAINER, _WORKER from bec_server.scan_server.procedures.container_utils import get_backend from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager @@ -27,6 +29,15 @@ pytestmark = pytest.mark.random_order(disabled=True) +@dataclass(frozen=True) +class PATCHED_CONSTANTS: + WORKER = _WORKER() + CONTAINER = _CONTAINER() + MANAGER_SHUTDOWN_TIMEOUT_S = 2 + BEC_VERSION = version("bec_lib") + REDIS_HOST = "localhost" + + @pytest.fixture def client_logtool_and_manager( bec_ipython_client_fixture_with_logtool: tuple[BECIPythonClient, "LogTestTool"], @@ -84,6 +95,7 @@ def cb(worker: ContainerProcedureWorker): @pytest.mark.timeout(100) @patch("bec_server.scan_server.procedures.manager.procedure_registry.is_registered", lambda _: True) +@patch("bec_server.scan_server.procedures.container_worker.PROCEDURE", PATCHED_CONSTANTS()) def test_happy_path_container_procedure_runner( client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager], ): @@ -103,10 +115,10 @@ def test_happy_path_container_procedure_runner( logtool.fetch() assert logtool.is_present_in_any_message("procedure accepted: True, message:") - assert "test string" in "\n".join(manager._logs) assert logtool.is_present_in_any_message( "ContainerWorker started container for queue primary" - ), f"Log content relating to procedures: {manager}" + ), f"Log content relating to procedures: {manager._logs}" + res, msg = logtool.are_present_in_order( [ "Container worker 'primary' status update: IDLE", @@ -116,12 +128,7 @@ def test_happy_path_container_procedure_runner( ] ) assert res, f"failed on {msg}" - res, msg = logtool.are_present_in_order( - [ - "Container worker 'primary' status update: IDLE", - f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}", - "Container worker 'primary' status update: IDLE", - "Container worker 'primary' status update: FINISHED", - ] + + assert logtool.is_present_in_any_message( + f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}" ) - assert res, f"failed on {msg}" diff --git a/bec_server/bec_server/scan_server/procedures/constants.py b/bec_server/bec_server/scan_server/procedures/constants.py index 6bdfeee85..ed5d20cdd 100644 --- a/bec_server/bec_server/scan_server/procedures/constants.py +++ b/bec_server/bec_server/scan_server/procedures/constants.py @@ -67,6 +67,7 @@ class _PROCEDURE: CONTAINER = _CONTAINER() MANAGER_SHUTDOWN_TIMEOUT_S = 2 BEC_VERSION = version("bec_lib") + REDIS_HOST = "redis" PROCEDURE = _PROCEDURE() diff --git a/bec_server/bec_server/scan_server/procedures/container_worker.py b/bec_server/bec_server/scan_server/procedures/container_worker.py index 47965eb59..8a9f3a554 100644 --- a/bec_server/bec_server/scan_server/procedures/container_worker.py +++ b/bec_server/bec_server/scan_server/procedures/container_worker.py @@ -62,7 +62,7 @@ def _worker_environment(self) -> ContainerWorkerEnv: minimum necessary, or things which are only necessary for the functioning of the worker, and other information should be passed through redis""" return { - "redis_server": f"redis:{self._conn.port}", + "redis_server": f"{PROCEDURE.REDIS_HOST}:{self._conn.port}", "queue": self._queue, "timeout_s": str(self._lifetime_s), } From 86192b7b9b933814e27a0a70af486f1294d7ccf7 Mon Sep 17 00:00:00 2001 From: perl_d Date: Tue, 18 Nov 2025 13:04:03 +0100 Subject: [PATCH 09/10] docs: fix docstrings --- bec_lib/bec_lib/endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 989be3c99..fc11eb572 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -1506,7 +1506,7 @@ def procedure_abort(): Endpoint to request aborting a running procedure Returns: - EndpointInfo: Endpoint for set of active procedure executions. + EndpointInfo: Endpoint to request procedure abortion. """ endpoint = f"{EndpointType.USER.value}/procedures/abort" return EndpointInfo( @@ -1521,7 +1521,7 @@ def procedure_clear_unhandled(): Endpoint to request removing an aborted procedure Returns: - EndpointInfo: Endpoint for set of active procedure executions. + EndpointInfo: Endpoint to request removing aborted procedures. """ endpoint = f"{EndpointType.USER.value}/procedures/clear_unhandled" return EndpointInfo( From cae0c5dbdb6f5af8efcb05d3232526a330a55e43 Mon Sep 17 00:00:00 2001 From: David Perl Date: Wed, 26 Nov 2025 11:59:25 +0100 Subject: [PATCH 10/10] tests: mock ACLs for fakeredis --- bec_server/tests/tests_scan_server/test_procedures.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py index 1eee31e31..7775e5281 100644 --- a/bec_server/tests/tests_scan_server/test_procedures.py +++ b/bec_server/tests/tests_scan_server/test_procedures.py @@ -40,6 +40,7 @@ @pytest.fixture(autouse=True) +@patch("bec_lib.bec_service.BECAccess", MagicMock) def shutdown_client(): bec_client = BECClient( config=ServiceConfig(config={"redis": {"host": FAKEREDIS_HOST, "port": FAKEREDIS_PORT}}),