diff --git a/.github/actions/bec_e2e_install/action.yml b/.github/actions/bec_e2e_install/action.yml
index 8a0c3bd61..8f93f393b 100644
--- a/.github/actions/bec_e2e_install/action.yml
+++ b/.github/actions/bec_e2e_install/action.yml
@@ -62,7 +62,7 @@ runs:
cd ./_e2e_test_checkout_/bec
source ./bin/install_bec_dev.sh -t
pip install -e ../ophyd_devices
- podman pod create --net host local_bec
+ podman pod create --network=host local_bec
python ./bec_ipython_client/tests/end-2-end/_ensure_requirements_container.py
pytest -v --files-path ./ --start-servers --random-order ./bec_ipython_client/tests/end-2-end/
diff --git a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py
index 409c11f92..48afa5550 100644
--- a/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py
+++ b/bec_ipython_client/tests/end-2-end/test_procedures_e2e.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import time
+from dataclasses import dataclass
from importlib.metadata import version
from typing import TYPE_CHECKING, Callable, Generator
from unittest.mock import MagicMock, patch
@@ -11,7 +12,7 @@
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
-from bec_server.scan_server.procedures.constants import PROCEDURE
+from bec_server.scan_server.procedures.constants import _CONTAINER, _WORKER
from bec_server.scan_server.procedures.container_utils import get_backend
from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker
from bec_server.scan_server.procedures.manager import ProcedureManager
@@ -28,6 +29,15 @@
pytestmark = pytest.mark.random_order(disabled=True)
+@dataclass(frozen=True)
+class PATCHED_CONSTANTS:
+ WORKER = _WORKER()
+ CONTAINER = _CONTAINER()
+ MANAGER_SHUTDOWN_TIMEOUT_S = 2
+ BEC_VERSION = version("bec_lib")
+ REDIS_HOST = "localhost"
+
+
@pytest.fixture
def client_logtool_and_manager(
bec_ipython_client_fixture_with_logtool: tuple[BECIPythonClient, "LogTestTool"],
@@ -52,7 +62,7 @@ def _wait_while(cond: Callable[[], bool], timeout_s):
def test_building_worker_image():
podman_utils = get_backend()
build = podman_utils.build_worker_image()
- assert len(build._command_output.splitlines()[-1]) == 64
+ assert len(build._command_output.splitlines()[-1]) == 64 # type: ignore
assert podman_utils.image_exists(f"bec_procedure_worker:v{version('bec_lib')}")
@@ -62,7 +72,7 @@ def test_procedure_runner_spawns_worker(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
client, _, manager = client_logtool_and_manager
- assert manager.active_workers == {}
+ assert manager._active_workers == {}
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
identifier="sleep", args_kwargs=((), {"time_s": 2}), queue="test"
@@ -77,21 +87,22 @@ def cb(worker: ContainerProcedureWorker):
manager.add_callback("test", cb)
client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump())
- _wait_while(lambda: manager.active_workers == {}, 5)
- _wait_while(lambda: manager.active_workers != {}, 20)
+ _wait_while(lambda: manager._active_workers == {}, 5)
+ _wait_while(lambda: manager._active_workers != {}, 20)
assert logs != []
@pytest.mark.timeout(100)
@patch("bec_server.scan_server.procedures.manager.procedure_registry.is_registered", lambda _: True)
+@patch("bec_server.scan_server.procedures.container_worker.PROCEDURE", PATCHED_CONSTANTS())
def test_happy_path_container_procedure_runner(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
test_args = (1, 2, 3)
test_kwargs = {"a": "b", "c": "d"}
client, logtool, manager = client_logtool_and_manager
- assert manager.active_workers == {}
+ assert manager._active_workers == {}
conn = client.connector
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
@@ -99,12 +110,15 @@ def test_happy_path_container_procedure_runner(
)
conn.xadd(topic=endpoint, msg_dict=msg.model_dump())
- _wait_while(lambda: manager.active_workers == {}, 5)
- _wait_while(lambda: manager.active_workers != {}, 20)
+ _wait_while(lambda: manager._active_workers == {}, 5)
+ _wait_while(lambda: manager._active_workers != {}, 20)
logtool.fetch()
assert logtool.is_present_in_any_message("procedure accepted: True, message:")
- assert logtool.is_present_in_any_message("ContainerWorker started container for queue primary")
+ assert logtool.is_present_in_any_message(
+ "ContainerWorker started container for queue primary"
+ ), f"Log content relating to procedures: {manager._logs}"
+
res, msg = logtool.are_present_in_order(
[
"Container worker 'primary' status update: IDLE",
@@ -114,12 +128,7 @@ def test_happy_path_container_procedure_runner(
]
)
assert res, f"failed on {msg}"
- res, msg = logtool.are_present_in_order(
- [
- "Container worker 'primary' status update: IDLE",
- f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}",
- "Container worker 'primary' status update: IDLE",
- "Container worker 'primary' status update: FINISHED",
- ]
+
+ assert logtool.is_present_in_any_message(
+ f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}"
)
- assert res, f"failed on {msg}"
diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py
index cc4f56b90..fc11eb572 100644
--- a/bec_lib/bec_lib/endpoints.py
+++ b/bec_lib/bec_lib/endpoints.py
@@ -41,9 +41,9 @@ class MessageOp(list[str], enum.Enum):
SET_PUBLISH = ["register", "set_and_publish", "delete", "get", "keys"]
SEND = ["send", "register"]
STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last", "delete"]
- LIST = ["lpush", "lrange", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"]
+ LIST = ["lpush", "lrange", "lrem", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"]
KEY_VALUE = ["set", "get", "delete", "keys"]
- SET = ["remove_from_set", "get_set_members"]
+ SET = ["remove_from_set", "get_set_members", "delete"]
MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]")
@@ -1420,8 +1420,8 @@ def available_procedures() -> EndpointInfo:
@staticmethod
def procedure_request() -> EndpointInfo:
"""
- Endpoint for scan queue request. This endpoint is used to request the new scans.
- The request is sent using a messages.ScanQueueMessage message.
+ Endpoint for requesting new procedures.
+ The request is sent using a messages.ProcedureRequestMessage message.
Returns:
EndpointInfo: Endpoint for scan queue request.
@@ -1460,7 +1460,24 @@ def procedure_execution(queue_id: str):
Returns:
EndpointInfo: Endpoint for scan queue request.
"""
- endpoint = f"{EndpointType.INTERNAL.value}/procedures/procedure_execution/{queue_id}"
+ endpoint = f"{EndpointType.INFO.value}/procedures/procedure_execution/{queue_id}"
+ return EndpointInfo(
+ endpoint=endpoint,
+ message_type=messages.ProcedureExecutionMessage,
+ message_op=MessageOp.LIST,
+ )
+
+ @staticmethod
+ def unhandled_procedure_execution(queue_id: str):
+ """
+ Endpoint for procedure executions which were pending when the manager was shutdown.
+ Messages from procedure_execution are moved here on manager startup.
+ The request is sent using a messages.ProcedureExecutionMessage message.
+
+ Returns:
+ EndpointInfo: Endpoint for scan queue request.
+ """
+ endpoint = f"{EndpointType.INFO.value}/procedures/unhandled_procedure_execution/{queue_id}"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureExecutionMessage,
@@ -1483,6 +1500,36 @@ def active_procedure_executions():
message_op=MessageOp.SET,
)
+ @staticmethod
+ def procedure_abort():
+ """
+ Endpoint to request aborting a running procedure
+
+ Returns:
+ EndpointInfo: Endpoint to request procedure abortion.
+ """
+ endpoint = f"{EndpointType.USER.value}/procedures/abort"
+ return EndpointInfo(
+ endpoint=endpoint,
+ message_type=messages.ProcedureAbortMessage,
+ message_op=MessageOp.STREAM,
+ )
+
+ @staticmethod
+ def procedure_clear_unhandled():
+ """
+ Endpoint to request removing an aborted procedure
+
+ Returns:
+ EndpointInfo: Endpoint to request removing aborted procedures.
+ """
+ endpoint = f"{EndpointType.USER.value}/procedures/clear_unhandled"
+ return EndpointInfo(
+ endpoint=endpoint,
+ message_type=messages.ProcedureClearUnhandledMessage,
+ message_op=MessageOp.STREAM,
+ )
+
@staticmethod
def procedure_worker_status_update(queue_id: str):
"""
@@ -1498,6 +1545,34 @@ def procedure_worker_status_update(queue_id: str):
message_op=MessageOp.LIST,
)
+ @staticmethod
+ def procedure_queue_notif():
+ """
+ PubSub channel for a consumer (e.g. BEC widgets) to be notified of changes to a procedure queue
+
+ Returns:
+ EndpointInfo: Endpoint for procedure queue updates for given queue.
+ """
+ endpoint = f"{EndpointType.INFO.value}/procedures/queue_notif"
+ return EndpointInfo(
+ endpoint=endpoint,
+ message_type=messages.ProcedureQNotifMessage,
+ message_op=MessageOp.SEND,
+ )
+
+ @staticmethod
+ def procedure_logs(queue: str):
+ """
+ Endpoint for logs for a given procedure queue
+
+ Returns:
+ EndpointInfo: Endpoint for procedure queue updates for given queue.
+ """
+ endpoint = f"{EndpointType.INFO.value}/procedures/logs/{queue}"
+ return EndpointInfo(
+ endpoint=endpoint, message_type=messages.RawMessage, message_op=MessageOp.STREAM
+ )
+
@staticmethod
def gui_registry_state(gui_id: str):
"""
diff --git a/bec_lib/bec_lib/logger.py b/bec_lib/bec_lib/logger.py
index 3e02f305a..201efb3a8 100644
--- a/bec_lib/bec_lib/logger.py
+++ b/bec_lib/bec_lib/logger.py
@@ -64,6 +64,7 @@ class BECLogger:
"{service_name} | {{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} |"
" {{thread.name}} ({{thread.id}}) | {{extra[stack]}} - {{message}}\n"
)
+ CONTAINER_FORMAT = "{{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} | {{message}}\n"
LOGLEVEL = LogLevel
_logger = None
@@ -224,18 +225,21 @@ def _logger_callback(self, msg):
# because it depends on the connector
pass
- def get_format(self, level: LogLevel = None, is_stderr=False) -> str:
+ def get_format(self, level: LogLevel = None, is_stderr=False, is_container=False) -> str:
"""
Get the format for a specific log level.
Args:
level (LogLevel, optional): Log level. Defaults to None. If None, the current log level will be used.
is_stderr (bool, optional): Whether the log is for stderr. Defaults to False.
+ is_container (bool, optional): Simple logging for procedure container. Defaults to False.
Returns:
str: Log format.
"""
service_name = self.service_name if self.service_name else ""
+ if is_container:
+ return self.CONTAINER_FORMAT.format()
if level is None:
level = self.level
if level > self.LOGLEVEL.DEBUG:
@@ -246,15 +250,16 @@ def get_format(self, level: LogLevel = None, is_stderr=False) -> str:
return self.DEBUG_FORMAT.format(service_name=service_name)
return self.TRACE_FORMAT.format(service_name=service_name)
- def formatting(self, is_stderr=False):
+ def formatting(self, is_stderr=False, is_container=False):
"""
Format the log message.
Args:
record (dict): Log record.
+ is_container (bool, optional): Simple logging for procedure container. Defaults to False.
Returns:
- dict: Formatted log record.
+ str: Log format.
"""
def _update_record(record):
@@ -269,7 +274,7 @@ def _update_record(record):
def _format(record):
level = _update_record(record)
- return self.get_format(level)
+ return self.get_format(level, is_container=is_container)
def _format_stderr(record):
level = _update_record(record)
diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py
index 28c12943b..39bdc51c9 100644
--- a/bec_lib/bec_lib/messages.py
+++ b/bec_lib/bec_lib/messages.py
@@ -5,7 +5,8 @@
import warnings
from copy import deepcopy
from enum import Enum, auto
-from typing import Any, ClassVar, Literal
+from typing import Any, ClassVar, Literal, Self
+from uuid import uuid4
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
@@ -19,6 +20,7 @@ class ProcedureWorkerStatus(Enum):
IDLE = auto()
FINISHED = auto()
DEAD = auto() # worker lost communication with the container
+ NONE = auto() # worker doesn't exist in manager, caught during creation and cleanup
class BECStatus(Enum):
@@ -1212,18 +1214,67 @@ class ProcedureRequestMessage(BECMessage):
queue: str | None = None
+class ProcedureQNotifMessage(BECMessage):
+ """Message type for notifying watchers of changes to queues"""
+
+ msg_type: ClassVar[str] = "procedure_queue_notif_message"
+ queue_name: str
+ queue_type: Literal["execution", "unhandled"]
+
+
class ProcedureExecutionMessage(BECMessage):
"""Message type for sending procedure execution instructions to the scheduler
Sent by the user to the procedure_request topic. It will be consumed by the scan server.
Args:
identifier (str): name of the procedure registered with the server
+ queue (str): the procedure queue this execution belongs to
+ args_kwargs (tuple[tuple[Any, ...], dict[str, Any]]): arguments for the procedure function
"""
msg_type: ClassVar[str] = "procedure_execution_message"
identifier: str
queue: str
args_kwargs: tuple[tuple[Any, ...], dict[str, Any]] = (), {}
+ execution_id: str = Field(default_factory=lambda: str(uuid4()))
+
+
+class ProcedureAbortMessage(BECMessage):
+ """Message type to request aborting a procedure or procedure queue
+
+ One and only one of the args should be supplied.
+ Args:
+ queue (str | None): the procedure queue to abort
+ execution_id (str | None): the procedure execution to abort
+ abort_all (bool | None): abort all procedures if true
+ """
+
+ msg_type: ClassVar[str] = "procedure_abort_message"
+ queue: str | None = None
+ execution_id: str | None = None
+ abort_all: bool | None = None
+
+ @model_validator(mode="after")
+ def mutually_exclusive(self) -> Self:
+ if (self.queue, self.execution_id, self.abort_all).count(None) != 2:
+ raise ValueError(
+ "Please only supply one argument! Supplied: \n"
+ f" {self.queue=}, {self.execution_id=}, {self.abort_all=}"
+ )
+ return self
+
+
+class ProcedureClearUnhandledMessage(ProcedureAbortMessage):
+ """Message type to request clearing an unhandled procedure or procedure queue
+
+ One and only one of the args should be supplied.
+ Args:
+ queue (str | None): the procedure queue to abort
+ execution_id (str | None): the procedure execution to abort
+ abort_all (bool | None): abort all procedures if true
+ """
+
+ ...
class ProcedureWorkerStatusMessage(BECMessage):
@@ -1232,12 +1283,21 @@ class ProcedureWorkerStatusMessage(BECMessage):
Args:
worker_queue (str): Worker queue ID
status (str): Worker status
-
+ current_execution_id (str | None): ID of the current job, only allowed for RUNNING
"""
msg_type: ClassVar[str] = "procedure_worker_status_message"
worker_queue: str
status: ProcedureWorkerStatus
+ current_execution_id: str | None = None
+
+ @model_validator(mode="after")
+ def check_id(self) -> Self:
+ if self.current_execution_id is not None and self.status != ProcedureWorkerStatus.RUNNING:
+ raise ValueError("Adding an execution ID is only valid for the RUNNING status")
+ if self.current_execution_id is None and self.status == ProcedureWorkerStatus.RUNNING:
+ raise ValueError("Adding an execution ID is mandatory for the RUNNING status")
+ return self
class LoginInfoMessage(BECMessage):
diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py
index 115d7b820..8d8764f1e 100644
--- a/bec_lib/bec_lib/redis_connector.py
+++ b/bec_lib/bec_lib/redis_connector.py
@@ -1105,6 +1105,23 @@ def lrange(self, topic: str, start: int, end: int, pipe: Pipeline | None = None)
ret.append(msg)
return ret
+ @validate_endpoint("topic")
+ def lrem(self, topic: str, count: int, msg, pipe: Pipeline | None = None):
+ """Removes the first count occurrences of elements equal to element from the list stored at key.
+ The count argument influences the operation in the following ways:
+ count > 0: Remove elements equal to element moving from head to tail.
+ count < 0: Remove elements equal to element moving from tail to head.
+ count = 0: Remove all elements equal to element.
+ For example, LREM list -2 "hello" will remove the last two occurrences of "hello" in the list stored at list.
+
+ Returns the number of items removed
+ """
+
+ client = pipe if pipe is not None else self._redis_conn
+ if isinstance(msg, BECMessage):
+ msg = MsgpackSerialization.dumps(msg)
+ return client.lrem(topic, count, msg)
+
@validate_endpoint("topic")
def set_and_publish(
self, topic: str, msg, pipe: Pipeline | None = None, expire: int | None = None
@@ -1172,7 +1189,7 @@ def xadd(
Args:
topic (str): redis topic
- msg_dict (dict): message to add
+ msg_dict (dict | BECMessage): message to add
max_size (int, optional): max size of stream. Defaults to None.
pipe (Pipeline, optional): redis pipe. Defaults to None.
expire (int, optional): expire time. Defaults to None.
diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py
index e94ba9196..a1a7cf2d4 100644
--- a/bec_lib/tests/test_bec_messages.py
+++ b/bec_lib/tests/test_bec_messages.py
@@ -243,6 +243,31 @@ def test_ProcedureWorkerStatusMessage():
assert res_loaded == msg
+def test_ProcedureWorkerStatusMessage_validation():
+ with pytest.raises(pydantic.ValidationError) as e:
+ messages.ProcedureWorkerStatusMessage(
+ worker_queue="background tasks",
+ status=messages.ProcedureWorkerStatus.RUNNING,
+ metadata={"RID": "1234"},
+ )
+ assert e.match("Adding an execution ID is mandatory")
+ with pytest.raises(pydantic.ValidationError) as e:
+ messages.ProcedureWorkerStatusMessage(
+ worker_queue="background tasks",
+ status=messages.ProcedureWorkerStatus.IDLE,
+ metadata={"RID": "1234"},
+ current_execution_id="test",
+ )
+ assert e.match("Adding an execution ID is only valid")
+
+
+def test_ProcedureAbortMessage_validation():
+ with pytest.raises(pydantic.ValidationError) as e:
+ messages.ProcedureAbortMessage(queue="test", execution_id="test")
+ assert e.match("only supply one argument")
+ messages.ProcedureAbortMessage(queue="test")
+
+
def test_FileMessage():
msg = messages.FileMessage(
device_name="samx",
diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py
index beb4173d8..777d8b300 100644
--- a/bec_lib/tests/test_redis_connector_fakeredis.py
+++ b/bec_lib/tests/test_redis_connector_fakeredis.py
@@ -568,3 +568,17 @@ def test_redis_connector_message_alternated_pass(connected_connector):
assert msg_received == [data_to_check]
assert msg_received == [data_original]
+
+
+def test_lrem(connected_connector: RedisConnector):
+ conn, ep = connected_connector, MessageEndpoints.procedure_execution
+ msgs = [
+ messages.ProcedureExecutionMessage(identifier=_id, queue="primary", args_kwargs=((), {}))
+ for _id in ("a", "b", "c")
+ ]
+ for msg in msgs:
+ conn.rpush(ep(msg.queue), msg)
+ assert len(conn.lrange(ep("primary"), 0, -1)) == 3
+ conn.lrem(ep(msgs[1].queue), 0, msgs[1])
+ list_contents = conn.lrange(ep("primary"), 0, -1)
+ assert list_contents == [msgs[0], msgs[2]]
diff --git a/bec_server/bec_server/scan_server/procedures/_dev_runner.py b/bec_server/bec_server/scan_server/procedures/_dev_runner.py
index e87ae4b78..b95db1c27 100644
--- a/bec_server/bec_server/scan_server/procedures/_dev_runner.py
+++ b/bec_server/bec_server/scan_server/procedures/_dev_runner.py
@@ -7,7 +7,6 @@
from bec_lib.logger import bec_logger
from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker
- from bec_server.scan_server.procedures.in_process_worker import InProcessProcedureWorker
from bec_server.scan_server.procedures.manager import ProcedureManager
logger = bec_logger.logger
diff --git a/bec_server/bec_server/scan_server/procedures/builtin_procedures.py b/bec_server/bec_server/scan_server/procedures/builtin_procedures.py
index 863b76544..cf708797e 100644
--- a/bec_server/bec_server/scan_server/procedures/builtin_procedures.py
+++ b/bec_server/bec_server/scan_server/procedures/builtin_procedures.py
@@ -31,3 +31,7 @@ def run_scan(scan_name: str, args: tuple, parameters: dict, *, bec: BECClient):
raise ValueError(f"Scan {scan_name} doesn't exist in this client!")
scan_report: ScanReport = scan(*args, **parameters)
scan_report.wait()
+
+
+def run_script(script_id: str, *, bec: BECClient):
+ bec._run_script(script_id)
diff --git a/bec_server/bec_server/scan_server/procedures/constants.py b/bec_server/bec_server/scan_server/procedures/constants.py
index 376266275..ed5d20cdd 100644
--- a/bec_server/bec_server/scan_server/procedures/constants.py
+++ b/bec_server/bec_server/scan_server/procedures/constants.py
@@ -67,6 +67,7 @@ class _PROCEDURE:
CONTAINER = _CONTAINER()
MANAGER_SHUTDOWN_TIMEOUT_S = 2
BEC_VERSION = version("bec_lib")
+ REDIS_HOST = "redis"
PROCEDURE = _PROCEDURE()
@@ -77,4 +78,5 @@ class PodmanContainerStates(str, Enum):
RUNNING = "running"
PAUSED = "paused"
STOPPED = "stopped"
+ STOPPING = "stopping"
EXITED = "exited"
diff --git a/bec_server/bec_server/scan_server/procedures/container_utils.py b/bec_server/bec_server/scan_server/procedures/container_utils.py
index 83e23c31e..766c11901 100644
--- a/bec_server/bec_server/scan_server/procedures/container_utils.py
+++ b/bec_server/bec_server/scan_server/procedures/container_utils.py
@@ -96,6 +96,7 @@ def run(
volumes: list[VolumeSpec],
command: str,
pod_name: str | None = None,
+ container_name: str | None = None,
) -> str:
with PodmanClient(base_url=self.uri) as client:
try:
@@ -106,6 +107,7 @@ def run(
environment=environment,
mounts=volumes,
pod=pod_name,
+ name=container_name,
) # type: ignore # running with detach returns container object
except APIError as e:
if e.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
@@ -151,8 +153,9 @@ def pretty_print(self) -> str:
class PodmanCliUtils(_PodmanUtilsBase):
- def _run_and_capture_error(self, *args: str):
- logger.debug(f"Running {args}")
+ def _run_and_capture_error(self, *args: str, log: bool = True):
+ if log:
+ logger.debug(f"Running {args}")
output = subprocess.run([*args], capture_output=True)
if output.returncode != 0:
raise ProcedureWorkerError(
@@ -166,7 +169,7 @@ def _run_and_capture_error(self, *args: str):
def _podman_ls_json(self, subcom: Literal["image", "container"] = "container"):
return json.loads(
self._run_and_capture_error(
- "podman", subcom, "list", "--all", "--format", "json"
+ "podman", subcom, "list", "--all", "--format", "json", log=False
).stdout
)
@@ -193,6 +196,7 @@ def run(
volumes: list[VolumeSpec],
command: str,
pod_name: str | None = None,
+ container_name: str | None = None,
) -> str:
_volumes = [
f"{vol['source']}:{vol['target']}{':ro' if vol['read_only'] else ''}" for vol in volumes
@@ -200,22 +204,39 @@ def run(
_volume_args = list(chain(*(("-v", vol) for vol in _volumes)))
_environment = _multi_args_from_dict("-e", environment) # type: ignore # this is actually a dict[str, str]
_pod_arg = ["--pod", pod_name] if pod_name else []
+ _name_arg = ["--replace", "--name", container_name] if container_name else []
return (
self._run_and_capture_error(
- "podman", "run", *_environment, "-d", *_volume_args, *_pod_arg, image_tag, command
+ "podman",
+ "run",
+ *_environment,
+ "-d",
+ *_name_arg,
+ *_volume_args,
+ *_pod_arg,
+ image_tag,
+ command,
)
.stdout.decode()
.strip()
)
def kill(self, id: str):
- self._run_and_capture_error("podman", "kill", id)
- self._run_and_capture_error("podman", "rm", id)
+ try:
+ self._run_and_capture_error("podman", "kill", id)
+ except ProcedureWorkerError as e:
+ logger.error(e)
def logs(self, id: str) -> list[str]:
- return self._run_and_capture_error("podman", "logs", id).stderr.decode().splitlines()
+ try:
+ return self._run_and_capture_error("podman", "logs", id).stderr.decode().splitlines()
+ except ProcedureWorkerError as e:
+ logger.error(e)
+ return [f"No logs found for container {id}\n"]
def state(self, id: str) -> PodmanContainerStates | None:
for container in self._podman_ls_json():
if container["Id"] == id or container["Id"].startswith(id):
+ if names := container.get("Names"):
+ logger.debug(f"Container {names[0]} status: {container['State']}")
return PodmanContainerStates(container["State"])
diff --git a/bec_server/bec_server/scan_server/procedures/container_worker.py b/bec_server/bec_server/scan_server/procedures/container_worker.py
index 0d6c2e5b1..8a9f3a554 100644
--- a/bec_server/bec_server/scan_server/procedures/container_worker.py
+++ b/bec_server/bec_server/scan_server/procedures/container_worker.py
@@ -1,10 +1,15 @@
+import inspect
import os
+import sys
+from contextlib import redirect_stdout
+from typing import AnyStr, TextIO
+from bec_ipython_client.main import BECIPythonClient
from bec_lib import messages
from bec_lib.client import BECClient
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import LogLevel, bec_logger
-from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus
+from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus, RawMessage
from bec_lib.redis_connector import RedisConnector
from bec_lib.service_config import ServiceConfig
from bec_server.scan_server.procedures import procedure_registry
@@ -15,12 +20,35 @@
ProcedureWorkerError,
)
from bec_server.scan_server.procedures.container_utils import get_backend
+from bec_server.scan_server.procedures.helper import BackendProcedureHelper
from bec_server.scan_server.procedures.protocol import ContainerCommandBackend
from bec_server.scan_server.procedures.worker_base import ProcedureWorker
logger = bec_logger.logger
+class RedisOutputDiverter(TextIO):
+ def __init__(self, conn: RedisConnector, queue: str):
+
+ self._conn = conn
+ self._ep = MessageEndpoints.procedure_logs(queue)
+ self._conn.delete(self._ep)
+
+ def write(self, data: AnyStr):
+ if data:
+ self._conn.xadd(self._ep, {"data": RawMessage(data=str(data))})
+ return len(data)
+
+ def flush(self): ...
+
+ @property
+ def encoding(self):
+ return "utf-8"
+
+ def close(self):
+ return
+
+
class ContainerProcedureWorker(ProcedureWorker):
"""A worker which runs scripts in a container with a full BEC environment,
mounted from the filesystem, and only access to Redis"""
@@ -34,7 +62,7 @@ def _worker_environment(self) -> ContainerWorkerEnv:
minimum necessary, or things which are only necessary for the functioning of the worker,
and other information should be passed through redis"""
return {
- "redis_server": f"{self._conn.host}:{self._conn.port}",
+ "redis_server": f"{PROCEDURE.REDIS_HOST}:{self._conn.port}",
"queue": self._queue,
"timeout_s": str(self._lifetime_s),
}
@@ -42,6 +70,7 @@ def _worker_environment(self) -> ContainerWorkerEnv:
def _setup_execution_environment(self):
self._backend: ContainerCommandBackend = get_backend()
image_tag = f"{PROCEDURE.CONTAINER.IMAGE_NAME}:v{PROCEDURE.BEC_VERSION}"
+ self.container_name = f"bec_procedure_{PROCEDURE.BEC_VERSION}_{self._queue}"
if not self._backend.image_exists(image_tag):
self._backend.build_worker_image()
self._container_id = self._backend.run(
@@ -57,6 +86,7 @@ def _setup_execution_environment(self):
],
PROCEDURE.CONTAINER.COMMAND,
pod_name=PROCEDURE.CONTAINER.POD_NAME,
+ container_name=self.container_name,
)
def _run_task(self, item: ProcedureExecutionMessage):
@@ -64,12 +94,16 @@ def _run_task(self, item: ProcedureExecutionMessage):
f"Container worker _run_task() called with {item} - this should never happen!"
)
- def _kill_process(self):
- if self._backend.state(self._container_id) not in [
+ def _ending_or_ended(self):
+ return self._backend.state(self._container_id) in [
PodmanContainerStates.EXITED,
PodmanContainerStates.STOPPED,
- ]:
- self._backend.kill(self._container_id)
+ PodmanContainerStates.STOPPING,
+ ]
+
+ def _kill_process(self):
+ if not self._ending_or_ended():
+ self._backend.kill(self.container_name)
def work(self):
"""block until the container is finished, listen for status updates in the meantime"""
@@ -77,34 +111,48 @@ def work(self):
# on timeout check if container is still running
status_update = None
- while self._backend.state(self._container_id) not in [
- PodmanContainerStates.EXITED,
- PodmanContainerStates.STOPPED,
- ]:
+ while not self._ending_or_ended():
status_update = self._conn.blocking_list_pop(
- MessageEndpoints.procedure_worker_status_update(self._queue), timeout_s=1
+ MessageEndpoints.procedure_worker_status_update(self._queue), timeout_s=0.2
)
if status_update is not None:
if not isinstance(status_update, messages.ProcedureWorkerStatusMessage):
raise ProcedureWorkerError(f"Received unexpected message {status_update}")
self.status = status_update.status
+ self._current_execution_id = status_update.current_execution_id
logger.info(
f"Container worker '{self._queue}' status update: {status_update.status.name}"
)
# TODO: we probably do want to handle some kind of timeout here but we don't know how
# long a running procedure should actually take - it could theoretically be infinite
+ if self.status != ProcedureWorkerStatus.FINISHED:
+ self.status = ProcedureWorkerStatus.DEAD
+
+ def abort_execution(self, execution_id: str):
+ """Abort the execution with the given id. Has no effect if the given ID is not the current job"""
+ if execution_id == self._current_execution_id:
+ self._backend.kill(self._container_id)
+ self._helper.remove_from_active.by_exec_id(execution_id)
+ logger.info(
+ f"Aborting execution {execution_id}, restarting worker for queue: {self._queue}"
+ )
+ self._setup_execution_environment()
+ def logs(self):
+ if self._container_id is None:
+ return [""]
+ return self._backend.logs(self._container_id)
-def main():
- """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command"""
- logger.info(f"Container worker starting up")
+
+def _setup():
+ logger.info("Container worker starting up")
try:
needed_keys = ContainerWorkerEnv.__annotations__.keys()
logger.debug(f"Checking for environment variables: {needed_keys}")
env: ContainerWorkerEnv = {k: os.environ[k] for k in needed_keys} # type: ignore
except KeyError as e:
logger.error(f"Missing environment variable needed by container worker: {e}")
- return
+ exit(1)
bec_logger.level = LogLevel.DEBUG
bec_logger._console_log = True
@@ -117,19 +165,27 @@ def main():
host, port = env["redis_server"].split(":")
redis = {"host": host, "port": port}
- client = BECClient(config=ServiceConfig(redis=redis))
+
+ client = BECIPythonClient(config=ServiceConfig(redis=redis))
+ logger.debug("starting client")
client.start()
logger.info(f"ContainerWorker started container for queue {env['queue']}")
logger.debug(f"ContainerWorker environment: {env}")
- endpoint_info = MessageEndpoints.procedure_execution(env["queue"])
conn = RedisConnector(env["redis_server"])
+ logger.debug(f"ContainerWorker {env['queue']} connected to Redis at {conn.host}:{conn.port}")
+ helper = BackendProcedureHelper(conn)
+
+ return env, helper, client, conn
+
+
+def _main(env, helper, client, conn):
+
+ exec_endpoint = MessageEndpoints.procedure_execution(env["queue"])
active_procs_endpoint = MessageEndpoints.active_procedure_executions()
status_endpoint = MessageEndpoints.procedure_worker_status_update(env["queue"])
- logger.debug(f"ContainerWorker connecting to Redis at {conn.host}:{conn.port}")
-
try:
timeout_s = int(env["timeout_s"])
except ValueError as e:
@@ -138,30 +194,43 @@ def main():
)
timeout_s = PROCEDURE.WORKER.QUEUE_TIMEOUT_S
- def _push_status(status: ProcedureWorkerStatus):
+ def _push_status(status: ProcedureWorkerStatus, id: str | None = None):
logger.debug(f"Updating container worker status to {status.name}")
conn.rpush(
status_endpoint,
- messages.ProcedureWorkerStatusMessage(worker_queue=env["queue"], status=status),
+ messages.ProcedureWorkerStatusMessage(
+ worker_queue=env["queue"], status=status, current_execution_id=id
+ ),
)
def _run_task(item: ProcedureExecutionMessage):
- procedure_registry.callable_from_execution_message(item)(
- *item.args_kwargs[0], **item.args_kwargs[1]
- )
+ kwargs = item.args_kwargs[1]
+ proc_func = procedure_registry.callable_from_execution_message(item)
+ if bec_arg := inspect.signature(proc_func).parameters.get("bec"):
+ if bec_arg.kind == bec_arg.KEYWORD_ONLY and bec_arg.annotation.__name__ == "BECClient":
+ logger.debug(f"Injecting BEC client argument for {item}")
+ kwargs["bec"] = client
+ procedure_registry.callable_from_execution_message(item)(*item.args_kwargs[0], **kwargs)
_push_status(ProcedureWorkerStatus.IDLE)
item = None
try:
- logger.debug(f"ContainerWorker waiting for instructions on {endpoint_info}")
+ logger.debug(f"ContainerWorker waiting for instructions on {exec_endpoint}")
while (
item := conn.blocking_list_pop_to_set_add(
- endpoint_info, active_procs_endpoint, timeout_s=timeout_s
+ exec_endpoint, active_procs_endpoint, timeout_s=timeout_s
)
) is not None:
- _push_status(ProcedureWorkerStatus.RUNNING)
+ _push_status(ProcedureWorkerStatus.RUNNING, item.execution_id)
+ helper.notify_watchers(env["queue"], queue_type="execution")
logger.debug(f"running task {item!r}")
- _run_task(item)
+ try:
+ _run_task(item)
+ except Exception as e:
+ logger.error(f"Encountered error running procedure {item}")
+ logger.error(e)
+ finally:
+ helper.remove_from_active.by_exec_id(item.execution_id)
_push_status(ProcedureWorkerStatus.IDLE)
except Exception as e:
logger.error(e) # don't stop ProcedureManager.spawn from cleaning up
@@ -170,8 +239,19 @@ def _run_task(item: ProcedureExecutionMessage):
_push_status(ProcedureWorkerStatus.FINISHED)
client.shutdown()
if item is not None: # in this case we are here due to an exception, not a timeout
- conn.remove_from_set(active_procs_endpoint, item)
+ helper.remove_from_active.by_exec_id(item.execution_id)
-if __name__ == "__main__":
- main()
+def main():
+ """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command"""
+
+ env, helper, client, conn = _setup()
+ output_diverter = RedisOutputDiverter(conn, env["queue"])
+ with redirect_stdout(output_diverter):
+ logger.add(
+ output_diverter,
+ level=LogLevel.SUCCESS,
+ format=bec_logger.formatting(is_container=True),
+ filter=bec_logger.filter(),
+ )
+ _main(env, helper, client, conn)
diff --git a/bec_server/bec_server/scan_server/procedures/helper.py b/bec_server/bec_server/scan_server/procedures/helper.py
new file mode 100644
index 000000000..b68bafa31
--- /dev/null
+++ b/bec_server/bec_server/scan_server/procedures/helper.py
@@ -0,0 +1,195 @@
+from typing import Literal
+
+from bec_lib.endpoints import EndpointInfo
+from bec_lib.endpoints import MessageEndpoints as ME
+from bec_lib.logger import bec_logger
+from bec_lib.messages import BECMessage
+from bec_lib.messages import ProcedureAbortMessage as AbrtMsg
+from bec_lib.messages import ProcedureClearUnhandledMessage as ClrMsg
+from bec_lib.messages import ProcedureExecutionMessage as ExecMsg
+from bec_lib.messages import ProcedureQNotifMessage as QNotifMsg
+from bec_lib.messages import ProcedureRequestMessage as ReqMsg
+from bec_lib.redis_connector import RedisConnector
+
+logger = bec_logger.logger
+
+
+class _HelperBase:
+ def __init__(self, conn: RedisConnector) -> None:
+ self._conn = conn
+
+
+class _Request(_HelperBase):
+ def _xadd(self, ep: EndpointInfo, msg: BECMessage):
+ self._conn.xadd(ep, msg.model_dump())
+
+ def procedure(self, msg: ReqMsg):
+ self._xadd(ME.procedure_request(), msg)
+
+ def abort_execution(self, execution_id: str):
+ """Send a message requesting an abort of execution_id"""
+ return self._xadd(ME.procedure_abort(), AbrtMsg(execution_id=execution_id))
+
+ def abort_queue(self, queue: str):
+ """Send a message requesting an abort of execution_id"""
+ return self._xadd(ME.procedure_abort(), AbrtMsg(queue=queue))
+
+ def abort_all(self):
+ """Send a message requesting an abort of execution_id"""
+ return self._xadd(ME.procedure_abort(), AbrtMsg(abort_all=True))
+
+ def clear_unhandled_execution(self, execution_id: str):
+ """Send a message requesting an abort of execution_id"""
+ return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(execution_id=execution_id))
+
+ def clear_unhandled_queue(self, queue: str):
+ """Send a message requesting an abort of execution_id"""
+ return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(queue=queue))
+
+ def clear_all_unhandled(self):
+ """Send a message requesting an abort of execution_id"""
+ return self._xadd(ME.procedure_clear_unhandled(), ClrMsg(abort_all=True))
+
+
+class _Get(_HelperBase):
+ def running_procedures(self) -> set[ExecMsg]:
+ """Get all the running procedures"""
+ return self._conn.get_set_members(ME.active_procedure_executions())
+
+ def exec_queue(self, queue: str) -> list[ExecMsg]:
+ """Get all the ProcedureExecutionMessages from a given execution queue"""
+ return self._conn.lrange(ME.procedure_execution(queue), 0, -1)
+
+ def unhandled_queue(self, queue: str) -> list[ExecMsg]:
+ """Get all the ProcedureExecutionMessages from a given unhandled execution queue"""
+ return self._conn.lrange(ME.unhandled_procedure_execution(queue), 0, -1)
+
+ def active_and_pending_queue_names(self) -> list[str]:
+ """Get the names of all pending queues and queues of currently running procedures"""
+ return list(set(self.queue_names()) | set(self.active_queue_names()))
+
+ def active_queue_names(self) -> list[str]:
+ """Get the names of all queues of currently running procedures"""
+ return list({msg.queue for msg in self.running_procedures()})
+
+ def queue_names(self, queue_type: Literal["execution", "unhandled"] = "execution") -> list[str]:
+ """Get the names of queues currently containing pending ProcedureExecutionMessages
+
+ Args:
+ queue_type (Literal["execution", "unhandled"]): Type of queue, default "execution" for currently active executions, "unhandled" for aborted executions
+ """
+ ep = (
+ ME.procedure_execution
+ if queue_type == "execution"
+ else ME.unhandled_procedure_execution
+ )
+ raw: list[str] = [s.decode() for s in self._conn.keys(ep("*"))]
+ return [s.split("/")[-1] for s in raw]
+
+
+class FrontendProcedureHelper(_HelperBase):
+
+ def __init__(self, conn: RedisConnector) -> None:
+ super().__init__(conn)
+ self.request = _Request(conn)
+ self.get = _Get(conn)
+
+
+class _BackenHelperBase(_HelperBase):
+ def __init__(self, conn: RedisConnector, parent: "BackendProcedureHelper") -> None:
+ self._conn = conn
+ self._parent = parent
+
+
+class _Push(_BackenHelperBase):
+ def exec(self, queue: str, msg: ExecMsg):
+ """Push execution message `msg` to execution queue `queue`"""
+ self._conn.rpush(ME.procedure_execution(queue), msg)
+ self._parent.notify_watchers(queue, "execution")
+
+ def unhandled(self, queue: str, msg: ExecMsg):
+ """Push execution message `msg` to unhandled execution queue `queue`"""
+ self._conn.rpush(ME.unhandled_procedure_execution(queue), msg)
+ self._parent.notify_watchers(queue, "unhandled")
+
+
+class _Clear(_BackenHelperBase):
+ def all_unhandled(self):
+ """Remove all unhandled execution queues"""
+ for queue in self._parent.get.queue_names("unhandled"):
+ self.unhandled_queue(queue)
+
+ def unhandled_queue(self, queue: str):
+ """Remove an unhandled execution queue"""
+ self._conn.delete(ME.unhandled_procedure_execution(queue))
+ self._parent.notify_watchers(queue, "unhandled")
+
+ def unhandled_execution(self, execution_id: str):
+ """Remove a ProcedureExecutionMessage from its unhandled queue by its execution ID"""
+ for queue in self._parent.get.queue_names("unhandled"):
+ for msg in self._parent.get.unhandled_queue(queue):
+ if msg.execution_id == execution_id:
+ if self._conn.lrem(ME.unhandled_procedure_execution(msg.queue), 0, msg) > 0:
+ logger.debug(f"Removed execution {msg} from queue.")
+ self._parent.notify_watchers(queue, "unhandled")
+ return
+ logger.debug(f"Execution {execution_id} not found in any unhandled queue.")
+
+
+class _Move(_BackenHelperBase):
+ def all_active_to_unhandled(self):
+ """Move all messages in the active executions set to unhandled"""
+ for msg in self._parent.get.running_procedures():
+ self._parent.push.unhandled(msg.queue, msg)
+ self._conn.delete(ME.active_procedure_executions())
+
+ def execution_queue_to_unhandled(self, queue: str):
+ """Move all messages from execution queue to unhandled execution queue of the same name"""
+ for msg in self._parent.get.exec_queue(queue):
+ self._parent.push.unhandled(queue, msg)
+ self._conn.delete(ME.procedure_execution(queue))
+
+ def all_execution_queues_to_unhandled(self):
+ """Move all messages from all execution queues to unhandled execution queues of the same name"""
+ for queue in self._parent.get.queue_names():
+ self.execution_queue_to_unhandled(queue)
+
+
+class _RemoveFromActive(_BackenHelperBase):
+ def by_exec_id(self, execution_id: str):
+ """Remove a message from the set of currently active executions"""
+ for msg in self._conn.get_set_members(ME.active_procedure_executions()):
+ if msg.execution_id == execution_id:
+ self._conn.remove_from_set(ME.active_procedure_executions(), msg)
+ logger.debug(f"removed active procedure {execution_id}")
+ return
+ logger.debug(f"No active procedure {execution_id} to remove")
+
+ def by_queue(self, queue: str):
+ """Remove a message from the set of currently active executions"""
+ removed = False
+ for msg in self._conn.get_set_members(ME.active_procedure_executions()):
+ if msg.queue == queue:
+ self._conn.remove_from_set(ME.active_procedure_executions(), msg)
+ logger.debug(f"removed active procedure {msg} with queue {queue}")
+ if removed:
+ return
+ logger.debug(f"No active procedure with queue {queue} to remove")
+
+
+class BackendProcedureHelper(FrontendProcedureHelper):
+ def __init__(self, conn: RedisConnector) -> None:
+ super().__init__(conn)
+ self.push = _Push(conn, self)
+ self.clear = _Clear(conn, self)
+ self.move = _Move(conn, self)
+ self.remove_from_active = _RemoveFromActive(conn, self)
+
+ def notify_watchers(self, queue: str, queue_type: Literal["execution", "unhandled"]):
+ return self._conn.send(
+ ME.procedure_queue_notif(), QNotifMsg(queue_name=queue, queue_type=queue_type)
+ )
+
+ def notify_all(self, queue_type: Literal["execution", "unhandled"]):
+ for queue in self.get.queue_names(queue_type):
+ self.notify_watchers(queue, queue_type)
diff --git a/bec_server/bec_server/scan_server/procedures/in_process_worker.py b/bec_server/bec_server/scan_server/procedures/in_process_worker.py
index e225a88b4..d56c9eebe 100644
--- a/bec_server/bec_server/scan_server/procedures/in_process_worker.py
+++ b/bec_server/bec_server/scan_server/procedures/in_process_worker.py
@@ -12,7 +12,7 @@
class InProcessProcedureWorker(ProcedureWorker):
"""A simple in-process procedure worker. Be careful with this, it should only run trusted code.
- Intended for built-in procedures like those to run a single scan."""
+ Intended for built-in procedures like those to run a single scan, or testing."""
def _setup_execution_environment(self):
from bec_lib.logger import bec_logger
@@ -45,3 +45,7 @@ def _kill_process(self):
self.logger.info(
f"In-process procedure worker for queue {self.key.endpoint} timed out after {self._lifetime_s} s, shutting down"
)
+
+ def abort_execution(self, execution_id: str):
+ """Abort the execution with the given id"""
+ ... # No nice way to abort a running function, don't use this class for such things
diff --git a/bec_server/bec_server/scan_server/procedures/manager.py b/bec_server/bec_server/scan_server/procedures/manager.py
index d402d92e3..faa401ac7 100644
--- a/bec_server/bec_server/scan_server/procedures/manager.py
+++ b/bec_server/bec_server/scan_server/procedures/manager.py
@@ -1,25 +1,38 @@
from __future__ import annotations
import atexit
+from collections import deque
from concurrent import futures
from concurrent.futures import Future, ThreadPoolExecutor
from threading import RLock
-from typing import Any, Callable, TypedDict
+from typing import Any, Callable, TypedDict, TypeVar
from pydantic import ValidationError
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
-from bec_lib.messages import ProcedureRequestMessage, RequestResponseMessage
+from bec_lib.messages import (
+ BECMessage,
+ ProcedureAbortMessage,
+ ProcedureClearUnhandledMessage,
+ ProcedureExecutionMessage,
+ ProcedureRequestMessage,
+ ProcedureWorkerStatus,
+ RequestResponseMessage,
+)
from bec_lib.redis_connector import RedisConnector
from bec_server.scan_server.procedures import procedure_registry
from bec_server.scan_server.procedures.constants import PROCEDURE, WorkerAlreadyExists
+from bec_server.scan_server.procedures.helper import BackendProcedureHelper
from bec_server.scan_server.procedures.worker_base import ProcedureWorker
from bec_server.scan_server.scan_server import ScanServer
logger = bec_logger.logger
+# TODO garbage collect IDs
+
+
class ProcedureWorkerEntry(TypedDict):
worker: ProcedureWorker | None
future: Future
@@ -33,6 +46,15 @@ def _log_on_end(future: Future):
logger.success(f"Procedure worker {future} shut down gracefully")
+_T = TypeVar("_T", bound=BECMessage)
+
+
+def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T:
+ if isinstance(msg, dict):
+ return MsgType.model_validate(msg)
+ return msg
+
+
class ProcedureManager:
def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]):
@@ -42,10 +64,18 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]):
Args:
parent (ScanServer): the scan server to get the Redis server address from.
worker_type (type[ProcedureWorker]): which kind of worker to use."""
-
self._parent = parent
self.lock = RLock()
- self.active_workers: dict[str, ProcedureWorkerEntry] = {}
+
+ logger.success("Initialising procedure manager...")
+
+ self._logs = deque([], maxlen=1000)
+ self._conn = RedisConnector([self._parent.bootstrap_server])
+ self._helper = BackendProcedureHelper(self._conn)
+ self._startup()
+
+ self._active_workers: dict[str, ProcedureWorkerEntry] = {}
+ self._messages_by_ids: dict[str, ProcedureExecutionMessage] = {}
self.executor = ThreadPoolExecutor(
max_workers=PROCEDURE.WORKER.MAX_WORKERS, thread_name_prefix="user_procedure_"
)
@@ -53,11 +83,26 @@ def __init__(self, parent: ScanServer, worker_type: type[ProcedureWorker]):
self._callbacks: dict[str, list[Callable[[ProcedureWorker], Any]]] = {}
self._worker_cls = worker_type
- self._conn = RedisConnector([self._parent.bootstrap_server])
self._reply_endpoint = MessageEndpoints.procedure_request_response()
self._server = f"{self._conn.host}:{self._conn.port}"
- self._conn.register(MessageEndpoints.procedure_request(), None, self.process_queue_request)
+ self._conn.register(MessageEndpoints.procedure_abort(), None, self._process_abort)
+ self._conn.register(
+ MessageEndpoints.procedure_clear_unhandled(), None, self._process_clear_unhandled
+ )
+ self._conn.register(MessageEndpoints.procedure_request(), None, self._process_queue_request)
+ logger.success("Done initialising procedure manager.")
+
+ def _startup(self):
+ # If the server is restarted, clear any pending requests, they'll have to be resubmitted
+ self._conn.delete(MessageEndpoints.procedure_request())
+ previous_queues = self._helper.get.active_and_pending_queue_names()
+ logger.debug(f"Clearing previous procedure queues {previous_queues}...")
+ self._helper.move.all_execution_queues_to_unhandled()
+ self._helper.move.all_active_to_unhandled()
+ for queue in previous_queues:
+ self._helper.notify_watchers(queue, "execution")
+ self._helper.notify_all("unhandled")
def _ack(self, accepted: bool, msg: str):
logger.info(f"procedure accepted: {accepted}, message: {msg}")
@@ -65,9 +110,9 @@ def _ack(self, accepted: bool, msg: str):
self._reply_endpoint, RequestResponseMessage(accepted=accepted, message=msg)
)
- def _validate_request(self, msg: dict[str, Any]):
+ def _validate_request(self, msg: dict[str, Any] | ProcedureRequestMessage):
try:
- message_obj = ProcedureRequestMessage.model_validate(msg)
+ message_obj = _resolve_dict(msg, ProcedureRequestMessage)
if not procedure_registry.is_registered(message_obj.identifier):
self._ack(
False,
@@ -86,13 +131,17 @@ def add_callback(self, queue: str, cb: Callable[[ProcedureWorker], Any]):
self._callbacks[queue].append(cb)
def _run_callbacks(self, queue: str):
- if (worker := self.active_workers[queue]["worker"]) is None:
- return
+ with self.lock:
+ if queue not in self._active_workers:
+ logger.error(f"Attempted to run callbacks for nonexistent worker {queue}")
+ return
+ if (worker := self._active_workers[queue]["worker"]) is None:
+ return
for cb in self._callbacks.get(queue, []):
cb(worker)
self._callbacks[queue] = []
- def process_queue_request(self, msg: dict[str, Any]):
+ def _process_queue_request(self, msg: dict[str, Any] | ProcedureRequestMessage):
"""Read a `ProcedureRequestMessage` and if it is valid, create a corresponding `ProcedureExecutionMessage`.
If there is already a worker for the queue for that request message, add the execution message to that queue,
otherwise create a new queue and a new worker.
@@ -101,33 +150,93 @@ def process_queue_request(self, msg: dict[str, Any]):
msg (dict[str, Any]): dict corresponding to a ProcedureRequestMessage"""
logger.debug(f"Procedure manager got request message {msg}")
- if (message_obj := self._validate_request(msg)) is None:
+ if (message := self._validate_request(msg)) is None:
return
- self._ack(True, f"Running procedure {message_obj.identifier}")
- queue = message_obj.queue or PROCEDURE.WORKER.DEFAULT_QUEUE
- endpoint = MessageEndpoints.procedure_execution(queue)
- logger.debug(f"active workers: {self.active_workers}, worker requested: {queue}")
- self._conn.rpush(
- endpoint,
- endpoint.message_type(
- identifier=message_obj.identifier,
- queue=queue,
- args_kwargs=message_obj.args_kwargs or ((), {}),
- ),
+ self._ack(True, f"Running procedure {message.identifier}")
+ queue = message.queue or PROCEDURE.WORKER.DEFAULT_QUEUE
+ exec_message = ProcedureExecutionMessage(
+ identifier=message.identifier, queue=queue, args_kwargs=message.args_kwargs or ((), {})
)
+ logger.debug(f"active workers: {self._active_workers}, worker requested: {queue}")
+ self._helper.push.exec(queue, exec_message)
def cleanup_worker(fut):
with self.lock:
logger.debug(f"cleaning up worker {fut} for queue {queue}...")
+ self._helper.remove_from_active.by_queue(queue)
self._run_callbacks(queue)
- del self.active_workers[queue]
+ self._helper.notify_watchers(queue, "execution")
+ del self._active_workers[queue]
with self.lock:
- if queue not in self.active_workers:
+ if queue not in self._active_workers:
new_worker = self.executor.submit(self.spawn, queue=queue)
new_worker.add_done_callback(_log_on_end)
new_worker.add_done_callback(cleanup_worker)
- self.active_workers[queue] = {"worker": None, "future": new_worker}
+ self._active_workers[queue] = {"worker": None, "future": new_worker}
+ self._messages_by_ids[exec_message.execution_id] = exec_message
+
+ def _process_abort(self, msg: dict[str, Any] | ProcedureAbortMessage):
+ message = _resolve_dict(msg, ProcedureAbortMessage)
+ with self.lock:
+ if message.abort_all:
+ self._abort_all()
+ if message.queue is not None:
+ self._abort_queue(message.queue)
+ if message.execution_id is not None:
+ self._abort_execution(message.execution_id)
+
+ def _abort_execution(self, execution_id: str):
+ if (msg := self._messages_by_ids.get(execution_id)) is None:
+ logger.warning(f"Procedure execution with ID {execution_id} not known.")
+ return
+ # Remove it from the queue if not yet started
+ if self._conn.lrem(MessageEndpoints.procedure_execution(msg.queue), 0, msg) > 0:
+ logger.debug(f"Removed execution {msg} from queue.")
+ self._helper.notify_watchers(msg.queue, "execution")
+ # Otherwise try to remove it from whichever worker has it
+ for entry in self._active_workers.values():
+ if (worker := entry["worker"]) is not None and worker:
+ worker.abort_execution(execution_id)
+ # Move it to unhandled and stop tracking
+ self._helper.push.unhandled(msg.queue, msg)
+ del self._messages_by_ids[execution_id]
+
+ def _abort_queue(self, queue: str):
+ self._helper.move.execution_queue_to_unhandled(queue)
+ if (entry := self._active_workers.get(queue)) is not None:
+ if entry["worker"] is not None:
+ entry["worker"].abort()
+ entry["future"].cancel()
+ futures.wait((entry["future"],), PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S)
+ else:
+ logger.warning(f"Received abort request for unknown queue {queue}!")
+ self._helper.notify_watchers(queue, "execution")
+
+ def _abort_all(self):
+ for entry in self._active_workers.values():
+ if entry["worker"] is not None:
+ entry["worker"].abort()
+ for entry in self._active_workers.values():
+ entry["future"].cancel()
+ self._helper.move.all_execution_queues_to_unhandled()
+ self._wait_for_all_futures()
+
+ def _process_clear_unhandled(self, msg: dict[str, Any] | ProcedureClearUnhandledMessage):
+ message = _resolve_dict(msg, ProcedureClearUnhandledMessage)
+ with self.lock:
+ if message.abort_all:
+ self._helper.clear.all_unhandled()
+ if message.queue is not None:
+ self._helper.clear.unhandled_queue(message.queue)
+ if message.execution_id is not None:
+ self._helper.clear.unhandled_execution(message.execution_id)
+
+ def _wait_for_all_futures(self):
+ futures.wait(
+ (entry["future"] for entry in self._active_workers.values()),
+ timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S,
+ )
def spawn(self, queue: str):
"""Spawn a procedure worker future which listens to a given queue, i.e. procedure queue list in Redis.
@@ -135,32 +244,41 @@ def spawn(self, queue: str):
Args:
queue (str): name of the queue to spawn a worker for"""
- if queue in self.active_workers and self.active_workers[queue]["worker"] is not None:
+ if queue in self._active_workers and self._active_workers[queue]["worker"] is not None:
raise WorkerAlreadyExists(
- f"Queue {queue} already has an active worker in {self.active_workers}!"
+ f"Queue {queue} already has an active worker in {self._active_workers}!"
)
with self._worker_cls(self._server, queue, PROCEDURE.WORKER.QUEUE_TIMEOUT_S) as worker:
with self.lock:
- self.active_workers[queue]["worker"] = worker
+ self._active_workers[queue]["worker"] = worker
worker.work()
+ self._logs.extend(worker.logs())
def shutdown(self):
"""Shutdown the procedure manager. Unregisters from the request endpoint, cancel any
procedure workers which haven't started, and abort any which have."""
self._conn.unregister(
- MessageEndpoints.procedure_request(), None, self.process_queue_request
+ MessageEndpoints.procedure_request(), None, self._process_queue_request
)
self._conn.shutdown()
# cancel futures by hand to give us the opportunity to detatch them from redis if they have started
- for entry in self.active_workers.values():
+ for entry in self._active_workers.values():
cancelled = entry["future"].cancel()
if not cancelled:
# unblock any waiting workers and let them shutdown
if worker := entry["worker"]:
# redis unblock executor.client_id
worker.abort()
- futures.wait(
- (entry["future"] for entry in self.active_workers.values()),
- timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S,
- )
+ self._wait_for_all_futures()
self.executor.shutdown()
+
+ def active_workers(self) -> list[str]:
+ with self.lock:
+ return list(self._active_workers.keys())
+
+ def worker_statuses(self) -> dict[str, ProcedureWorkerStatus]:
+ with self.lock:
+ return {
+ q: w["worker"].status if w["worker"] is not None else ProcedureWorkerStatus.NONE
+ for q, w in self._active_workers.items()
+ }
diff --git a/bec_server/bec_server/scan_server/procedures/procedure_registry.py b/bec_server/bec_server/scan_server/procedures/procedure_registry.py
index 8a2582f56..4ed3ff114 100644
--- a/bec_server/bec_server/scan_server/procedures/procedure_registry.py
+++ b/bec_server/bec_server/scan_server/procedures/procedure_registry.py
@@ -1,9 +1,10 @@
-from typing import Any, Callable, Iterable, Iterator
+from typing import Iterable
from bec_lib.messages import ProcedureExecutionMessage
from bec_server.scan_server.procedures.builtin_procedures import (
log_message_args_kwargs,
run_scan,
+ run_script,
sleep,
)
from bec_server.scan_server.procedures.constants import BecProcedure
@@ -12,6 +13,7 @@
"log execution message args": log_message_args_kwargs,
"run scan": run_scan,
"sleep": sleep,
+ "run_script": run_script,
}
_PROCEDURE_REGISTRY: dict[str, BecProcedure] = {} | _BUILTIN_PROCEDURES
diff --git a/bec_server/bec_server/scan_server/procedures/protocol.py b/bec_server/bec_server/scan_server/procedures/protocol.py
index 6d79b80a6..06d40c818 100644
--- a/bec_server/bec_server/scan_server/procedures/protocol.py
+++ b/bec_server/bec_server/scan_server/procedures/protocol.py
@@ -11,7 +11,6 @@ class VolumeSpec(TypedDict):
class ContainerCommandOutput(Protocol):
-
def pretty_print(self) -> str: ...
@@ -30,6 +29,7 @@ def run(
volumes: list[VolumeSpec],
command: str,
pod_name: str | None = None,
+ container_name: str | None = None,
) -> str: ...
def kill(self, id: str): ...
def logs(self, id: str) -> list[str]: ...
diff --git a/bec_server/bec_server/scan_server/procedures/worker_base.py b/bec_server/bec_server/scan_server/procedures/worker_base.py
index a18b0dcf8..7b07256db 100644
--- a/bec_server/bec_server/scan_server/procedures/worker_base.py
+++ b/bec_server/bec_server/scan_server/procedures/worker_base.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from enum import Enum, auto
+from threading import Event
from typing import cast
from bec_lib.endpoints import MessageEndpoints
@@ -9,6 +9,7 @@
from bec_lib.messages import ProcedureExecutionMessage, ProcedureWorkerStatus
from bec_lib.redis_connector import RedisConnector
from bec_server.scan_server.procedures.constants import PROCEDURE
+from bec_server.scan_server.procedures.helper import BackendProcedureHelper
logger = bec_logger.logger
@@ -36,8 +37,11 @@ def __init__(self, server: str, queue: str, lifetime_s: float | None = None):
self._active_procs_endpoint = MessageEndpoints.active_procedure_executions()
self.status = ProcedureWorkerStatus.IDLE
self._conn = RedisConnector([server])
+ self._helper = BackendProcedureHelper(self._conn)
self._lifetime_s = lifetime_s or PROCEDURE.WORKER.QUEUE_TIMEOUT_S
self.client_id = self._conn.client_id()
+ self._current_execution_id: str | None = None
+ self._aborted = Event()
self._setup_execution_environment()
@@ -61,9 +65,19 @@ def _run_task(self, item: ProcedureExecutionMessage):
@abstractmethod
def _setup_execution_environment(self): ...
+ def logs(self) -> list[str]:
+ return [""]
+
def abort(self):
+ """Abort the entire worker"""
+ self._aborted.set()
self._kill_process()
+ @abstractmethod
+ def abort_execution(self, execution_id: str):
+ """Abort the execution with the given id"""
+ ...
+
def __exit__(self, exc_type, exc_val, exc_tb):
self._kill_process()
@@ -76,9 +90,10 @@ def work(self):
self.key, self._active_procs_endpoint, timeout_s=self._lifetime_s
)
) is not None:
- self.status = ProcedureWorkerStatus.RUNNING
- self._run_task(cast(ProcedureExecutionMessage, item))
- self.status = ProcedureWorkerStatus.IDLE
+ if not self._aborted.is_set():
+ self.status = ProcedureWorkerStatus.RUNNING
+ self._run_task(cast(ProcedureExecutionMessage, item))
+ self.status = ProcedureWorkerStatus.IDLE
except Exception as e:
logger.error(e) # don't stop ProcedureManager.spawn from cleaning up
finally:
diff --git a/bec_server/tests/tests_scan_server/test_container_utils.py b/bec_server/tests/tests_scan_server/test_container_utils.py
index 55fb81408..31c948f6b 100644
--- a/bec_server/tests/tests_scan_server/test_container_utils.py
+++ b/bec_server/tests/tests_scan_server/test_container_utils.py
@@ -138,6 +138,7 @@ def test_api_utils_run(api_utils: tuple[PodmanApiUtils, MagicMock]):
environment={"a": "b"},
mounts=[{"source": "a", "target": "b", "read_only": True, "type": "bind"}],
pod=None,
+ name=None,
)
@@ -185,7 +186,7 @@ def test_cli_kill(cli_utils: tuple[PodmanCliUtils, MagicMock]):
utils, run_mock = cli_utils
run_mock.reset_mock()
utils.kill("test")
- assert run_mock.call_count == 2
+ assert run_mock.call_count == 1
def test_cli_get_state(cli_utils_with_fake_container_json: PodmanCliUtils):
diff --git a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py
index b80ffabc9..5b686a04a 100644
--- a/bec_server/tests/tests_scan_server/test_procedure_container_worker.py
+++ b/bec_server/tests/tests_scan_server/test_procedure_container_worker.py
@@ -29,7 +29,7 @@ def test_container_worker_init(redis_mock):
redis_mock().port = "port"
worker = ContainerProcedureWorker(server="server:port", queue="test_queue", lifetime_s=1)
assert worker._worker_environment() == {
- "redis_server": "server:port",
+ "redis_server": "redis:port",
"queue": "test_queue",
"timeout_s": "1",
}
@@ -45,20 +45,20 @@ def test_container_worker_work(redis_mock, logger_mock):
redis_mock().host = "server"
redis_mock().port = "port"
+ msgs = [
+ ProcedureWorkerStatusMessage(
+ worker_queue="test_queue",
+ status=ProcedureWorkerStatus.RUNNING,
+ current_execution_id="test",
+ ),
+ ProcedureWorkerStatusMessage(
+ worker_queue="test_queue", status=ProcedureWorkerStatus.FINISHED
+ ),
+ ]
+
def _mock_pop():
- msgs = [
- ProcedureWorkerStatusMessage(
- worker_queue="test_queue", status=ProcedureWorkerStatus.RUNNING
- ),
- ProcedureWorkerStatusMessage(
- worker_queue="test_queue", status=ProcedureWorkerStatus.FINISHED
- ),
- ]
- for msg in msgs:
- sleep(0.05)
- yield msg
+ yield from msgs
while True:
- sleep(0.05)
yield from repeat(None)
mock_pop = _mock_pop()
@@ -74,7 +74,7 @@ def cleanup():
t.start()
start = time.monotonic()
- while time.monotonic() < start + 5:
+ while time.monotonic() < start + 1000:
try:
assert (
logger_mock.info.call_args_list[0].args[0]
@@ -100,7 +100,7 @@ def cleanup():
@patch("bec_server.scan_server.procedures.container_worker.logger")
def test_main_exits_without_env_variables(logger_mock):
- with patch.dict(os.environ, clear=True):
+ with patch.dict(os.environ, clear=True), pytest.raises(SystemExit):
container_worker_main()
assert "Missing environment variable " in logger_mock.error.call_args.args[0]
@@ -125,6 +125,7 @@ def __init__(self, name: str, args: tuple = (), kwargs: dict = {}):
self.identifier = name
self.args = args
self.kwargs = kwargs
+ self.execution_id = f"test_{name}"
def __repr__(self) -> str:
return self.identifier
diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py
index 0fda944c9..7775e5281 100644
--- a/bec_server/tests/tests_scan_server/test_procedures.py
+++ b/bec_server/tests/tests_scan_server/test_procedures.py
@@ -1,6 +1,7 @@
import threading
import time
from functools import partial
+from itertools import starmap
from typing import Any, Callable
from unittest.mock import MagicMock, patch
@@ -8,6 +9,7 @@
import pytest
from bec_lib.client import BECClient, RedisConnector
+from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import (
ProcedureExecutionMessage,
ProcedureRequestMessage,
@@ -15,6 +17,7 @@
RequestResponseMessage,
)
from bec_lib.serialization import MsgpackSerialization
+from bec_lib.service_config import ServiceConfig
from bec_server.scan_server.procedures.constants import PROCEDURE, BecProcedure, WorkerAlreadyExists
from bec_server.scan_server.procedures.in_process_worker import InProcessProcedureWorker
from bec_server.scan_server.procedures.manager import ProcedureManager, ProcedureWorker
@@ -32,11 +35,17 @@
LOG_MSG_PROC_NAME = "log execution message args"
+FAKEREDIS_HOST = "127.0.0.1"
+FAKEREDIS_PORT = 6380
@pytest.fixture(autouse=True)
+@patch("bec_lib.bec_service.BECAccess", MagicMock)
def shutdown_client():
- bec_client = BECClient()
+ bec_client = BECClient(
+ config=ServiceConfig(config={"redis": {"host": FAKEREDIS_HOST, "port": FAKEREDIS_PORT}}),
+ connector_cls=partial(RedisConnector, redis_cls=fakeredis.FakeRedis),
+ )
bec_client.start()
yield
bec_client.shutdown()
@@ -45,7 +54,7 @@ def shutdown_client():
@pytest.fixture
def procedure_manager():
server = MagicMock()
- server.bootstrap_server = "localhost:1"
+ server.bootstrap_server = f"{FAKEREDIS_HOST}:{FAKEREDIS_PORT}"
with patch(
"bec_server.scan_server.procedures.manager.RedisConnector",
partial(RedisConnector, redis_cls=fakeredis.FakeRedis), # type: ignore
@@ -104,7 +113,7 @@ def process_request_manager(procedure_manager: ProcedureManager):
@pytest.mark.parametrize("message", PROCESS_REQUEST_TEST_CASES)
def test_process_request_happy_paths(process_request_manager, message: ProcedureRequestMessage):
- process_request_manager.process_queue_request(message)
+ process_request_manager._process_queue_request(message)
process_request_manager._ack.assert_called_with(True, f"Running procedure {message.identifier}")
process_request_manager._conn.rpush.assert_called()
endpoint, execution_msg = process_request_manager._conn.rpush.call_args.args
@@ -112,15 +121,15 @@ def test_process_request_happy_paths(process_request_manager, message: Procedure
assert queue in endpoint.endpoint
assert execution_msg.identifier == message.identifier
process_request_manager.spawn.assert_called()
- assert queue in process_request_manager.active_workers.keys()
+ assert queue in process_request_manager._active_workers.keys()
def test_process_request_failure(process_request_manager):
- process_request_manager.process_queue_request(None)
+ process_request_manager._process_queue_request(None)
process_request_manager._ack.assert_not_called()
process_request_manager._conn.rpush.assert_not_called()
process_request_manager.spawn.assert_not_called()
- assert process_request_manager.active_workers == {}
+ assert process_request_manager._active_workers == {}
class UnlockableWorker(ProcedureWorker):
@@ -131,6 +140,7 @@ def __init__(self, server: str, queue: str, lifetime_s: int | None = None):
self.event_1 = threading.Event()
self.event_2 = threading.Event()
+ def abort_execution(self, execution_id: str): ...
def _setup_execution_environment(self): ...
def _kill_process(self): ...
def _run_task(self, item):
@@ -161,14 +171,14 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager):
queue = message.queue or PROCEDURE.WORKER.DEFAULT_QUEUE
procedure_manager._validate_request = MagicMock(side_effect=lambda msg: msg)
# trigger the running of the test message
- procedure_manager.process_queue_request(message) # type: ignore
- assert queue in procedure_manager.active_workers.keys()
+ procedure_manager._process_queue_request(message) # type: ignore
+ assert queue in procedure_manager._active_workers.keys()
# spawn method should be added as a future
- _wait_until(procedure_manager.active_workers[queue]["future"].running)
+ _wait_until(procedure_manager._active_workers[queue]["future"].running)
# and then create the worker
- _wait_until(lambda: procedure_manager.active_workers[queue].get("worker") is not None)
- worker = procedure_manager.active_workers[queue]["worker"]
+ _wait_until(lambda: procedure_manager._active_workers[queue].get("worker") is not None)
+ worker = procedure_manager._active_workers[queue]["worker"]
assert isinstance(worker, UnlockableWorker)
_wait_until(lambda: worker.status == ProcedureWorkerStatus.RUNNING)
@@ -185,7 +195,7 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager):
worker.event_2.set()
_wait_until(lambda: worker.status == ProcedureWorkerStatus.FINISHED)
# spawn deletes the worker queue
- _wait_until(lambda: len(procedure_manager.active_workers) == 0)
+ _wait_until(lambda: len(procedure_manager._active_workers) == 0)
@patch("bec_server.scan_server.procedures.worker_base.RedisConnector", MagicMock())
@@ -259,7 +269,7 @@ def test_callable_from_message():
def test_register_rejects_wrong_type():
with pytest.raises(ProcedureRegistryError) as e:
- register("test", "test")
+ register("test", "test") # type: ignore
assert e.match("not a valid procedure")
@@ -267,3 +277,162 @@ def test_register_rejects_already_registered():
with pytest.raises(ProcedureRegistryError) as e:
register("run scan", lambda *_, **__: None)
assert e.match("already registered")
+
+
+def _yield_once():
+ yield "value"
+ while True:
+ yield None
+
+
+@patch(
+ "bec_server.scan_server.procedures.worker_base.RedisConnector",
+ side_effect=lambda *_: MagicMock(
+ blocking_list_pop_to_set_add=MagicMock(side_effect=_yield_once())
+ ),
+)
+def test_manager_status_api(_conn, procedure_manager):
+ procedure_manager._worker_cls = UnlockableWorker
+ for message in PROCESS_REQUEST_TEST_CASES:
+ procedure_manager._process_queue_request(message)
+ _wait_until(lambda: procedure_manager.active_workers() == ["primary", "queue2"])
+ _wait_until(
+ lambda: procedure_manager.worker_statuses()
+ == {"primary": ProcedureWorkerStatus.RUNNING, "queue2": ProcedureWorkerStatus.RUNNING}
+ )
+ for w in procedure_manager._active_workers.values():
+ w["worker"].event_1.set()
+ _wait_until(
+ lambda: procedure_manager.worker_statuses()
+ == {"primary": ProcedureWorkerStatus.IDLE, "queue2": ProcedureWorkerStatus.IDLE}
+ )
+ for w in procedure_manager._active_workers.values():
+ w["worker"].event_2.set()
+ _wait_until(lambda: procedure_manager.active_workers() == [])
+
+
+_ManagerWithMsgs = tuple[ProcedureManager, list[ProcedureExecutionMessage]]
+
+
+@pytest.fixture
+def manager_with_test_msgs(procedure_manager: ProcedureManager):
+ procedure_manager._conn._redis_conn.flushdb()
+ contents = [
+ ("test_identifier_1", "queue1", ((), {})),
+ ("test_identifier_2", "queue1", ((), {})),
+ ("test_identifier_1", "queue2", ((), {})),
+ ("test_identifier_2", "queue2", ((), {})),
+ ]
+ msgs = iter(
+ ProcedureRequestMessage(identifier=c[0], queue=c[1], args_kwargs=c[2]) for c in contents
+ )
+ procedure_manager._validate_request = lambda msg: next(msgs)
+ for _ in range(len(contents)):
+ procedure_manager._process_queue_request({})
+ return (
+ procedure_manager,
+ [
+ ProcedureExecutionMessage(metadata={}, identifier=c[0], queue=c[1], args_kwargs=c[2])
+ for c in contents
+ ],
+ )
+
+
+def _eq_except_id(a: ProcedureExecutionMessage, b: ProcedureExecutionMessage):
+ return a.identifier == b.identifier and a.queue == b.queue and a.args_kwargs == b.args_kwargs
+
+
+def _all_eq_except_id(a: list[ProcedureExecutionMessage], b: list[ProcedureExecutionMessage]):
+ if len(a) != len(b):
+ return False
+ return all(starmap(_eq_except_id, zip(a, b)))
+
+
+@pytest.mark.parametrize("queue", ["queue1", "queue2"])
+@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock())
+def test_startup(manager_with_test_msgs: _ManagerWithMsgs, queue: str):
+ procedure_manager, expected = manager_with_test_msgs
+ queue_expected = list(filter(lambda msg: msg.queue == queue, expected))
+
+ execution_list = procedure_manager._helper.get.exec_queue(queue)
+ assert _all_eq_except_id(execution_list, queue_expected)
+
+ procedure_manager._startup()
+
+ # on startup, the manager should move active queues to unhandled queues
+ execution_list = procedure_manager._helper.get.exec_queue(queue)
+ unhandled_execution_list = procedure_manager._conn.lrange(
+ MessageEndpoints.unhandled_procedure_execution(queue), 0, -1
+ )
+ assert execution_list == []
+ assert _all_eq_except_id(unhandled_execution_list, queue_expected)
+
+
+@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock())
+def test_abort_queue(manager_with_test_msgs: _ManagerWithMsgs):
+ procedure_manager, expected = manager_with_test_msgs
+ remaining_expected = list(filter(lambda msg: msg.queue == "queue2", expected))
+ aborted_expected = list(filter(lambda msg: msg.queue == "queue1", expected))
+
+ q1_execution_list = procedure_manager._helper.get.exec_queue("queue1")
+ assert _all_eq_except_id(q1_execution_list, aborted_expected)
+ q2_execution_list = procedure_manager._helper.get.exec_queue("queue2")
+ assert _all_eq_except_id(q2_execution_list, remaining_expected)
+
+ procedure_manager._process_abort({"queue": "queue1"})
+
+ # on abort, the manager should move active queues to unhandled queues
+ # this should happen for q1 and not q2
+ q2_execution_list = procedure_manager._helper.get.exec_queue("queue2")
+ unhandled_execution_list = procedure_manager._conn.lrange(
+ MessageEndpoints.unhandled_procedure_execution("queue1"), 0, -1
+ )
+ assert _all_eq_except_id(q2_execution_list, remaining_expected)
+ assert _all_eq_except_id(unhandled_execution_list, aborted_expected)
+
+
+@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock())
+def test_abort_individual(manager_with_test_msgs: _ManagerWithMsgs):
+ procedure_manager, expected = manager_with_test_msgs
+ q1_expected = list(filter(lambda msg: msg.queue == "queue1", expected))
+ q2_expected = list(filter(lambda msg: msg.queue == "queue2", expected))
+
+ q1_execution_list = procedure_manager._helper.get.exec_queue("queue1")
+ assert _all_eq_except_id(
+ q1_execution_list, list(filter(lambda msg: msg.queue == "queue1", q1_expected))
+ )
+ q2_execution_list = procedure_manager._helper.get.exec_queue("queue2")
+ assert _all_eq_except_id(q2_execution_list, q2_expected)
+
+ procedure_manager._process_abort({"execution_id": q2_execution_list[1].execution_id})
+
+ q1_execution_list = procedure_manager._helper.get.exec_queue("queue1")
+ q2_execution_list = procedure_manager._helper.get.exec_queue("queue2")
+ assert _all_eq_except_id(q1_execution_list, q1_expected)
+ assert _all_eq_except_id(q2_execution_list, [q2_expected[0]])
+
+
+@patch("bec_server.scan_server.procedures.in_process_worker.BECClient", MagicMock())
+def test_abort_all(manager_with_test_msgs: _ManagerWithMsgs):
+ procedure_manager, expected = manager_with_test_msgs
+ q1_expected = list(filter(lambda msg: msg.queue == "queue1", expected))
+ q2_expected = list(filter(lambda msg: msg.queue == "queue2", expected))
+
+ q1_execution_list = procedure_manager._helper.get.exec_queue("queue1")
+ assert _all_eq_except_id(
+ q1_execution_list, list(filter(lambda msg: msg.queue == "queue1", q1_expected))
+ )
+ q2_execution_list = procedure_manager._helper.get.exec_queue("queue2")
+ assert _all_eq_except_id(q2_execution_list, q2_expected)
+
+ procedure_manager._process_abort({"abort_all": True})
+
+ q1_execution_list = procedure_manager._helper.get.exec_queue("queue1")
+ q2_execution_list = procedure_manager._helper.get.exec_queue("queue2")
+ q1_unhandled_list = procedure_manager._helper.get.unhandled_queue("queue1")
+ q2_unhandled_list = procedure_manager._helper.get.unhandled_queue("queue2")
+
+ assert q1_execution_list == []
+ assert q2_execution_list == []
+ assert _all_eq_except_id(q1_unhandled_list, q1_expected)
+ assert _all_eq_except_id(q2_unhandled_list, q2_expected)