Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/bec_e2e_install/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ runs:
cd ./_e2e_test_checkout_/bec
source ./bin/install_bec_dev.sh -t
pip install -e ../ophyd_devices
podman pod create --net host local_bec
podman pod create --network=host local_bec
python ./bec_ipython_client/tests/end-2-end/_ensure_requirements_container.py
pytest -v --files-path ./ --start-servers --random-order ./bec_ipython_client/tests/end-2-end/

43 changes: 26 additions & 17 deletions bec_ipython_client/tests/end-2-end/test_procedures_e2e.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from importlib.metadata import version
from typing import TYPE_CHECKING, Callable, Generator
from unittest.mock import MagicMock, patch
Expand All @@ -11,7 +12,7 @@
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_server.scan_server.procedures.constants import PROCEDURE
from bec_server.scan_server.procedures.constants import _CONTAINER, _WORKER
from bec_server.scan_server.procedures.container_utils import get_backend
from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker
from bec_server.scan_server.procedures.manager import ProcedureManager
Expand All @@ -28,6 +29,15 @@
pytestmark = pytest.mark.random_order(disabled=True)


@dataclass(frozen=True)
class PATCHED_CONSTANTS:
WORKER = _WORKER()
CONTAINER = _CONTAINER()
MANAGER_SHUTDOWN_TIMEOUT_S = 2
BEC_VERSION = version("bec_lib")
REDIS_HOST = "localhost"


@pytest.fixture
def client_logtool_and_manager(
bec_ipython_client_fixture_with_logtool: tuple[BECIPythonClient, "LogTestTool"],
Expand All @@ -52,7 +62,7 @@ def _wait_while(cond: Callable[[], bool], timeout_s):
def test_building_worker_image():
podman_utils = get_backend()
build = podman_utils.build_worker_image()
assert len(build._command_output.splitlines()[-1]) == 64
assert len(build._command_output.splitlines()[-1]) == 64 # type: ignore
assert podman_utils.image_exists(f"bec_procedure_worker:v{version('bec_lib')}")


Expand All @@ -62,7 +72,7 @@ def test_procedure_runner_spawns_worker(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
client, _, manager = client_logtool_and_manager
assert manager.active_workers == {}
assert manager._active_workers == {}
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
identifier="sleep", args_kwargs=((), {"time_s": 2}), queue="test"
Expand All @@ -77,34 +87,38 @@ def cb(worker: ContainerProcedureWorker):
manager.add_callback("test", cb)
client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump())

_wait_while(lambda: manager.active_workers == {}, 5)
_wait_while(lambda: manager.active_workers != {}, 20)
_wait_while(lambda: manager._active_workers == {}, 5)
_wait_while(lambda: manager._active_workers != {}, 20)

assert logs != []


@pytest.mark.timeout(100)
@patch("bec_server.scan_server.procedures.manager.procedure_registry.is_registered", lambda _: True)
@patch("bec_server.scan_server.procedures.container_worker.PROCEDURE", PATCHED_CONSTANTS())
def test_happy_path_container_procedure_runner(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
test_args = (1, 2, 3)
test_kwargs = {"a": "b", "c": "d"}
client, logtool, manager = client_logtool_and_manager
assert manager.active_workers == {}
assert manager._active_workers == {}
conn = client.connector
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
identifier="log execution message args", args_kwargs=(test_args, test_kwargs)
)
conn.xadd(topic=endpoint, msg_dict=msg.model_dump())

_wait_while(lambda: manager.active_workers == {}, 5)
_wait_while(lambda: manager.active_workers != {}, 20)
_wait_while(lambda: manager._active_workers == {}, 5)
_wait_while(lambda: manager._active_workers != {}, 20)

logtool.fetch()
assert logtool.is_present_in_any_message("procedure accepted: True, message:")
assert logtool.is_present_in_any_message("ContainerWorker started container for queue primary")
assert logtool.is_present_in_any_message(
"ContainerWorker started container for queue primary"
), f"Log content relating to procedures: {manager._logs}"

res, msg = logtool.are_present_in_order(
[
"Container worker 'primary' status update: IDLE",
Expand All @@ -114,12 +128,7 @@ def test_happy_path_container_procedure_runner(
]
)
assert res, f"failed on {msg}"
res, msg = logtool.are_present_in_order(
[
"Container worker 'primary' status update: IDLE",
f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}",
"Container worker 'primary' status update: IDLE",
"Container worker 'primary' status update: FINISHED",
]

assert logtool.is_present_in_any_message(
f"Builtin procedure log_message_args_kwargs called with args: {test_args} and kwargs: {test_kwargs}"
)
assert res, f"failed on {msg}"
85 changes: 80 additions & 5 deletions bec_lib/bec_lib/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class MessageOp(list[str], enum.Enum):
SET_PUBLISH = ["register", "set_and_publish", "delete", "get", "keys"]
SEND = ["send", "register"]
STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last", "delete"]
LIST = ["lpush", "lrange", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"]
LIST = ["lpush", "lrange", "lrem", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"]
KEY_VALUE = ["set", "get", "delete", "keys"]
SET = ["remove_from_set", "get_set_members"]
SET = ["remove_from_set", "get_set_members", "delete"]


MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]")
Expand Down Expand Up @@ -1420,8 +1420,8 @@ def available_procedures() -> EndpointInfo:
@staticmethod
def procedure_request() -> EndpointInfo:
"""
Endpoint for scan queue request. This endpoint is used to request the new scans.
The request is sent using a messages.ScanQueueMessage message.
Endpoint for requesting new procedures.
The request is sent using a messages.ProcedureRequestMessage message.

Returns:
EndpointInfo: Endpoint for scan queue request.
Expand Down Expand Up @@ -1460,7 +1460,24 @@ def procedure_execution(queue_id: str):
Returns:
EndpointInfo: Endpoint for scan queue request.
"""
endpoint = f"{EndpointType.INTERNAL.value}/procedures/procedure_execution/{queue_id}"
endpoint = f"{EndpointType.INFO.value}/procedures/procedure_execution/{queue_id}"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureExecutionMessage,
message_op=MessageOp.LIST,
)

@staticmethod
def unhandled_procedure_execution(queue_id: str):
"""
Endpoint for procedure executions which were pending when the manager was shutdown.
Messages from procedure_execution are moved here on manager startup.
The request is sent using a messages.ProcedureExecutionMessage message.

Returns:
EndpointInfo: Endpoint for scan queue request.
"""
endpoint = f"{EndpointType.INFO.value}/procedures/unhandled_procedure_execution/{queue_id}"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureExecutionMessage,
Expand All @@ -1483,6 +1500,36 @@ def active_procedure_executions():
message_op=MessageOp.SET,
)

@staticmethod
def procedure_abort():
"""
Endpoint to request aborting a running procedure

Returns:
EndpointInfo: Endpoint to request procedure abortion.
"""
endpoint = f"{EndpointType.USER.value}/procedures/abort"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureAbortMessage,
message_op=MessageOp.STREAM,
)

