diff --git a/.github/actions/bec_e2e_install/action.yml b/.github/actions/bec_e2e_install/action.yml index 8a0c3bd61..8f93f393b 100644 --- a/.github/actions/bec_e2e_install/action.yml +++ b/.github/actions/bec_e2e_install/action.yml @@ -62,7 +62,7 @@ runs: cd ./_e2e_test_checkout_/bec source ./bin/install_bec_dev.sh -t pip install -e ../ophyd_devices - podman pod create --net host local_bec + podman pod create --network=host local_bec python ./bec_ipython_client/tests/end-2-end/_ensure_requirements_container.py pytest -v --files-path ./ --start-servers --random-order ./bec_ipython_client/tests/end-2-end/ diff --git a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py index 409c11f92..48afa5550 100644 --- a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from dataclasses import dataclass from importlib.metadata import version from typing import TYPE_CHECKING, Callable, Generator from unittest.mock import MagicMock, patch @@ -11,7 +12,7 @@ from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_server.scan_server.procedures.constants import PROCEDURE +from bec_server.scan_server.procedures.constants import _CONTAINER, _WORKER from bec_server.scan_server.procedures.container_utils import get_backend from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager @@ -28,6 +29,15 @@ pytestmark = pytest.mark.random_order(disabled=True) +@dataclass(frozen=True) +class PATCHED_CONSTANTS: + WORKER = _WORKER() + CONTAINER = _CONTAINER() + MANAGER_SHUTDOWN_TIMEOUT_S = 2 + BEC_VERSION = version("bec_lib") + REDIS_HOST = "localhost" + + @pytest.fixture def client_logtool_and_manager( bec_ipython_client_fixture_with_logtool: tuple[BECIPythonClient, "LogTestTool"], @@ -52,7 +62,7 @@ def _wait_while(cond: Callable[[], bool], timeout_s): def test_building_worker_image(): podman_utils = get_backend() build = podman_utils.build_worker_image() - assert len(build._command_output.splitlines()[-1]) == 64 + assert len(build._command_output.splitlines()[-1]) == 64 # type: ignore assert podman_utils.image_exists(f"bec_procedure_worker:v{version('bec_lib')}") @@ -62,7 +72,7 @@ def test_procedure_runner_spawns_worker( client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager], ): client, _, manager = client_logtool_and_manager - assert manager.active_workers == {} + assert manager._active_workers == {} endpoint = MessageEndpoints.procedure_request() msg = messages.ProcedureRequestMessage( identifier="sleep", args_kwargs=((), {"time_s": 2}), queue="test" @@ -77,21 +87,22 @@ def cb(worker: ContainerProcedureWorker): manager.add_callback("test", cb) client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump()) - _wait_while(lambda: manager.active_workers == {}, 5) - _wait_while(lambda: manager.active_workers != {}, 20) + _wait_while(lambda: manager._active_workers == {}, 5) + _wait_while(lambda: manager._active_workers != {}, 20) assert logs != [] @pytest.mark.timeout(100) @patch("bec_server.scan_server.procedures.manager.procedure_registry.is_registered", lambda _: True) +@patch("bec_server.scan_server.procedures.container_worker.PROCEDURE", PATCHED_CONSTANTS()) def test_happy_path_container_procedure_runner( client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager], ): test_args = (1, 2, 3) test_kwargs = {"a": "b", "c": "d"} client, logtool, manager = client_logtool_and_manager - assert manager.active_workers == {} + assert manager._active_workers == {} conn = client.connector endpoint = MessageEndpoints.procedure_request() msg = messages.ProcedureRequestMessage( @@ -99,12 +110,15 @@ def test_happy_path_container_procedure_runner( ) conn.xadd(topic=endpoint, msg_dict=msg.model_dump()) - _wait_while(lambda: manager.active_workers == {}, 5) - _wait_while(lambda: manager.active_workers != {}, 20) + _wait_while(lambda: manager._active_workers == {}, 5) + _wait_while(lambda: manager._active_workers != {}, 20) logtool.fetch() assert logtool.is_present_in_any_message("procedure accepted: True, message:") - assert logtool.is_present_in_any_message("ContainerWorker started container for queue primary") + assert logtool.is_present_in_any_message( + "ContainerWorker started container for queue primary" + ), f"Log content relating to procedures: {manager._logs}" + res, msg = logtool.are_present_in_order( [ "Container worker 'primary' status update: IDLE", @@ -114,12 +128,7 @@ def test_happy_path_container_procedure_runner( ] ) assert res, f"failed on {msg}" - res, msg = logtool.are_present_in_order( - [ - "Container worker 'primary' status update: IDLE", - f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}", - "Container worker 'primary' status update: IDLE", - "Container worker 'primary' status update: FINISHED", - ] + + assert logtool.is_present_in_any_message( + f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}" ) - assert res, f"failed on {msg}" diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index cc4f56b90..fc11eb572 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -41,9 +41,9 @@ class MessageOp(list[str], enum.Enum): SET_PUBLISH = ["register", "set_and_publish", "delete", "get", "keys"] SEND = ["send", "register"] STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last", "delete"] - LIST = ["lpush", "lrange", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"] + LIST = ["lpush", "lrange", "lrem", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"] KEY_VALUE = ["set", "get", "delete", "keys"] - SET = ["remove_from_set", "get_set_members"] + SET = ["remove_from_set", "get_set_members", "delete"] MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]") @@ -1420,8 +1420,8 @@ def available_procedures() -> EndpointInfo: @staticmethod def procedure_request() -> EndpointInfo: """ - Endpoint for scan queue request. This endpoint is used to request the new scans. - The request is sent using a messages.ScanQueueMessage message. + Endpoint for requesting new procedures. + The request is sent using a messages.ProcedureRequestMessage message. Returns: EndpointInfo: Endpoint for scan queue request. @@ -1460,7 +1460,24 @@ def procedure_execution(queue_id: str): Returns: EndpointInfo: Endpoint for scan queue request. """ - endpoint = f"{EndpointType.INTERNAL.value}/procedures/procedure_execution/{queue_id}" + endpoint = f"{EndpointType.INFO.value}/procedures/procedure_execution/{queue_id}" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureExecutionMessage, + message_op=MessageOp.LIST, + ) + + @staticmethod + def unhandled_procedure_execution(queue_id: str): + """ + Endpoint for procedure executions which were pending when the manager was shutdown. + Messages from procedure_execution are moved here on manager startup. + The request is sent using a messages.ProcedureExecutionMessage message. + + Returns: + EndpointInfo: Endpoint for scan queue request. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/unhandled_procedure_execution/{queue_id}" return EndpointInfo( endpoint=endpoint, message_type=messages.ProcedureExecutionMessage, @@ -1483,6 +1500,36 @@ def active_procedure_executions(): message_op=MessageOp.SET, ) + @staticmethod + def procedure_abort(): + """ + Endpoint to request aborting a running procedure + + Returns: + EndpointInfo: Endpoint to request procedure abortion. + """ + endpoint = f"{EndpointType.USER.value}/procedures/abort" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureAbortMessage, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def procedure_clear_unhandled(): + """ + Endpoint to request removing an aborted procedure + + Returns: + EndpointInfo: Endpoint to request removing aborted procedures. + """ + endpoint = f"{EndpointType.USER.value}/procedures/clear_unhandled" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureClearUnhandledMessage, + message_op=MessageOp.STREAM, + ) + @staticmethod def procedure_worker_status_update(queue_id: str): """ @@ -1498,6 +1545,34 @@ def procedure_worker_status_update(queue_id: str): message_op=MessageOp.LIST, ) + @staticmethod + def procedure_queue_notif(): + """ + PubSub channel for a consumer (e.g. BEC widgets) to be notified of changes to a procedure queue + + Returns: + EndpointInfo: Endpoint for procedure queue updates for given queue. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/queue_notif" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.ProcedureQNotifMessage, + message_op=MessageOp.SEND, + ) + + @staticmethod + def procedure_logs(queue: str): + """ + Endpoint for logs for a given procedure queue + + Returns: + EndpointInfo: Endpoint for procedure queue updates for given queue. + """ + endpoint = f"{EndpointType.INFO.value}/procedures/logs/{queue}" + return EndpointInfo( + endpoint=endpoint, message_type=messages.RawMessage, message_op=MessageOp.STREAM + ) + @staticmethod def gui_registry_state(gui_id: str): """ diff --git a/bec_lib/bec_lib/logger.py b/bec_lib/bec_lib/logger.py index 3e02f305a..201efb3a8 100644 --- a/bec_lib/bec_lib/logger.py +++ b/bec_lib/bec_lib/logger.py @@ -64,6 +64,7 @@ class BECLogger: "{service_name} | {{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} |" " {{thread.name}} ({{thread.id}}) | {{extra[stack]}} - {{message}}\n" ) + CONTAINER_FORMAT = "{{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} | {{message}}\n" LOGLEVEL = LogLevel _logger = None @@ -224,18 +225,21 @@ def _logger_callback(self, msg): # because it depends on the connector pass - def get_format(self, level: LogLevel = None, is_stderr=False) -> str: + def get_format(self, level: LogLevel = None, is_stderr=False, is_container=False) -> str: """ Get the format for a specific log level. Args: level (LogLevel, optional): Log level. Defaults to None. If None, the current log level will be used. is_stderr (bool, optional): Whether the log is for stderr. Defaults to False. + is_container (bool, optional): Simple logging for procedure container. Defaults to False. Returns: str: Log format. """ service_name = self.service_name if self.service_name else "" + if is_container: + return self.CONTAINER_FORMAT.format() if level is None: level = self.level if level > self.LOGLEVEL.DEBUG: @@ -246,15 +250,16 @@ def get_format(self, level: LogLevel = None, is_stderr=False) -> str: return self.DEBUG_FORMAT.format(service_name=service_name) return self.TRACE_FORMAT.format(service_name=service_name) - def formatting(self, is_stderr=False): + def formatting(self, is_stderr=False, is_container=False): """ Format the log message. Args: record (dict): Log record. + is_container (bool, optional): Simple logging for procedure container. Defaults to False. Returns: - dict: Formatted log record. + str: Log format. """ def _update_record(record): @@ -269,7 +274,7 @@ def _update_record(record): def _format(record): level = _update_record(record) - return self.get_format(level) + return self.get_format(level, is_container=is_container) def _format_stderr(record): level = _update_record(record) diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 28c12943b..39bdc51c9 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -5,7 +5,8 @@ import warnings from copy import deepcopy from enum import Enum, auto -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, Self +from uuid import uuid4 import numpy as np from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator @@ -19,6 +20,7 @@ class ProcedureWorkerStatus(Enum): IDLE = auto() FINISHED = auto() DEAD = auto() # worker lost communication with the container + NONE = auto() # worker doesn't exist in manager, caught during creation and cleanup class BECStatus(Enum): @@ -1212,18 +1214,67 @@ class ProcedureRequestMessage(BECMessage): queue: str | None = None +class ProcedureQNotifMessage(BECMessage): + """Message type for notifying watchers of changes to queues""" + + msg_type: ClassVar[str] = "procedure_queue_notif_message" + queue_name: str + queue_type: Literal["execution", "unhandled"] + + class ProcedureExecutionMessage(BECMessage): """Message type for sending procedure execution instructions to the scheduler Sent by the user to the procedure_request topic. It will be consumed by the scan server. Args: identifier (str): name of the procedure registered with the server + queue (str): the procedure queue this execution belongs to + args_kwargs (tuple[tuple[Any, ...], dict[str, Any]]): arguments for the procedure function """ msg_type: ClassVar[str] = "procedure_execution_message" identifier: str queue: str args_kwargs: tuple[tuple[Any, ...], dict[str, Any]] = (), {} + execution_id: str = Field(default_factory=lambda: str(uuid4())) + + +class ProcedureAbortMessage(BECMessage): + """Message type to request aborting a procedure or procedure queue + + One and only one of the args should be supplied. + Args: + queue (str | None): the procedure queue to abort + execution_id (str | None): the procedure execution to abort + abort_all (bool | None): abort all procedures if true + """ + + msg_type: ClassVar[str] = "procedure_abort_message" + queue: str | None = None + execution_id: str | None = None + abort_all: bool | None = None + + @model_validator(mode="after") + def mutually_exclusive(self) -> Self: + if (self.queue, self.execution_id, self.abort_all).count(None) != 2: + raise ValueError( + "Please only supply one argument! Supplied: \n" + f" {self.queue=}, {self.execution_id=}, {self.abort_all=}" + ) + return self + + +class ProcedureClearUnhandledMessage(ProcedureAbortMessage): + """Message type to request clearing an unhandled procedure or procedure queue + + One and only one of the args should be supplied. + Args: + queue (str | None): the procedure queue to abort + execution_id (str | None): the procedure execution to abort + abort_all (bool | None): abort all procedures if true + """ + + ... class ProcedureWorkerStatusMessage(BECMessage): @@ -1232,12 +1283,21 @@ class ProcedureWorkerStatusMessage(BECMessage): Args: worker_queue (str): Worker queue ID status (str): Worker status - + current_execution_id (str | None): ID of the current job, only allowed for RUNNING """ msg_type: ClassVar[str] = "procedure_worker_status_message" worker_queue: str status: ProcedureWorkerStatus + current_execution_id: str | None = None + + @model_validator(mode="after") + def check_id(self) -> Self: + if self.current_execution_id is not None and self.status != ProcedureWorkerStatus.RUNNING: + raise ValueError("Adding an execution ID is only valid for the RUNNING status") + if self.current_execution_id is None and self.status == ProcedureWorkerStatus.RUNNING: + raise ValueError("Adding an execution ID is mandatory for the RUNNING status") + return self class LoginInfoMessage(BECMessage): diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index 115d7b820..8d8764f1e 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1105,6 +1105,23 @@ def lrange(self, topic: str, start: int, end: int, pipe: Pipeline | None = None) ret.append(msg) return ret + @validate_endpoint("topic") + def lrem(self, topic: str, count: int, msg, pipe: Pipeline | None = None): + """Removes the first count occurrences of elements equal to element from the list stored at key. + The count argument influences the operation in the following ways: + count > 0: Remove elements equal to element moving from head to tail. + count < 0: Remove elements equal to element moving from tail to head. + count = 0: Remove all elements equal to element. + For example, LREM list -2 "hello" will remove the last two occurrences of "hello" in the list stored at list. + + Returns the number of items removed + """ + + client = pipe if pipe is not None else self._redis_conn + if isinstance(msg, BECMessage): + msg = MsgpackSerialization.dumps(msg) + return client.lrem(topic, count, msg) + @validate_endpoint("topic") def set_and_publish( self, topic: str, msg, pipe: Pipeline | None = None, expire: int | None = None @@ -1172,7 +1189,7 @@ def xadd( Args: topic (str): redis topic - msg_dict (dict): message to add + msg_dict (dict | BECMessage): message to add max_size (int, optional): max size of stream. Defaults to None. pipe (Pipeline, optional): redis pipe. Defaults to None. expire (int, optional): expire time. Defaults to None. diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index e94ba9196..a1a7cf2d4 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -243,6 +243,31 @@ def test_ProcedureWorkerStatusMessage(): assert res_loaded == msg +def test_ProcedureWorkerStatusMessage_validation(): + with pytest.raises(pydantic.ValidationError) as e: + messages.ProcedureWorkerStatusMessage( + worker_queue="background tasks", + status=messages.ProcedureWorkerStatus.RUNNING, + metadata={"RID": "1234"}, + ) + assert e.match("Adding an execution ID is mandatory") + with pytest.raises(pydantic.ValidationError) as e: + messages.ProcedureWorkerStatusMessage( + worker_queue="background tasks", + status=messages.ProcedureWorkerStatus.IDLE, + metadata={"RID": "1234"}, + current_execution_id="test", + ) + assert e.match("Adding an execution ID is only valid") + + +def test_ProcedureAbortMessage_validation(): + with pytest.raises(pydantic.ValidationError) as e: + messages.ProcedureAbortMessage(queue="test", execution_id="test") + assert e.match("only supply one argument") + messages.ProcedureAbortMessage(queue="test") + + def test_FileMessage(): msg = messages.FileMessage( device_name="samx", diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index beb4173d8..777d8b300 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -568,3 +568,17 @@ def test_redis_connector_message_alternated_pass(connected_connector): assert msg_received == [data_to_check] assert msg_received == [data_original] + + +def test_lrem(connected_connector: RedisConnector): + conn, ep = connected_connector, MessageEndpoints.procedure_execution + msgs = [ + messages.ProcedureExecutionMessage(identifier=_id, queue="primary", args_kwargs=((), {})) + for _id in ("a", "b", "c") + ] + for msg in msgs: + conn.rpush(ep(msg.queue), msg) + assert len(conn.lrange(ep("primary"), 0, -1)) == 3 + conn.lrem(ep(msgs[1].queue), 0, msgs[1]) + list_contents = conn.lrange(ep("primary"), 0, -1) + assert list_contents == [msgs[0], msgs[2]] diff --git a/bec_server/bec_server/scan_server/procedures/_dev_runner.py b/bec_server/bec_server/scan_server/procedures/_dev_runner.py index e87ae4b78..b95db1c27 100644 --- a/bec_server/bec_server/scan_server/procedures/_dev_runner.py +++ b/bec_server/bec_server/scan_server/procedures/_dev_runner.py @@ -7,7 +7,6 @@ from bec_lib.logger import bec_logger from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker - from bec_server.scan_server.procedures.in_process_worker import InProcessProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager logger = bec_logger.logger diff --git a/bec_server/bec_server/scan_server/procedures/builtin_procedures.py b/bec_server/bec_server/scan_server/procedures/builtin_procedures.py index 863b76544..cf708797e 100644 --- a/bec_server/bec_server/scan_server/procedures/builtin_procedures.py +++ b/bec_server/bec_server/scan_server/procedures/builtin_procedures.py @@ -31,3 +31,7 @@ def run_scan(scan_name: str, args: tuple, parameters: dict, *, bec: BECClient): raise ValueError(f"Scan {scan_name} doesn't exist in this client!") scan_report: ScanReport = scan(*args, **parameters) scan_report.wait() + + +def run_script(script_id: str, *, bec: BECClient): + bec._run_script(script_id) diff --git a/bec_server/bec_server/scan_server/procedures/constants.py b/bec_server/bec_server/scan_server/procedures/constants.py index 376266275..ed5d20cdd 100644 --- a/bec_server/bec_server/scan_server/procedures/constants.py +++ b/bec_server/bec_server/scan_server/procedures/constants.py @@ -67,6 +67,7 @@ class _PROCEDURE: CONTAINER = _CONTAINER() MANAGER_SHUTDOWN_TIMEOUT_S = 2 BEC_VERSION = version("bec_lib") + REDIS_HOST = "redis" PROCEDURE = _PROCEDURE() @@ -77,4 +78,5 @@ class PodmanContainerStates(str, Enum): RUNNING = "running" PAUSED = "paused" STOPPED = "stopped" + STOPPING = "stopping" EXITED = "exited" diff --git a/bec_server/bec_server/scan_server/procedures/container_utils.py b/bec_server/bec_server/scan_server/procedures/container_utils.py index 83e23c31e..766c11901 100644 --- a/bec_server/bec_server/scan_server/procedures/container_utils.py +++ b/bec_server/bec_server/scan_server/procedures/container_utils.py @@ -96,6 +96,7 @@ def run( volumes: list[VolumeSpec], command: str, pod_name: str | None = None, + container_name: str | None = None, ) -> str: with PodmanClient(base_url=self.uri) as client: try: @@ -106,6 +107,7 @@ def run( environment=environment, mounts=volumes, pod=pod_name, + name=container_name, ) # type: ignore # running with detach returns container object except APIError as e: if e.status_code == HTTPStatus.INTERNAL_SERVER_ERROR: @@ -151,8 +153,9 @@ def pretty_print(self) -> str: class PodmanCliUtils(_PodmanUtilsBase): - def _run_and_capture_error(self, *args: str): - logger.debug(f"Running {args}") + def _run_and_capture_error(self, *args: str, log: bool = True): + if log: + logger.debug(f"Running {args}") output = subprocess.run([*args], capture_output=True) if output.returncode != 0: raise ProcedureWorkerError( @@ -166,7 +169,7 @@ def _run_and_capture_error(self, *args: str): def _podman_ls_json(self, subcom: Literal["image", "container"] = "container"): return json.loads( self._run_and_capture_error( - "podman", subcom, "list", "--all", "--format", "json" + "podman", subcom, "list", "--all", "--format", "json", log=False ).stdout ) @@ -193,6 +196,7 @@ def run( volumes: list[VolumeSpec], command: str, pod_name: str | None = None, + container_name: str | None = None, ) -> str: _volumes = [ f"{vol['source']}:{vol['target']}{':ro' if vol['read_only'] else ''}" for vol in volumes @@ -200,22 +204,39 @@ def run( _volume_args = list(chain(*(("-v", vol) for vol in _volumes))) _environment = _multi_args_from_dict("-e", environment) # type: ignore # this is actually a dict[str, str] _pod_arg = ["--pod", pod_name] if pod_name else [] + _name_arg = ["--replace", "--name", container_name] if container_name else [] return ( self._run_and_capture_error( - "podman", "run", *_environment, "-d", *_volume_args, *_pod_arg, image_tag, command + "podman", + "run", + *_environment, + "-d", + *_name_arg, + *_volume_args, + *_pod_arg, + image_tag, + command, ) .stdout.decode() .strip() ) def kill(self, id: str): - self._run_and_capture_error("podman", "kill", id) - self._run_and_capture_error("podman", "rm", id) + try: + self._run_and_capture_error("podman", "kill", id) + except ProcedureWorkerError as e: + logger.error(e) def logs(self, id: str) -> list[str]: - return self._run_and_capture_error("podman", "logs", id).stderr.decode().splitlines() + try: + return self._run_and_capture_error("podman", "logs", id).stderr.decode().splitlines() + except ProcedureWorkerError as e: + logger.error(e) + return [f"No logs found for container {id}\n"] def state(self, id: str) -> PodmanContainerStates | None: for container in self._podman_ls_json(): if container["Id"] == id or container["Id"].startswith(id): + if names := container.get("Names"): + logger.debug(f"Container {names[0]} status: {container['State']}") return PodmanContainerStates(container["State"]) diff --git a/bec_server/bec_server/scan_server/procedures/container_worker.py b/bec_server/bec_server/scan_server/procedures/container_worker.py index 0d6c2e5b1..8a9f3a554 100644 --- a/bec_server/bec_server/scan_server/procedures/container_worker.py +++ b/bec_server/bec_server/scan_server/procedures/container_worker.py @@ -1,10 +1,15 @@ +import inspect import os +import sys +from contextlib import redirect_stdout +from typing import AnyStr, TextIO +from bec_ipython_client.main import BECIPythonClient from bec_lib import messages from bec_lib.client import BECClient from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import LogLevel, bec_logger -from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus +from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus, RawMessage from bec_lib.redis_connector import RedisConnector from bec_lib.service_config import ServiceConfig from bec_server.scan_server.procedures import procedure_registry @@ -15,12 +20,35 @@ ProcedureWorkerError, ) from bec_server.scan_server.procedures.container_utils import get_backend +from bec_server.scan_server.procedures.helper import BackendProcedureHelper from bec_server.scan_server.procedures.protocol import ContainerCommandBackend from bec_server.scan_server.procedures.worker_base import ProcedureWorker logger = bec_logger.logger +class RedisOutputDiverter(TextIO): + def __init__(self, conn: RedisConnector, queue: str): + + self._conn = conn + self._ep = MessageEndpoints.procedure_logs(queue) + self._conn.delete(self._ep) + + def write(self, data: AnyStr): + if data: + self._conn.xadd(self._ep, {"data": RawMessage(data=str(data))}) + return len(data) + + def flush(self): ... + + @property + def encoding(self): + return "utf-8" + + def close(self): + return + + class ContainerProcedureWorker(ProcedureWorker): """A worker which runs scripts in a container with a full BEC environment, mounted from the filesystem, and only access to Redis""" @@ -34,7 +62,7 @@ def _worker_environment(self) -> ContainerWorkerEnv: minimum necessary, or things which are only necessary for the functioning of the worker, and other information should be passed through redis""" return { - "redis_server": f"{self._conn.host}:{self._conn.port}", + "redis_server": f"{PROCEDURE.REDIS_HOST}:{self._conn.port}", "queue": self._queue, "timeout_s": str(self._lifetime_s), } @@ -42,6 +70,7 @@ def _worker_environment(self) -> ContainerWorkerEnv: def _setup_execution_environment(self): self._backend: ContainerCommandBackend = get_backend() image_tag = f"{PROCEDURE.CONTAINER.IMAGE_NAME}:v{PROCEDURE.BEC_VERSION}" + self.container_name = f"bec_procedure_{PROCEDURE.BEC_VERSION}_{self._queue}" if not self._backend.image_exists(image_tag): self._backend.build_worker_image() self._container_id = self._backend.run( @@ -57,6 +86,7 @@ def _setup_execution_environment(self): ], PROCEDURE.CONTAINER.COMMAND, pod_name=PROCEDURE.CONTAINER.POD_NAME, + container_name=self.container_name, ) def _run_task(self, item: ProcedureExecutionMessage): @@ -64,12 +94,16 @@ def _run_task(self, item: ProcedureExecutionMessage): f"Container worker _run_task() called with {item} - this should never happen!" ) - def _kill_process(self): - if self._backend.state(self._container_id) not in [ + def _ending_or_ended(self): + return self._backend.state(self._container_id) in [ PodmanContainerStates.EXITED, PodmanContainerStates.STOPPED, - ]: - self._backend.kill(self._container_id) + PodmanContainerStates.STOPPING, + ] + + def _kill_process(self): + if not self._ending_or_ended(): + self._backend.kill(self.container_name) def work(self): """block until the container is finished, listen for status updates in the meantime""" @@ -77,34 +111,48 @@ def work(self): # on timeout check if container is still running status_update = None - while self._backend.state(self._container_id) not in [ - PodmanContainerStates.EXITED, - PodmanContainerStates.STOPPED, - ]: + while not self._ending_or_ended(): status_update = self._conn.blocking_list_pop( - MessageEndpoints.procedure_worker_status_update(self._queue), timeout_s=1 + MessageEndpoints.procedure_worker_status_update(self._queue), timeout_s=0.2 ) if status_update is not None: if not isinstance(status_update, messages.ProcedureWorkerStatusMessage): raise ProcedureWorkerError(f"Received unexpected message {status_update}") self.status = status_update.status + self._current_execution_id = status_update.current_execution_id logger.info( f"Container worker '{self._queue}' status update: {status_update.status.name}" ) # TODO: we probably do want to handle some kind of timeout here but we don't know how # long a running procedure should actually take - it could theoretically be infinite + if self.status != ProcedureWorkerStatus.FINISHED: + self.status = ProcedureWorkerStatus.DEAD + + def abort_execution(self, execution_id: str): + """Abort the execution with the given id. Has no effect if the given ID is not the current job""" + if execution_id == self._current_execution_id: + self._backend.kill(self._container_id) + self._helper.remove_from_active.by_exec_id(execution_id) + logger.info( + f"Aborting execution {execution_id}, restarting worker for queue: {self._queue}" + ) + self._setup_execution_environment() + def logs(self): + if self._container_id is None: + return [""] + return self._backend.logs(self._container_id) -def main(): - """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command""" - logger.info(f"Container worker starting up") + +def _setup(): + logger.info("Container worker starting up") try: needed_keys = ContainerWorkerEnv.__annotations__.keys() logger.debug(f"Checking for environment variables: {needed_keys}") env: ContainerWorkerEnv = {k: os.environ[k] for k in needed_keys} # type: ignore except KeyError as e: logger.error(f"Missing environment variable needed by container worker: {e}") - return + exit(1) bec_logger.level = LogLevel.DEBUG bec_logger._console_log = True @@ -117,19 +165,27 @@ def main(): host, port = env["redis_server"].split(":") redis = {"host": host, "port": port} - client = BECClient(config=ServiceConfig(redis=redis)) + + client = BECIPythonClient(config=ServiceConfig(redis=redis)) + logger.debug("starting client") client.start() logger.info(f"ContainerWorker started container for queue {env['queue']}") logger.debug(f"ContainerWorker environment: {env}") - endpoint_info = MessageEndpoints.procedure_execution(env["queue"]) conn = RedisConnector(env["redis_server"]) + logger.debug(f"ContainerWorker {env['queue']} connected to Redis at {conn.host}:{conn.port}") + helper = BackendProcedureHelper(conn) + + return env, helper, client, conn + + +def _main(env, helper, client, conn): + + exec_endpoint = MessageEndpoints.procedure_execution(env["queue"]) active_procs_endpoint = MessageEndpoints.active_procedure_executions() status_endpoint = MessageEndpoints.procedure_worker_status_update(env["queue"]) - logger.debug(f"ContainerWorker connecting to Redis at {conn.host}:{conn.port}") - try: timeout_s = int(env["timeout_s"]) except ValueError as e: @@ -138,30 +194,43 @@ def main(): ) timeout_s = PROCEDURE.WORKER.QUEUE_TIMEOUT_S - def _push_status(status: ProcedureWorkerStatus): + def _push_status(status: ProcedureWorkerStatus, id: str | None = None): logger.debug(f"Updating container worker status to {status.name}") conn.rpush( status_endpoint, - messages.ProcedureWorkerStatusMessage(worker_queue=env["queue"], status=status), + messages.ProcedureWorkerStatusMessage( + worker_queue=env["queue"], status=status, current_execution_id=id + ), ) def _run_task(item: ProcedureExecutionMessage): - procedure_registry.callable_from_execution_message(item)( - *item.args_kwargs[0], **item.args_kwargs[1] - ) + kwargs = item.args_kwargs[1] + proc_func = procedure_registry.callable_from_execution_message(item) + if bec_arg := inspect.signature(proc_func).parameters.get("bec"): + if bec_arg.kind == bec_arg.KEYWORD_ONLY and bec_arg.annotation.__name__ == "BECClient": + logger.debug(f"Injecting BEC client argument for {item}") + kwargs["bec"] = client + procedure_registry.callable_from_execution_message(item)(*item.args_kwargs[0], **kwargs) _push_status(ProcedureWorkerStatus.IDLE) item = None try: - logger.debug(f"ContainerWorker waiting for instructions on {endpoint_info}") + logger.debug(f"ContainerWorker waiting for instructions on {exec_endpoint}") while ( item := conn.blocking_list_pop_to_set_add( - endpoint_info, active_procs_endpoint, timeout_s=timeout_s + exec_endpoint, active_procs_endpoint, timeout_s=timeout_s ) ) is not None: - _push_status(ProcedureWorkerStatus.RUNNING) + _push_status(ProcedureWorkerStatus.RUNNING, item.execution_id) + helper.notify_watchers(env["queue"], queue_type="execution") logger.debug(f"running task {item!r}") - _run_task(item) + try: + _run_task(item) + except Exception as e: + logger.error(f"Encountered error running procedure {item}") + logger.error(e) + finally: + helper.remove_from_active.by_exec_id(item.execution_id) _push_status(ProcedureWorkerStatus.IDLE) except Exception as e: logger.error(e) # don't stop ProcedureManager.spawn from cleaning up @@ -170,8 +239,19 @@ def _run_task(item: ProcedureExecutionMessage): _push_status(ProcedureWorkerStatus.FINISHED) client.shutdown() if item is not None: # in this case we are here due to an exception, not a timeout - conn.remove_from_set(active_procs_endpoint, item) + helper.remove_from_active.by_exec_id(item.execution_id) -if __name__ == "__main__": - main() +def main(): + """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command""" + + env, helper, client, conn = _setup() + output_diverter = RedisOutputDiverter(conn, env["queue"]) + with redirect_stdout(output_diverter): + logger.add( + output_diverter, + level=LogLevel.SUCCESS, + format=bec_logger.formatting(is_container=True), + filter=bec_logger.filter(), + ) + _main(env, helper, client, conn) diff --git a/bec_server/bec_server/scan_server/procedures/helper.py b/bec_server/bec_server/scan_server/procedures/helper.py new file mode 100644 index 000000000..b68bafa31 --- /dev/null +++ b/bec_server/bec_server/scan_server/procedures/helper.py @@ -0,0 +1,195 @@ +from typing import Literal + +from bec_lib.endpoints import EndpointInfo +from bec_lib.endpoints import MessageEndpoints as ME +from bec_lib.logger import bec_logger +from bec_lib.messages import BECMessage +from bec_lib.messages import ProcedureAbortMessage as AbrtMsg +from bec_lib.messages import ProcedureClearUnhandledMessage as ClrMsg +from bec_lib.messages import ProcedureExecutionMessage as ExecMsg +from bec_lib.messages import ProcedureQNotifMessage as QNotifMsg +from bec_lib.messages import ProcedureRequestMessage as ReqMsg +from bec_lib.redis_connector import RedisConnector + +logger = bec_logger.logger + + +class _HelperBase: + def __init__(self, conn: RedisConnector) -> None: + self._conn = conn + + +class _Request(_HelperBase): + def _xadd(self, ep: EndpointInfo, msg: BECMessage): + self._conn.xadd(ep, msg.model_dump()) + + def procedure(self, msg: ReqMsg): + self._xadd(ME.procedure_request(), msg) + + def abort_execution(self, execution_id: str): + """Send a message requesting an abort of execution_id""" + return self._xadd(ME.procedure_abort(), AbrtMsg(execution_id=execution_id)) + + def abort_queue(self, queue: str): + """Send a message requesting an abort of execution_id""" + return self._xadd(ME.procedure_abort(), AbrtMsg(queue=queue)) + + def abort_all(self): + """Send a message requesting an abort of execution_id""" + return self._xadd(ME.procedure_abort(), AbrtMsg(abort_all=True)) + + def clear_unhandled_execution(self, execution_id: str): + """Send a message requesting an abort of execution_id""" + return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(execution_id=execution_id)) + + def clear_unhandled_queue(self, queue: str): + """Send a message requesting an abort of execution_id""" + return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(queue=queue)) + + def clear_all_unhandled(self): + """Send a message requesting an abort of execution_id""" + return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(abort_all=True)) + + +class _Get(_HelperBase): + def running_procedures(self) -> set[ExecMsg]: + """Get all the running procedures""" + return self._conn.get_set_members(ME.active_procedure_executions()) + + def exec_queue(self, queue: str) -> list[ExecMsg]: + """Get all the ProcedureExecutionMessages from a given execution queue""" + return self._conn.lrange(ME.procedure_execution(queue), 0, -1) + + def unhandled_queue(self, queue: str) -> list[ExecMsg]: + """Get all the ProcedureExecutionMessages from a given unhandled execution queue""" + return self._conn.lrange(ME.unhandled_procedure_execution(queue), 0, -1) + + def active_and_pending_queue_names(self) -> list[str]: + """Get the names of all pending queues and queues of currently running procedures""" + return list(set(self.queue_names()) | set(self.active_queue_names())) + + def active_queue_names(self) -> list[str]: + """Get the names of all queues of currently running procedures""" + return list({msg.queue for msg in self.running_procedures()}) + + def queue_names(self, queue_type: Literal["execution", "unhandled"] = "execution") -> list[str]: + """Get the names of queues currently containing pending ProcedureExecutionMessages + + Args: + queue_type (Literal["execution", "unhandled"]): Type of queue, default "execution" for currently active executions, "unhandled" for aborted executions + """ + ep = ( + ME.procedure_execution + if queue_type == "execution" + else ME.unhandled_procedure_execution + ) + raw: list[str] = [s.decode() for s in self._conn.keys(ep("*"))] + return [s.split("/")[-1] for s in raw] + + +class FrontendProcedureHelper(_HelperBase): + + def __init__(self, conn: RedisConnector) -> None: + super().__init__(conn) + self.request = _Request(conn) + self.get = _Get(conn) + + +class _BackenHelperBase(_HelperBase): + def __init__(self, conn: RedisConnector, parent: "BackendProcedureHelper") -> None: + self._conn = conn + self._parent = parent + + +class _Push(_BackenHelperBase): + def exec(self, queue: str, msg: ExecMsg): + """Push execution message `msg` to execution queue `queue`""" + self._conn.rpush(ME.procedure_execution(queue), msg) + self._parent.notify_watchers(queue, "execution") + + def unhandled(self, queue: str, msg: ExecMsg): + """Push execution message `msg` to unhandled execution queue `queue`""" + self._conn.rpush(ME.unhandled_procedure_execution(queue), msg) + self._parent.notify_watchers(queue, "unhandled") + + +class _Clear(_BackenHelperBase): + def all_unhandled(self): + """Remove all unhandled execution queues""" + for queue in self._parent.get.queue_names("unhandled"): + self.unhandled_queue(queue) + + def unhandled_queue(self, queue: str): + """Remove an unhandled execution queue""" + self._conn.delete(ME.unhandled_procedure_execution(queue)) + self._parent.notify_watchers(queue, "unhandled") + + def unhandled_execution(self, execution_id: str): + """Remove a ProcedureExecutionMessage from its unhandled queue by its execution ID""" + for queue in self._parent.get.queue_names("unhandled"): + for msg in self._parent.get.unhandled_queue(queue): + if msg.execution_id == execution_id: + if self._conn.lrem(ME.unhandled_procedure_execution(msg.queue), 0, msg) > 0: + logger.debug(f"Removed execution {msg} from queue.") + self._parent.notify_watchers(queue, "unhandled") + return + logger.debug(f"Execution {execution_id} not found in any unhandled queue.") + + +class _Move(_BackenHelperBase): + def all_active_to_unhandled(self): + """Move all messages in the active executions set to unhandled""" + for msg in self._parent.get.running_procedures(): + self._parent.push.unhandled(msg.queue, msg) + self._conn.delete(ME.active_procedure_executions()) + + def execution_queue_to_unhandled(self, queue: str): + """Move all messages from execution queue to unhandled execution queue of the same name""" + for msg in self._parent.get.exec_queue(queue): + self._parent.push.unhandled(queue, msg) + self._conn.delete(ME.procedure_execution(queue)) + + def all_execution_queues_to_unhandled(self): + """Move all messages from all execution queues to unhandled execution queues of the same name""" + for queue in self._parent.get.queue_names(): + self.execution_queue_to_unhandled(queue) + + +class _RemoveFromActive(_BackenHelperBase): + def by_exec_id(self, execution_id: str): + """Remove a message from the set of currently active executions""" + for msg in self._conn.get_set_members(ME.active_procedure_executions()): + if msg.execution_id == execution_id: + self._conn.remove_from_set(ME.active_procedure_executions(), msg) + logger.debug(f"removed active procedure {execution_id}") + return + logger.debug(f"No active procedure {execution_id} to remove") + + def by_queue(self, queue: str): + """Remove a message from the set of currently active executions""" + removed = False + for msg in self._conn.get_set_members(ME.active_procedure_executions()): + if msg.queue == queue: + self._conn.remove_from_set(ME.active_procedure_executions(), msg) + logger.debug(f"removed active procedure {msg} with queue {queue}") + if removed: + return + logger.debug(f"No active procedure with queue {queue} to remove") + + +class BackendProcedureHelper(FrontendProcedureHelper): + def __init__(self, conn: RedisConnector) -> None: + super().__init__(conn) + self.push = _Push(conn, self) + self.clear = _Clear(conn, self) + self.move = _Move(conn, self) + self.remove_from_active = _RemoveFromActive(conn, self) + + def notify_watchers(self, queue: str, queue_type: Literal["execution", "unhandled"]): + return self._conn.send( + ME.procedure_queue_notif(), QNotifMsg(queue_name=queue, queue_type=queue_type) + ) + + def notify_all(self, queue_type: Literal["execution", "unhandled"]): + for queue in self.get.queue_names(queue_type): + self.notify_watchers(queue, queue_type) diff --git a/bec_server/bec_server/scan_server/procedures/in_process_worker.py b/bec_server/bec_server/scan_server/procedures/in_process_worker.py index e225a88b4..d56c9eebe 100644 --- a/bec_server/bec_server/scan_server/procedures/in_process_worker.py +++ b/bec_server/bec_server/scan_server/procedures/in_process_worker.py @@ -12,7 +12,7 @@ class InProcessProcedureWorker(ProcedureWorker): """A simple in-process procedure worker. Be careful with this, it should only run trusted code. - Intended for built-in procedures like those to run a single scan.""" + Intended for built-in procedures like those to run a single scan, or testing.""" def _setup_execution_environment(self): from bec_lib.logger import bec_logger @@ -45,3 +45,7 @@ def _kill_process(self): self.logger.info( f"In-process procedure worker for queue {self.key.endpoint} timed out after {self._lifetime_s} s, shutting down" ) + + def abort_execution(self, execution_id: str): + """Abort the execution with the given id""" + ... # No nice way to abort a running function, don't use this class for such things diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py index d402d92e3..faa401ac7 100644 --- a/bec_server/bec_server/scan_server/procedures/manager.py +++ b/bec_server/bec_server/scan_server/procedures/manager.py @@ -1,25 +1,38 @@ from __future__ import annotations import atexit +from collections import deque from concurrent import futures from concurrent.futures import Future, ThreadPoolExecutor from threading import RLock -from typing import Any, Callable, TypedDict +from typing import Any, Callable, TypedDict, TypeVar from pydantic import ValidationError from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.messages import ProcedureRequestMessage, RequestResponseMessage +from bec_lib.messages import ( + BECMessage, + ProcedureAbortMessage, + ProcedureClearUnhandledMessage, + ProcedureExecutionMessage, + ProcedureRequestMessage, + ProcedureWorkerStatus, + RequestResponseMessage, +) from bec_lib.redis_connector import RedisConnector from bec_server.scan_server.procedures import procedure_registry from bec_server.scan_server.procedures.constants import PROCEDURE, WorkerAlreadyExists +from bec_server.scan_server.procedures.helper import BackendProcedureHelper from bec_server.scan_server.procedures.worker_base import ProcedureWorker from bec_server.scan_server.scan_server import ScanServer logger = bec_logger.logger +# TODO garbage collect IDs + + class ProcedureWorkerEntry(TypedDict): worker: ProcedureWorker | None future: Future @@ -33,6 +46,15 @@ def _log_on_end(future: Future): logger.success(f"Procedure worker {future} shut down gracefully") +_T = TypeVar("_T", bound=BECMessage) + + +def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T: + if isinstance(msg, dict): + return MsgType.model_validate(msg) + return msg + + class ProcedureManager: def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): @@ -42,10 +64,18 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): Args: parent (ScanServer): the scan server to get the Redis server address from. worker_type (type[ProcedureWorker]): which kind of worker to use.""" - self._parent = parent self.lock = RLock() - self.active_workers: dict[str, ProcedureWorkerEntry] = {} + + logger.success("Initialising procedure manager...") + + self._logs = deque([], maxlen=1000) + self._conn = RedisConnector([self._parent.bootstrap_server]) + self._helper = BackendProcedureHelper(self._conn) + self._startup() + + self._active_workers: dict[str, ProcedureWorkerEntry] = {} + self._messages_by_ids: dict[str, ProcedureExecutionMessage] = {} self.executor = ThreadPoolExecutor( max_workers=PROCEDURE.WORKER.MAX_WORKERS, thread_name_prefix="user_procedure_" ) @@ -53,11 +83,26 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]): self._callbacks: dict[str, list[Callable[[ProcedureWorker], Any]]] = {} self._worker_cls = worker_type - self._conn = RedisConnector([self._parent.bootstrap_server]) self._reply_endpoint = MessageEndpoints.procedure_request_response() self._server = f"{self._conn.host}:{self._conn.port}" - self._conn.register(MessageEndpoints.procedure_request(), None, self.process_queue_request) + self._conn.register(MessageEndpoints.procedure_abort(), None, self._process_abort) + self._conn.register( + MessageEndpoints.procedure_clear_unhandled(), None, self._process_clear_unhandled + ) + self._conn.register(MessageEndpoints.procedure_request(), None, self._process_queue_request) + logger.success("Done initialising procedure manager.") + + def _startup(self): + # If the server is restarted, clear any pending requests, they'll have to be resubmitted + self._conn.delete(MessageEndpoints.procedure_request()) + previous_queues = self._helper.get.active_and_pending_queue_names() + logger.debug(f"Clearing previous procedure queues {previous_queues}...") + self._helper.move.all_execution_queues_to_unhandled() + self._helper.move.all_active_to_unhandled() + for queue in previous_queues: + self._helper.notify_watchers(queue, "execution") + self._helper.notify_all("unhandled") def _ack(self, accepted: bool, msg: str): logger.info(f"procedure accepted: {accepted}, message: {msg}") @@ -65,9 +110,9 @@ def _ack(self, accepted: bool, msg: str): self._reply_endpoint, RequestResponseMessage(accepted=accepted, message=msg) ) - def _validate_request(self, msg: dict[str, Any]): + def _validate_request(self, msg: dict[str, Any] | ProcedureRequestMessage): try: - message_obj = ProcedureRequestMessage.model_validate(msg) + message_obj = _resolve_dict(msg, ProcedureRequestMessage) if not procedure_registry.is_registered(message_obj.identifier): self._ack( False, @@ -86,13 +131,17 @@ def add_callback(self, queue: str, cb: Callable[[ProcedureWorker], Any]): self._callbacks[queue].append(cb) def _run_callbacks(self, queue: str): - if (worker := self.active_workers[queue]["worker"]) is None: - return + with self.lock: + if queue not in self._active_workers: + logger.error(f"Attempted to run callbacks for nonexistent worker {queue}") + return + if (worker := self._active_workers[queue]["worker"]) is None: + return for cb in self._callbacks.get(queue, []): cb(worker) self._callbacks[queue] = [] - def process_queue_request(self, msg: dict[str, Any]): + def _process_queue_request(self, msg: dict[str, Any] | ProcedureRequestMessage): """Read a `ProcedureRequestMessage` and if it is valid, create a corresponding `ProcedureExecutionMessage`. If there is already a worker for the queue for that request message, add the execution message to that queue, otherwise create a new queue and a new worker. @@ -101,33 +150,93 @@ def process_queue_request(self, msg: dict[str, Any]): msg (dict[str, Any]): dict corresponding to a ProcedureRequestMessage""" logger.debug(f"Procedure manager got request message {msg}") - if (message_obj := self._validate_request(msg)) is None: + if (message := self._validate_request(msg)) is None: return - self._ack(True, f"Running procedure {message_obj.identifier}") - queue = message_obj.queue or PROCEDURE.WORKER.DEFAULT_QUEUE - endpoint = MessageEndpoints.procedure_execution(queue) - logger.debug(f"active workers: {self.active_workers}, worker requested: {queue}") - self._conn.rpush( - endpoint, - endpoint.message_type( - identifier=message_obj.identifier, - queue=queue, - args_kwargs=message_obj.args_kwargs or ((), {}), - ), + self._ack(True, f"Running procedure {message.identifier}") + queue = message.queue or PROCEDURE.WORKER.DEFAULT_QUEUE + exec_message = ProcedureExecutionMessage( + identifier=message.identifier, queue=queue, args_kwargs=message.args_kwargs or ((), {}) ) + logger.debug(f"active workers: {self._active_workers}, worker requested: {queue}") + self._helper.push.exec(queue, exec_message) def cleanup_worker(fut): with self.lock: logger.debug(f"cleaning up worker {fut} for queue {queue}...") + self._helper.remove_from_active.by_queue(queue) self._run_callbacks(queue) - del self.active_workers[queue] + self._helper.notify_watchers(queue, "execution") + del self._active_workers[queue] with self.lock: - if queue not in self.active_workers: + if queue not in self._active_workers: new_worker = self.executor.submit(self.spawn, queue=queue) new_worker.add_done_callback(_log_on_end) new_worker.add_done_callback(cleanup_worker) - self.active_workers[queue] = {"worker": None, "future": new_worker} + self._active_workers[queue] = {"worker": None, "future": new_worker} + self._messages_by_ids[exec_message.execution_id] = exec_message + + def _process_abort(self, msg: dict[str, Any] | ProcedureAbortMessage): + message = _resolve_dict(msg, ProcedureAbortMessage) + with self.lock: + if message.abort_all: + self._abort_all() + if message.queue is not None: + self._abort_queue(message.queue) + if message.execution_id is not None: + self._abort_execution(message.execution_id) + + def _abort_execution(self, execution_id: str): + if (msg := self._messages_by_ids.get(execution_id)) is None: + logger.warning(f"Procedure execution with ID {execution_id} not known.") + return + # Remove it from the queue if not yet started + if self._conn.lrem(MessageEndpoints.procedure_execution(msg.queue), 0, msg) > 0: + logger.debug(f"Removed execution {msg} from queue.") + self._helper.notify_watchers(msg.queue, "execution") + # Otherwise try to remove it from whichever worker has it + for entry in self._active_workers.values(): + if (worker := entry["worker"]) is not None and worker: + worker.abort_execution(execution_id) + # Move it to unhandled and stop tracking + self._helper.push.unhandled(msg.queue, msg) + del self._messages_by_ids[execution_id] + + def _abort_queue(self, queue: str): + self._helper.move.execution_queue_to_unhandled(queue) + if (entry := self._active_workers.get(queue)) is not None: + if entry["worker"] is not None: + entry["worker"].abort() + entry["future"].cancel() + futures.wait((entry["future"],), PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S) + else: + logger.warning(f"Received abort request for unknown queue {queue}!") + self._helper.notify_watchers(queue, "execution") + + def _abort_all(self): + for entry in self._active_workers.values(): + if entry["worker"] is not None: + entry["worker"].abort() + for entry in self._active_workers.values(): + entry["future"].cancel() + self._helper.move.all_execution_queues_to_unhandled() + self._wait_for_all_futures() + + def _process_clear_unhandled(self, msg: dict[str, Any] | ProcedureClearUnhandledMessage): + message = _resolve_dict(msg, ProcedureClearUnhandledMessage) + with self.lock: + if message.abort_all: + self._helper.clear.all_unhandled() + if message.queue is not None: + self._helper.clear.unhandled_queue(message.queue) + if message.execution_id is not None: + self._helper.clear.unhandled_execution(message.execution_id) + + def _wait_for_all_futures(self): + futures.wait( + (entry["future"] for entry in self._active_workers.values()), + timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S, + ) def spawn(self, queue: str): """Spawn a procedure worker future which listens to a given queue, i.e. procedure queue list in Redis. @@ -135,32 +244,41 @@ def spawn(self, queue: str): Args: queue (str): name of the queue to spawn a worker for""" - if queue in self.active_workers and self.active_workers[queue]["worker"] is not None: + if queue in self._active_workers and self._active_workers[queue]["worker"] is not None: raise WorkerAlreadyExists( - f"Queue {queue} already has an active worker in {self.active_workers}!" + f"Queue {queue} already has an active worker in {self._active_workers}!" ) with self._worker_cls(self._server, queue, PROCEDURE.WORKER.QUEUE_TIMEOUT_S) as worker: with self.lock: - self.active_workers[queue]["worker"] = worker + self._active_workers[queue]["worker"] = worker worker.work() + self._logs.extend(worker.logs()) def shutdown(self): """Shutdown the procedure manager. Unregisters from the request endpoint, cancel any procedure workers which haven't started, and abort any which have.""" self._conn.unregister( - MessageEndpoints.procedure_request(), None, self.process_queue_request + MessageEndpoints.procedure_request(), None, self._process_queue_request ) self._conn.shutdown() # cancel futures by hand to give us the opportunity to detatch them from redis if they have started - for entry in self.active_workers.values(): + for entry in self._active_workers.values(): cancelled = entry["future"].cancel() if not cancelled: # unblock any waiting workers and let them shutdown if worker := entry["worker"]: # redis unblock executor.client_id worker.abort() - futures.wait( - (entry["future"] for entry in self.active_workers.values()), - timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S, - ) + self._wait_for_all_futures() self.executor.shutdown() + + def active_workers(self) -> list[str]: + with self.lock: + return list(self._active_workers.keys()) + + def worker_statuses(self) -> dict[str, ProcedureWorkerStatus]: + with self.lock: + return { + q: w["worker"].status if w["worker"] is not None else ProcedureWorkerStatus.NONE + for q, w in self._active_workers.items() + } diff --git a/bec_server/bec_server/scan_server/procedures/procedure_registry.py b/bec_server/bec_server/scan_server/procedures/procedure_registry.py index 8a2582f56..4ed3ff114 100644 --- a/bec_server/bec_server/scan_server/procedures/procedure_registry.py +++ b/bec_server/bec_server/scan_server/procedures/procedure_registry.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Iterable, Iterator +from typing import Iterable from bec_lib.messages import ProcedureExecutionMessage from bec_server.scan_server.procedures.builtin_procedures import ( log_message_args_kwargs, run_scan, + run_script, sleep, ) from bec_server.scan_server.procedures.constants import BecProcedure @@ -12,6 +13,7 @@ "log execution message args": log_message_args_kwargs, "run scan": run_scan, "sleep": sleep, + "run_script": run_script, } _PROCEDURE_REGISTRY: dict[str, BecProcedure] = {} | _BUILTIN_PROCEDURES diff --git a/bec_server/bec_server/scan_server/procedures/protocol.py b/bec_server/bec_server/scan_server/procedures/protocol.py index 6d79b80a6..06d40c818 100644 --- a/bec_server/bec_server/scan_server/procedures/protocol.py +++ b/bec_server/bec_server/scan_server/procedures/protocol.py @@ -11,7 +11,6 @@ class VolumeSpec(TypedDict): class ContainerCommandOutput(Protocol): - def pretty_print(self) -> str: ... @@ -30,6 +29,7 @@ def run( volumes: list[VolumeSpec], command: str, pod_name: str | None = None, + container_name: str | None = None, ) -> str: ... def kill(self, id: str): ... def logs(self, id: str) -> list[str]: ... diff --git a/bec_server/bec_server/scan_server/procedures/worker_base.py b/bec_server/bec_server/scan_server/procedures/worker_base.py index a18b0dcf8..7b07256db 100644 --- a/bec_server/bec_server/scan_server/procedures/worker_base.py +++ b/bec_server/bec_server/scan_server/procedures/worker_base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from enum import Enum, auto +from threading import Event from typing import cast from bec_lib.endpoints import MessageEndpoints @@ -9,6 +9,7 @@ from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus from bec_lib.redis_connector import RedisConnector from bec_server.scan_server.procedures.constants import PROCEDURE +from bec_server.scan_server.procedures.helper import BackendProcedureHelper logger = bec_logger.logger @@ -36,8 +37,11 @@ def __init__(self, server: str, queue: str, lifetime_s: float | None = None): self._active_procs_endpoint = MessageEndpoints.active_procedure_executions() self.status = ProcedureWorkerStatus.IDLE self._conn = RedisConnector([server]) + self._helper = BackendProcedureHelper(self._conn) self._lifetime_s = lifetime_s or PROCEDURE.WORKER.QUEUE_TIMEOUT_S self.client_id = self._conn.client_id() + self._current_execution_id: str | None = None + self._aborted = Event() self._setup_execution_environment() @@ -61,9 +65,19 @@ def _run_task(self, item: ProcedureExecutionMessage): @abstractmethod def _setup_execution_environment(self): ... + def logs(self) -> list[str]: + return [""] + def abort(self): + """Abort the entire worker""" + self._aborted.set() self._kill_process() + @abstractmethod + def abort_execution(self, execution_id: str): + """Abort the execution with the given id""" + ... + def __exit__(self, exc_type, exc_val, exc_tb): self._kill_process() @@ -76,9 +90,10 @@ def work(self): self.key, self._active_procs_endpoint, timeout_s=self._lifetime_s ) ) is not None: - self.status = ProcedureWorkerStatus.RUNNING - self._run_task(cast(ProcedureExecutionMessage, item)) - self.status = ProcedureWorkerStatus.IDLE + if not self._aborted.is_set(): + self.status = ProcedureWorkerStatus.RUNNING + self._run_task(cast(ProcedureExecutionMessage, item)) + self.status = ProcedureWorkerStatus.IDLE except Exception as e: logger.error(e) # don't stop ProcedureManager.spawn from cleaning up finally: diff --git a/bec_server/tests/tests_scan_server/test_container_utils.py b/bec_server/tests/tests_scan_server/test_container_utils.py index 55fb81408..31c948f6b 100644 --- a/bec_server/tests/tests_scan_server/test_container_utils.py +++ b/bec_server/tests/tests_scan_server/test_container_utils.py @@ -138,6 +138,7 @@ def test_api_utils_run(api_utils: tuple[PodmanApiUtils, MagicMock]): environment={"a": "b"}, mounts=[{"source": "a", "target": "b", "read_only": True, "type": "bind"}], pod=None, + name=None, ) @@ -185,7 +186,7 @@ def test_cli_kill(cli_utils: tuple[PodmanCliUtils, MagicMock]): utils, run_mock = cli_utils run_mock.reset_mock() utils.kill("test") - assert run_mock.call_count == 2 + assert run_mock.call_count == 1 def test_cli_get_state(cli_utils_with_fake_container_json: PodmanCliUtils): diff --git a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py index b80ffabc9..5b686a04a 100644 --- a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py +++ b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py @@ -29,7 +29,7 @@ def test_container_worker_init(redis_mock): redis_mock().port = "port" worker = ContainerProcedureWorker(server="server:port", queue="test_queue", lifetime_s=1) assert worker._worker_environment() == { - "redis_server": "server:port", + "redis_server": "redis:port", "queue": "test_queue", "timeout_s": "1", } @@ -45,20 +45,20 @@ def test_container_worker_work(redis_mock, logger_mock): redis_mock().host = "server" redis_mock().port = "port" + msgs = [ + ProcedureWorkerStatusMessage( + worker_queue="test_queue", + status=ProcedureWorkerStatus.RUNNING, + current_execution_id="test", + ), + ProcedureWorkerStatusMessage( + worker_queue="test_queue", status=ProcedureWorkerStatus.FINISHED + ), + ] + def _mock_pop(): - msgs = [ - ProcedureWorkerStatusMessage( - worker_queue="test_queue", status=ProcedureWorkerStatus.RUNNING - ), - ProcedureWorkerStatusMessage( - worker_queue="test_queue", status=ProcedureWorkerStatus.FINISHED - ), - ] - for msg in msgs: - sleep(0.05) - yield msg + yield from msgs while True: - sleep(0.05) yield from repeat(None) mock_pop = _mock_pop() @@ -74,7 +74,7 @@ def cleanup(): t.start() start = time.monotonic() - while time.monotonic() < start + 5: + while time.monotonic() < start + 1000: try: assert ( logger_mock.info.call_args_list[0].args[0] @@ -100,7 +100,7 @@ def cleanup(): @patch("bec_server.scan_server.procedures.container_worker.logger") def test_main_exits_without_env_variables(logger_mock): - with patch.dict(os.environ, clear=True): + with patch.dict(os.environ, clear=True), pytest.raises(SystemExit): container_worker_main() assert "Missing environment variable " in logger_mock.error.call_args.args[0] @@ -125,6 +125,7 @@ def __init__(self, name: str, args: tuple = (), kwargs: dict = {}): self.identifier = name self.args = args self.kwargs = kwargs + self.execution_id = f"test_{name}" def __repr__(self) -> str: return self.identifier diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py index 0fda944c9..7775e5281 100644 --- a/bec_server/tests/tests_scan_server/test_procedures.py +++ b/bec_server/tests/tests_scan_server/test_procedures.py @@ -1,6 +1,7 @@ import threading import time from functools import partial +from itertools import starmap from typing import Any, Callable from unittest.mock import MagicMock, patch @@ -8,6 +9,7 @@ import pytest from bec_lib.client import BECClient, RedisConnector +from bec_lib.endpoints import MessageEndpoints from bec_lib.messages import ( ProcedureExecutionMessage, ProcedureRequestMessage, @@ -15,6 +17,7 @@ RequestResponseMessage, ) from bec_lib.serialization import MsgpackSerialization +from bec_lib.service_config import ServiceConfig from bec_server.scan_server.procedures.constants import PROCEDURE, BecProcedure, WorkerAlreadyExists from bec_server.scan_server.procedures.in_process_worker import InProcessProcedureWorker from bec_server.scan_server.procedures.manager import ProcedureManager, ProcedureWorker @@ -32,11 +35,17 @@ LOG_MSG_PROC_NAME = "log execution message args" +FAKEREDIS_HOST = "127.0.0.1" +FAKEREDIS_PORT = 6380 @pytest.fixture(autouse=True) +@patch("bec_lib.bec_service.BECAccess", MagicMock) def shutdown_client(): - bec_client = BECClient() + bec_client = BECClient( + config=ServiceConfig(config={"redis": {"host": FAKEREDIS_HOST, "port": FAKEREDIS_PORT}}), + connector_cls=partial(RedisConnector, redis_cls=fakeredis.FakeRedis), + ) bec_client.start() yield bec_client.shutdown() @@ -45,7 +54,7 @@ def shutdown_client(): @pytest.fixture def procedure_manager(): server = MagicMock() - server.bootstrap_server = "localhost:1" + server.bootstrap_server = f"{FAKEREDIS_HOST}:{FAKEREDIS_PORT}" with patch( "bec_server.scan_server.procedures.manager.RedisConnector", partial(RedisConnector, redis_cls=fakeredis.FakeRedis), # type: ignore @@ -104,7 +113,7 @@ def process_request_manager(procedure_manager: ProcedureManager): @pytest.mark.parametrize("message", PROCESS_REQUEST_TEST_CASES) def test_process_request_happy_paths(process_request_manager, message: ProcedureRequestMessage): - process_request_manager.process_queue_request(message) + process_request_manager._process_queue_request(message) process_request_manager._ack.assert_called_with(True, f"Running procedure {message.identifier}") process_request_manager._conn.rpush.assert_called() endpoint, execution_msg = process_request_manager._conn.rpush.call_args.args @@ -112,15 +121,15 @@ def test_process_request_happy_paths(process_request_manager, message: Procedure assert queue in endpoint.endpoint assert execution_msg.identifier == message.identifier process_request_manager.spawn.assert_called() - assert queue in process_request_manager.active_workers.keys() + assert queue in process_request_manager._active_workers.keys() def test_process_request_failure(process_request_manager): - process_request_manager.process_queue_request(None) + process_request_manager._process_queue_request(None) process_request_manager._ack.assert_not_called() process_request_manager._conn.rpush.assert_not_called() process_request_manager.spawn.assert_not_called() - assert process_request_manager.active_workers == {} + assert process_request_manager._active_workers == {} class UnlockableWorker(ProcedureWorker): @@ -131,6 +140,7 @@ def __init__(self, server: str, queue: str, lifetime_s: int | None = None): self.event_1 = threading.Event() self.event_2 = threading.Event() + def abort_execution(self, execution_id: str): ... def _setup_execution_environment(self): ... def _kill_process(self): ... def _run_task(self, item): @@ -161,14 +171,14 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): queue = message.queue or PROCEDURE.WORKER.DEFAULT_QUEUE procedure_manager._validate_request = MagicMock(side_effect=lambda msg: msg) # trigger the running of the test message - procedure_manager.process_queue_request(message) # type: ignore - assert queue in procedure_manager.active_workers.keys() + procedure_manager._process_queue_request(message) # type: ignore + assert queue in procedure_manager._active_workers.keys() # spawn method should be added as a future - _wait_until(procedure_manager.active_workers[queue]["future"].running) + _wait_until(procedure_manager._active_workers[queue]["future"].running) # and then create the worker - _wait_until(lambda: procedure_manager.active_workers[queue].get("worker") is not None) - worker = procedure_manager.active_workers[queue]["worker"] + _wait_until(lambda: procedure_manager._active_workers[queue].get("worker") is not None) + worker = procedure_manager._active_workers[queue]["worker"] assert isinstance(worker, UnlockableWorker) _wait_until(lambda: worker.status == ProcedureWorkerStatus.RUNNING) @@ -185,7 +195,7 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): worker.event_2.set() _wait_until(lambda: worker.status == ProcedureWorkerStatus.FINISHED) # spawn deletes the worker queue - _wait_until(lambda: len(procedure_manager.active_workers) == 0) + _wait_until(lambda: len(procedure_manager._active_workers) == 0) @patch("bec_server.scan_server.procedures.worker_base.RedisConnector", MagicMock()) @@ -259,7 +269,7 @@ def test_callable_from_message(): def test_register_rejects_wrong_type(): with pytest.raises(ProcedureRegistryError) as e: - register("test", "test") + register("test", "test") # type: ignore assert e.match("not a valid procedure") @@ -267,3 +277,162 @@ def test_register_rejects_already_registered(): with pytest.raises(ProcedureRegistryError) as e: register("run scan", lambda *_, **__: None) assert e.match("already registered") + + +def _yield_once(): + yield "value" + while True: + yield None + + +@patch( + "bec_server.scan_server.procedures.worker_base.RedisConnector", + side_effect=lambda *_: MagicMock( + blocking_list_pop_to_set_add=MagicMock(side_effect=_yield_once()) + ), +) +def test_manager_status_api(_conn, procedure_manager): + procedure_manager._worker_cls = UnlockableWorker + for message in PROCESS_REQUEST_TEST_CASES: + procedure_manager._process_queue_request(message) + _wait_until(lambda: procedure_manager.active_workers() == ["primary", "queue2"]) + _wait_until( + lambda: procedure_manager.worker_statuses() + == {"primary": ProcedureWorkerStatus.RUNNING, "queue2": ProcedureWorkerStatus.RUNNING} + ) + for w in procedure_manager._active_workers.values(): + w["worker"].event_1.set() + _wait_until( + lambda: procedure_manager.worker_statuses() + == {"primary": ProcedureWorkerStatus.IDLE, "queue2": ProcedureWorkerStatus.IDLE} + ) + for w in procedure_manager._active_workers.values(): + w["worker"].event_2.set() + _wait_until(lambda: procedure_manager.active_workers() == []) + + +_ManagerWithMsgs = tuple[ProcedureManager, list[ProcedureExecutionMessage]] + + +@pytest.fixture +def manager_with_test_msgs(procedure_manager: ProcedureManager): + procedure_manager._conn._redis_conn.flushdb() + contents = [ + ("test_identifier_1", "queue1", ((), {})), + ("test_identifier_2", "queue1", ((), {})), + ("test_identifier_1", "queue2", ((), {})), + ("test_identifier_2", "queue2", ((), {})), + ] + msgs = iter( + ProcedureRequestMessage(identifier=c[0], queue=c[1], args_kwargs=c[2]) for c in contents + ) + procedure_manager._validate_request = lambda msg: next(msgs) + for _ in range(len(contents)): + procedure_manager._process_queue_request({}) + return ( + procedure_manager, + [ + ProcedureExecutionMessage(metadata={}, identifier=c[0], queue=c[1], args_kwargs=c[2]) + for c in contents + ], + ) + + +def _eq_except_id(a: ProcedureExecutionMessage, b: ProcedureExecutionMessage): + return a.identifier == b.identifier and a.queue == b.queue and a.args_kwargs == b.args_kwargs + + +def _all_eq_except_id(a: list[ProcedureExecutionMessage], b: list[ProcedureExecutionMessage]): + if len(a) != len(b): + return False + return all(starmap(_eq_except_id, zip(a, b))) + + +@pytest.mark.parametrize("queue", ["queue1", "queue2"]) +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_startup(manager_with_test_msgs: _ManagerWithMsgs, queue: str): + procedure_manager, expected = manager_with_test_msgs + queue_expected = list(filter(lambda msg: msg.queue == queue, expected)) + + execution_list = procedure_manager._helper.get.exec_queue(queue) + assert _all_eq_except_id(execution_list, queue_expected) + + procedure_manager._startup() + + # on startup, the manager should move active queues to unhandled queues + execution_list = procedure_manager._helper.get.exec_queue(queue) + unhandled_execution_list = procedure_manager._conn.lrange( + MessageEndpoints.unhandled_procedure_execution(queue), 0, -1 + ) + assert execution_list == [] + assert _all_eq_except_id(unhandled_execution_list, queue_expected) + + +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_abort_queue(manager_with_test_msgs: _ManagerWithMsgs): + procedure_manager, expected = manager_with_test_msgs + remaining_expected = list(filter(lambda msg: msg.queue == "queue2", expected)) + aborted_expected = list(filter(lambda msg: msg.queue == "queue1", expected)) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + assert _all_eq_except_id(q1_execution_list, aborted_expected) + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q2_execution_list, remaining_expected) + + procedure_manager._process_abort({"queue": "queue1"}) + + # on abort, the manager should move active queues to unhandled queues + # this should happen for q1 and not q2 + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + unhandled_execution_list = procedure_manager._conn.lrange( + MessageEndpoints.unhandled_procedure_execution("queue1"), 0, -1 + ) + assert _all_eq_except_id(q2_execution_list, remaining_expected) + assert _all_eq_except_id(unhandled_execution_list, aborted_expected) + + +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_abort_individual(manager_with_test_msgs: _ManagerWithMsgs): + procedure_manager, expected = manager_with_test_msgs + q1_expected = list(filter(lambda msg: msg.queue == "queue1", expected)) + q2_expected = list(filter(lambda msg: msg.queue == "queue2", expected)) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + assert _all_eq_except_id( + q1_execution_list, list(filter(lambda msg: msg.queue == "queue1", q1_expected)) + ) + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q2_execution_list, q2_expected) + + procedure_manager._process_abort({"execution_id": q2_execution_list[1].execution_id}) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q1_execution_list, q1_expected) + assert _all_eq_except_id(q2_execution_list, [q2_expected[0]]) + + +@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock()) +def test_abort_all(manager_with_test_msgs: _ManagerWithMsgs): + procedure_manager, expected = manager_with_test_msgs + q1_expected = list(filter(lambda msg: msg.queue == "queue1", expected)) + q2_expected = list(filter(lambda msg: msg.queue == "queue2", expected)) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + assert _all_eq_except_id( + q1_execution_list, list(filter(lambda msg: msg.queue == "queue1", q1_expected)) + ) + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + assert _all_eq_except_id(q2_execution_list, q2_expected) + + procedure_manager._process_abort({"abort_all": True}) + + q1_execution_list = procedure_manager._helper.get.exec_queue("queue1") + q2_execution_list = procedure_manager._helper.get.exec_queue("queue2") + q1_unhandled_list = procedure_manager._helper.get.unhandled_queue("queue1") + q2_unhandled_list = procedure_manager._helper.get.unhandled_queue("queue2") + + assert q1_execution_list == [] + assert q2_execution_list == [] + assert _all_eq_except_id(q1_unhandled_list, q1_expected) + assert _all_eq_except_id(q2_unhandled_list, q2_expected)