diff --git a/Makefile b/Makefile index eac0b51..f468af2 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,8 @@ Variables: currently: "$(MANAGER_KEY)" - MANAGER_DATA host directory to mount into `/data` (shared with Manager) currently: "$(MANAGER_DATA)" + - MANAGER_WORKFLOWS host directory to mount into `/workflows` (shared with Manager) + currently: "$(MANAGER_WORKFLOWS)" - NETWORK Docker network to use (manage via "docker network") currently: $(NETWORK) - CONTROLLER_HOST network address for the Controller client @@ -42,6 +44,7 @@ help: ; @eval "$$HELP" MANAGER_KEY ?= $(firstword $(filter-out %.pub,$(wildcard $(HOME)/.ssh/id_*))) MANAGER_DATA ?= $(CURDIR) +MANAGER_WORKFLOWS ?= $(CURDIR) MONITOR_PORT_WEB ?= 5000 NETWORK ?= bridge CONTROLLER_HOST ?= $(shell dig +short $$HOSTNAME) @@ -54,6 +57,8 @@ run: $(DATA) -p $(MONITOR_PORT_WEB):5000 \ -v ${MANAGER_KEY}:/id_rsa \ --mount type=bind,source=$(MANAGER_KEY),target=/id_rsa \ + -v $(MANAGER_DATA):/data \ + -v $(MANAGER_WORKFLOWS):/workflows \ -v shared:/run/lock/ocrd.jobs \ -e CONTROLLER=$(CONTROLLER_HOST):$(CONTROLLER_PORT_SSH) \ -e MONITOR_PORT_LOG=${MONITOR_PORT_LOG} \ diff --git a/README.md b/README.md index 4ab2c16..3f7a563 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ In order to work properly, the following **environment variables** must be set: | CONTROLLER_HOST | Hostname of the OCR-D Controller | | CONTROLLER_PORT_SSH | Port on the OCR-D Controller host that allows a SSH connection | | MANAGER_DATA | Path to the OCR-D workspaces on the host | +| MANAGER_WORKFLOWS | Path to the OCR-D workflows on the host | | MANAGER_KEY | Path to a private key that can be used to authenticate with the OCR-D Controller | | MONITOR_PORT_WEB | The port at which the OCR-D Monitor will be available on the host | | MONITOR_PORT_LOG | The port at which the Dozzle logs will be available on the host | diff --git a/docker-compose.yml b/docker-compose.yml index 1771203..afbe83a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,7 @@ services: volumes: - ${MANAGER_DATA}:/data + - ${MANAGER_WORKFLOWS}:/workflows - ${MANAGER_KEY}:/id_rsa - shared:/run/lock/ocrd.jobs diff --git a/init.sh b/init.sh index 5671db6..f74408e 100755 --- a/init.sh +++ b/init.sh @@ -24,7 +24,7 @@ export OCRD_LOGVIEW__PORT=$MONITOR_PORT_LOG export OCRD_CONTROLLER__JOB_DIR=/run/lock/ocrd.jobs export OCRD_CONTROLLER__HOST=$CONTROLLER_HOST export OCRD_CONTROLLER__PORT=$CONTROLLER_PORT -export OCRD_CONTROLLER__USER=ocrd +export OCRD_CONTROLLER__USER=admin export OCRD_CONTROLLER__KEYFILE=~/.ssh/id_rsa cd /usr/local/ocrd-monitor diff --git a/ocrdmonitor/ocrdcontroller.py b/ocrdmonitor/ocrdcontroller.py index 9ed497d..a980a9a 100644 --- a/ocrdmonitor/ocrdcontroller.py +++ b/ocrdmonitor/ocrdcontroller.py @@ -1,11 +1,13 @@ from __future__ import annotations import sys +import logging from pathlib import Path from typing import Protocol +from ocrdmonitor import sshremote from ocrdmonitor.ocrdjob import OcrdJob -from ocrdmonitor.processstatus import ProcessStatus +from ocrdmonitor.processstatus import ProcessStatus, ProcessState if sys.version_info >= (3, 10): from typing import TypeGuard @@ -13,40 +15,53 @@ from typing_extensions import TypeGuard -class ProcessQuery(Protocol): - def __call__(self, process_group: int) -> list[ProcessStatus]: +class RemoteServer(Protocol): + async def read_file(self, path: str) -> str: + ... + + async def process_status(self, process_group: int) -> list[ProcessStatus]: ... class OcrdController: - def __init__(self, process_query: ProcessQuery, job_dir: Path) -> None: - self._process_query = process_query + def __init__(self, remote: RemoteServer, job_dir: Path) -> None: + self._remote = remote self._job_dir = job_dir + logging.info(f"process_query: {remote}") + logging.info(f"job_dir: {job_dir}") def get_jobs(self) -> list[OcrdJob]: def is_ocrd_job(j: OcrdJob | None) -> TypeGuard[OcrdJob]: return j is not None job_candidates = [ - self._try_parse(job_file.read_text()) + self._try_parse(job_file) for job_file in self._job_dir.iterdir() if job_file.is_file() ] return list(filter(is_ocrd_job, job_candidates)) - def _try_parse(self, job_str: str) -> OcrdJob | None: + def _try_parse(self, job_file: Path) -> OcrdJob | None: + job_str = job_file.read_text() try: return OcrdJob.from_str(job_str) - except (ValueError, KeyError): + except (ValueError, KeyError) as e: + logging.warning(f"found invalid job file: {job_file.name} - {e}") return None - def status_for(self, ocrd_job: OcrdJob) -> ProcessStatus | None: - if ocrd_job.pid is None: + async def status_for(self, ocrd_job: OcrdJob) -> ProcessStatus | None: + if ocrd_job.remotedir is None: return None - process_statuses = self._process_query(ocrd_job.pid) - matching_statuses = ( - status for status in process_statuses if status.pid == ocrd_job.pid - ) - return next(matching_statuses, None) + pid = await self._remote.read_file(f"/data/{ocrd_job.remotedir}/ocrd.pid") + process_statuses = await self._remote.process_status(int(pid)) + + for status in process_statuses: + if status.state == ProcessState.RUNNING: + return status + + if process_statuses: + return process_statuses[0] + + return None diff --git a/ocrdmonitor/ocrdjob.py b/ocrdmonitor/ocrdjob.py index 58d51bf..51080a3 100644 --- a/ocrdmonitor/ocrdjob.py +++ b/ocrdmonitor/ocrdjob.py @@ -1,15 +1,18 @@ from __future__ import annotations +from datetime import datetime from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, NamedTuple, Type +from typing import Any, Callable, NamedTuple, Type -_KEYMAP: dict[str, tuple[Type[int] | Type[str] | Type[Path], str]] = { +_KEYMAP: dict[str, tuple[Type[int] | Type[str] | Type[Path] | Callable[[str], datetime], str]] = { "PID": (int, "pid"), "RETVAL": (int, "return_code"), - "PROCESS_ID": (int, "process_id"), - "TASK_ID": (int, "task_id"), + "TIME_CREATED": (datetime.fromisoformat, "time_created"), + "TIME_TERMINATED": (datetime.fromisoformat, "time_terminated"), + "PROCESS_ID": (str, "process_id"), + "TASK_ID": (str, "task_id"), "PROCESS_DIR": (Path, "processdir"), "WORKDIR": (Path, "workdir"), "WORKFLOW": (Path, "workflow_file"), @@ -18,7 +21,7 @@ } -def _into_dict(content: str) -> dict[str, int | str | Path]: +def _into_dict(content: str) -> dict[str, int | str | Path | datetime]: result_dict = {} lines = content.splitlines() for line in lines: @@ -35,8 +38,8 @@ def _into_dict(content: str) -> dict[str, int | str | Path]: class KitodoProcessDetails(NamedTuple): - process_id: int - task_id: int + process_id: str + task_id: str processdir: Path @@ -59,6 +62,9 @@ class OcrdJob: pid: int | None = None return_code: int | None = None + time_created: datetime | None = None + time_terminated: datetime | None = None + @classmethod def from_str(cls, content: str) -> "OcrdJob": """ diff --git a/ocrdmonitor/processstatus.py b/ocrdmonitor/processstatus.py index de4528a..433284b 100644 --- a/ocrdmonitor/processstatus.py +++ b/ocrdmonitor/processstatus.py @@ -1,17 +1,17 @@ from __future__ import annotations -import subprocess from dataclasses import dataclass from datetime import timedelta from enum import Enum -PS_CMD = "ps -g {} -o pid,state,%cpu,rss,cputime --no-headers" - class ProcessState(Enum): + # see ps(1)#PROCESS_STATE_CODES RUNNING = "R" SLEEPING = "S" + SLEEPIO = "D" STOPPED = "T" + TRACING = "t" ZOMBIE = "Z" UNKNOWN = "?" @@ -28,7 +28,11 @@ class ProcessStatus: cpu_time: timedelta @classmethod - def from_ps_output(cls, ps_output: str) -> list["ProcessStatus"]: + def shell_command(cls, pid: int) -> str: + return f"ps -s {pid} -o pid,state,%cpu,rss,cputime --no-headers" + + @classmethod + def from_shell_output(cls, ps_output: str) -> list["ProcessStatus"]: def is_error(lines: list[str]) -> bool: return lines[0].startswith("error:") @@ -49,13 +53,6 @@ def parse_line(line: str) -> "ProcessStatus": return [parse_line(line) for line in lines] -def run(group: int) -> list[ProcessStatus]: - cmd = PS_CMD.format(group) - result = subprocess.run(cmd, shell=True, capture_output=True, text=True) - - return ProcessStatus.from_ps_output(result.stdout) - - def _cpu_time_to_seconds(cpu_time: str) -> int: hours, minutes, seconds, *_ = cpu_time.split(":") return int(hours) * 3600 + int(minutes) * 60 + int(seconds) diff --git a/ocrdmonitor/server/app.py b/ocrdmonitor/server/app.py index 551fc17..df1a593 100644 --- a/ocrdmonitor/server/app.py +++ b/ocrdmonitor/server/app.py @@ -35,7 +35,7 @@ async def swallow_exceptions(request: Request, err: Exception) -> Response: create_jobs( templates, OcrdController( - settings.ocrd_controller.process_query(), + settings.ocrd_controller.controller_remote(), settings.ocrd_controller.job_dir, ), ) diff --git a/ocrdmonitor/server/jobs.py b/ocrdmonitor/server/jobs.py index 23fcf03..87b82f5 100644 --- a/ocrdmonitor/server/jobs.py +++ b/ocrdmonitor/server/jobs.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime, timezone from dataclasses import dataclass from typing import Iterable @@ -42,19 +43,26 @@ def create_jobs(templates: Jinja2Templates, controller: OcrdController) -> APIRo router = APIRouter(prefix="/jobs") @router.get("/", name="jobs") - def jobs(request: Request) -> Response: + async def jobs(request: Request) -> Response: jobs = controller.get_jobs() running, completed = split_into_running_and_completed(jobs) - job_status = [controller.status_for(job) for job in running] + job_status = [await controller.status_for(job) for job in running] running_jobs = wrap_in_running_job_type(running, job_status) + now = datetime.now(timezone.utc) return templates.TemplateResponse( "jobs.html.j2", { "request": request, - "running_jobs": running_jobs, - "completed_jobs": completed, + "running_jobs": sorted( + running_jobs, + key=lambda x: x.ocrd_job.time_created or now, + ), + "completed_jobs": sorted( + completed, + key=lambda x: x.time_terminated or now, + ), }, ) diff --git a/ocrdmonitor/server/settings.py b/ocrdmonitor/server/settings.py index 3a3ab07..2ccfc27 100644 --- a/ocrdmonitor/server/settings.py +++ b/ocrdmonitor/server/settings.py @@ -2,7 +2,6 @@ import asyncio import atexit -from functools import partial from pathlib import Path from typing import Literal @@ -13,8 +12,9 @@ OcrdBrowserFactory, SubProcessOcrdBrowserFactory, ) -from ocrdmonitor.ocrdcontroller import ProcessQuery -from ocrdmonitor.sshps import process_status + +from ocrdmonitor.ocrdcontroller import RemoteServer +from ocrdmonitor.sshremote import SSHRemote class OcrdControllerSettings(BaseModel): @@ -24,8 +24,8 @@ class OcrdControllerSettings(BaseModel): port: int = 22 keyfile: Path = Path.home() / ".ssh" / "id_rsa" - def process_query(self) -> ProcessQuery: - return partial(process_status, self) + def controller_remote(self) -> RemoteServer: + return SSHRemote(self) class OcrdLogViewSettings(BaseModel): diff --git a/ocrdmonitor/server/templates/jobs.html.j2 b/ocrdmonitor/server/templates/jobs.html.j2 index f1833bc..65e98a6 100644 --- a/ocrdmonitor/server/templates/jobs.html.j2 +++ b/ocrdmonitor/server/templates/jobs.html.j2 @@ -9,10 +9,17 @@ {% endblock %} {% block content %} +
TSTART | TASK ID | PROCESS ID | WORKFLOW | @@ -21,19 +28,22 @@% CPU | MB RSS | DURATION | +ACTION | |||
---|---|---|---|---|---|---|---|---|---|---|
{{ job.ocrd_job.time_created }} | {{ job.ocrd_job.kitodo_details.task_id }} | {{ job.ocrd_job.kitodo_details.process_id }} | -{{ job.ocrd_job.workflow }} | +{{ job.ocrd_job.workflow }} | {{ job.process_status.pid }} | {{ job.process_status.state }} | {{ job.process_status.percent_cpu }} | {{ job.process_status.memory }} | {{ job.process_status.cpu_time }} | +
TSTOP | TASK ID | PROCESS ID | WORKFLOW | @@ -53,9 +64,10 @@||||
---|---|---|---|---|---|---|---|
{{ job.time_terminated }} | {{ job.kitodo_details.task_id }} | {{ job.kitodo_details.process_id }} | -{% if job.workflow is defined %}{{ job.workflow }}{% endif %} | +{{ job.workflow }} | {{ job.return_code }} {% if job.return_code == 0 %}(SUCCESS){% else %}(FAILURE){% endif %} | {{ job.kitodo_details.processdir.name }} | diff --git a/ocrdmonitor/sshps.py b/ocrdmonitor/sshps.py deleted file mode 100644 index b37610c..0000000 --- a/ocrdmonitor/sshps.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import subprocess -from pathlib import Path -from typing import Protocol - -from ocrdmonitor.processstatus import PS_CMD, ProcessStatus - - -class SSHConfig(Protocol): - host: str - port: int - user: str - keyfile: Path - - -_SSH = ( - "ssh -o StrictHostKeyChecking=no -i '{keyfile}' -p {port} {user}@{host} '{ps_cmd}'" -) - - -def process_status(config: SSHConfig, process_group: int) -> list[ProcessStatus]: - ssh_cmd = _build_ssh_command(config, process_group) - - result = subprocess.run( - ssh_cmd, - shell=True, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - ) - - return ProcessStatus.from_ps_output(result.stdout) - - -def _build_ssh_command(config: SSHConfig, process_group: int) -> str: - ps_cmd = PS_CMD.format(process_group or "") - return _SSH.format( - port=config.port, - keyfile=config.keyfile, - user=config.user, - host=config.host, - ps_cmd=ps_cmd, - ) diff --git a/ocrdmonitor/sshremote.py b/ocrdmonitor/sshremote.py new file mode 100644 index 0000000..598255a --- /dev/null +++ b/ocrdmonitor/sshremote.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import asyncio +import logging +import shlex +from pathlib import Path +from typing import Protocol + +from ocrdmonitor.processstatus import ProcessStatus + + +class SSHConfig(Protocol): + host: str + port: int + user: str + keyfile: Path + + +class SSHRemote: + def __init__(self, config: SSHConfig) -> None: + self._config = config + + async def read_file(self, path: str) -> str: + result = await asyncio.create_subprocess_shell( + _ssh(self._config, f"cat {path}"), + stdout=asyncio.subprocess.PIPE, + ) + await result.wait() + + if not result.stdout: + return "" + + return (await result.stdout.read()).decode() + + async def process_status(self, process_group: int) -> list[ProcessStatus]: + pid_cmd = ProcessStatus.shell_command(process_group) + result = await asyncio.create_subprocess_shell( + _ssh(self._config, pid_cmd), + stdout=asyncio.subprocess.PIPE, + ) + + if await result.wait() > 0: + logging.error( + f"checking status of process {process_group} failed: {result.stderr}" + ) + return [] + + if not result.stdout: + return [] + + output = (await result.stdout.read()).decode() + return ProcessStatus.from_shell_output(output) + + +def _ssh(config: SSHConfig, cmd: str) -> str: + return shlex.join( + ( + "ssh", + "-o", + "StrictHostKeyChecking=no", + "-i", + str(config.keyfile), + "-p", + str(config.port), + f"{config.user}@{config.host}", + *shlex.split(cmd), + ) + ) diff --git a/tests/ocrdmonitor/server/test_job_endpoint.py b/tests/ocrdmonitor/server/test_job_endpoint.py index 5715da9..58fa120 100644 --- a/tests/ocrdmonitor/server/test_job_endpoint.py +++ b/tests/ocrdmonitor/server/test_job_endpoint.py @@ -8,7 +8,7 @@ import pytest from fastapi.testclient import TestClient from httpx import Response -from ocrdmonitor.ocrdcontroller import ProcessQuery +from ocrdmonitor.ocrdcontroller import RemoteServer from ocrdmonitor.ocrdjob import OcrdJob from ocrdmonitor.processstatus import ProcessState, ProcessStatus from ocrdmonitor.server.settings import OcrdControllerSettings @@ -37,7 +37,7 @@ def running_ocrd_job( expected_status = make_status(pid) running_job = replace(JOB_TEMPLATE, pid=pid) jobfile = write_job_file_for(running_job) - patch_process_query(monkeypatch, expected_status) + patch_controller_remote(monkeypatch, expected_status) yield running_job, expected_status @@ -86,18 +86,20 @@ def make_status(pid: int) -> ProcessStatus: return expected_status -def patch_process_query( +def patch_controller_remote( monkeypatch: pytest.MonkeyPatch, expected_status: ProcessStatus ) -> None: - def make_process_query(self: OcrdControllerSettings) -> ProcessQuery: - def process_query_stub(process_group: int) -> list[ProcessStatus]: - if process_group != expected_status.pid: - raise ValueError(f"Unexpected process group {process_group}") - return [expected_status] + def make_remote_stub(self: OcrdControllerSettings) -> RemoteServer: + class RemoteStub: + async def read_file(self, path: str) -> str: + return str(expected_status.pid) - return process_query_stub + async def process_status(self, process_group: int) -> list[ProcessStatus]: + return [expected_status] - monkeypatch.setattr(OcrdControllerSettings, "process_query", make_process_query) + return RemoteStub() + + monkeypatch.setattr(OcrdControllerSettings, "controller_remote", make_remote_stub) def assert_lists_completed_job( @@ -106,6 +108,7 @@ def assert_lists_completed_job( texts = collect_texts_from_job_table(response.content, "completed-jobs") assert texts == [ + str(completed_job.time_terminated), str(completed_job.kitodo_details.task_id), str(completed_job.kitodo_details.process_id), completed_job.workflow_file.name, @@ -123,6 +126,7 @@ def assert_lists_running_job( texts = collect_texts_from_job_table(response.content, "running-jobs") assert texts == [ + str(running_job.time_created), str(running_job.kitodo_details.task_id), str(running_job.kitodo_details.process_id), running_job.workflow_file.name, @@ -135,7 +139,7 @@ def assert_lists_running_job( def collect_texts_from_job_table(content: bytes, table_id: str) -> list[str]: - selector = f"#{table_id} td:not(:has(a)), #{table_id} td > a" + selector = f"#{table_id} td:not(:has(a)):not(:has(button)), #{table_id} td > a" return scraping.parse_texts(content, selector) diff --git a/tests/ocrdmonitor/test_jobs.py b/tests/ocrdmonitor/test_jobs.py index dd51c54..c47fb97 100644 --- a/tests/ocrdmonitor/test_jobs.py +++ b/tests/ocrdmonitor/test_jobs.py @@ -1,4 +1,5 @@ from dataclasses import replace +from datetime import datetime, timedelta from pathlib import Path @@ -16,19 +17,25 @@ REMOTEDIR={remotedir} WORKFLOW={workflow} CONTROLLER={controller_address} +TIME_CREATED={created_at} +TIME_TERMINATED={terminated_at} """ +created_at = datetime(2023, 4, 12, hour=13, minute=0, second=0) +terminated_at = created_at + timedelta(hours=1) JOB_TEMPLATE = OcrdJob( kitodo_details=KitodoProcessDetails( - process_id=5432, - task_id=45989, + process_id="5432", + task_id="45989", processdir=Path("/data/5432"), ), workdir=Path("ocr-d/data/5432"), workflow_file=Path("ocr-workflow-default.sh"), remotedir="/remote/job/dir", controller_address="controller.ocrdhost.com", + time_created=created_at, + time_terminated=terminated_at, ) @@ -41,6 +48,8 @@ def jobfile_content_for(job: OcrdJob) -> str: workflow=job.workflow_file.as_posix(), remotedir=job.remotedir, controller_address=job.controller_address, + created_at=created_at, + terminated_at=terminated_at, ) if job.pid is not None: @@ -52,7 +61,9 @@ def jobfile_content_for(job: OcrdJob) -> str: return out -def test__parsing_a_ocrd_job_file_for_completed_job__returns_ocrdjob_with_a_return_code() -> None: +def test__parsing_a_ocrd_job_file_for_completed_job__returns_ocrdjob_with_a_return_code() -> ( + None +): expected = replace(JOB_TEMPLATE, return_code=0) content = jobfile_content_for(expected) @@ -61,7 +72,9 @@ def test__parsing_a_ocrd_job_file_for_completed_job__returns_ocrdjob_with_a_retu assert actual == expected -def test__parsing_a_ocrd_job_file_for_running_job__returns_ocrdjob_with_a_process_id() -> None: +def test__parsing_a_ocrd_job_file_for_running_job__returns_ocrdjob_with_a_process_id() -> ( + None +): expected = replace(JOB_TEMPLATE, pid=1) content = jobfile_content_for(expected) diff --git a/tests/ocrdmonitor/test_processstatus.py b/tests/ocrdmonitor/test_processstatus.py index a0f1111..3347959 100644 --- a/tests/ocrdmonitor/test_processstatus.py +++ b/tests/ocrdmonitor/test_processstatus.py @@ -1,7 +1,7 @@ import datetime import pytest -from ocrdmonitor.processstatus import ProcessState, ProcessStatus, run +from ocrdmonitor.processstatus import ProcessState, ProcessStatus PS_OUTPUT = """ 1 Ss 0.0 3872 01:12:46 @@ -16,7 +16,7 @@ def test__parsing_psoutput__returns_list_of_process_status() -> None: - actual = ProcessStatus.from_ps_output(PS_OUTPUT) + actual = ProcessStatus.from_shell_output(PS_OUTPUT) assert actual == [ ProcessStatus( @@ -38,6 +38,6 @@ def test__parsing_psoutput__returns_list_of_process_status() -> None: @pytest.mark.parametrize("output", FAILING_OUTPUTS) def test__parsing_psoutput_with_error__returns_empty_list(output: str) -> None: - actual = ProcessStatus.from_ps_output(output) + actual = ProcessStatus.from_shell_output(output) assert actual == [] diff --git a/tests/ocrdmonitor/test_sshps.py b/tests/ocrdmonitor/test_sshremote.py similarity index 51% rename from tests/ocrdmonitor/test_sshps.py rename to tests/ocrdmonitor/test_sshremote.py index e008383..5f94add 100644 --- a/tests/ocrdmonitor/test_sshps.py +++ b/tests/ocrdmonitor/test_sshremote.py @@ -1,11 +1,11 @@ from pathlib import Path -from typing import Any, Callable, TypeVar - import pytest +from typing import Any, Awaitable, Callable, TypeVar + from testcontainers.general import DockerContainer from ocrdmonitor.processstatus import ProcessState -from ocrdmonitor.sshps import process_status +from ocrdmonitor.sshremote import SSHRemote from tests.ocrdmonitor.sshcontainer import ( get_process_group_from_container, SSHConfig, @@ -15,28 +15,30 @@ T = TypeVar("T") -def run_until_truthy(fn: Callable[..., T], *args: Any) -> T: - while not (result := fn(*args)): - continue - - return result - - +@pytest.mark.asyncio @pytest.mark.integration -def test_ps_over_ssh__returns_list_of_process_status( +async def test_ps_over_ssh__returns_list_of_process_status( openssh_server: DockerContainer, ) -> None: process_group = get_process_group_from_container(openssh_server) - - config = SSHConfig( - host="localhost", - port=2222, - user="testcontainer", - keyfile=Path(KEYDIR) / "id.rsa", + sut = SSHRemote( + config=SSHConfig( + host="localhost", + port=2222, + user="testcontainer", + keyfile=Path(KEYDIR) / "id.rsa", + ), ) - actual = run_until_truthy(process_status, config, process_group) + actual = await run_until_truthy(sut.process_status, process_group) first_process = actual[0] assert first_process.pid == process_group assert first_process.state == ProcessState.SLEEPING + + +async def run_until_truthy(fn: Callable[..., Awaitable[T]], *args: Any) -> T: + while not (result := await fn(*args)): + continue + + return result |