@staticmethod
def procedure_clear_unhandled():
"""
Endpoint to request removing an aborted procedure

Returns:
EndpointInfo: Endpoint to request removing aborted procedures.
"""
endpoint = f"{EndpointType.USER.value}/procedures/clear_unhandled"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureClearUnhandledMessage,
message_op=MessageOp.STREAM,
)

@staticmethod
def procedure_worker_status_update(queue_id: str):
"""
Expand All @@ -1498,6 +1545,34 @@ def procedure_worker_status_update(queue_id: str):
message_op=MessageOp.LIST,
)

@staticmethod
def procedure_queue_notif():
"""
PubSub channel for a consumer (e.g. BEC widgets) to be notified of changes to a procedure queue

Returns:
EndpointInfo: Endpoint for procedure queue updates for given queue.
"""
endpoint = f"{EndpointType.INFO.value}/procedures/queue_notif"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureQNotifMessage,
message_op=MessageOp.SEND,
)

@staticmethod
def procedure_logs(queue: str):
"""
Endpoint for logs for a given procedure queue

Returns:
EndpointInfo: Endpoint for procedure queue updates for given queue.
"""
endpoint = f"{EndpointType.INFO.value}/procedures/logs/{queue}"
return EndpointInfo(
endpoint=endpoint, message_type=messages.RawMessage, message_op=MessageOp.STREAM
)

@staticmethod
def gui_registry_state(gui_id: str):
"""
Expand Down
13 changes: 9 additions & 4 deletions bec_lib/bec_lib/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class BECLogger:
"<green>{service_name} | {{time:YYYY-MM-DD HH:mm:ss.SSS}}</green> | <level>{{level}}</level> |"
" <level>{{thread.name}} ({{thread.id}})</level> | <cyan>{{extra[stack]}}</cyan> - <level>{{message}}</level>\n"
)
CONTAINER_FORMAT = "{{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} | {{message}}\n"
LOGLEVEL = LogLevel

_logger = None
Expand Down Expand Up @@ -224,18 +225,21 @@ def _logger_callback(self, msg):
# because it depends on the connector
pass

def get_format(self, level: LogLevel = None, is_stderr=False) -> str:
def get_format(self, level: LogLevel = None, is_stderr=False, is_container=False) -> str:
"""
Get the format for a specific log level.

Args:
level (LogLevel, optional): Log level. Defaults to None. If None, the current log level will be used.
is_stderr (bool, optional): Whether the log is for stderr. Defaults to False.
is_container (bool, optional): Simple logging for procedure container. Defaults to False.

Returns:
str: Log format.
"""
service_name = self.service_name if self.service_name else ""
if is_container:
return self.CONTAINER_FORMAT.format()
if level is None:
level = self.level
if level > self.LOGLEVEL.DEBUG:
Expand All @@ -246,15 +250,16 @@ def get_format(self, level: LogLevel = None, is_stderr=False) -> str:
return self.DEBUG_FORMAT.format(service_name=service_name)
return self.TRACE_FORMAT.format(service_name=service_name)

def formatting(self, is_stderr=False):
def formatting(self, is_stderr=False, is_container=False):
"""
Format the log message.

Args:
record (dict): Log record.
is_container (bool, optional): Simple logging for procedure container. Defaults to False.

Returns:
dict: Formatted log record.
str: Log format.
"""

def _update_record(record):
Expand All @@ -269,7 +274,7 @@ def _update_record(record):

def _format(record):
level = _update_record(record)
return self.get_format(level)
return self.get_format(level, is_container=is_container)

def _format_stderr(record):
level = _update_record(record)
Expand Down
Loading
Loading