diff --git a/src/isolate/connections/grpc/_base.py b/src/isolate/connections/grpc/_base.py index a4d3bfc..5ebcd80 100644 --- a/src/isolate/connections/grpc/_base.py +++ b/src/isolate/connections/grpc/_base.py @@ -1,4 +1,6 @@ +import os import socket +import subprocess from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -23,6 +25,11 @@ class AgentError(Exception): """An internal problem caused by (most probably) the agent.""" +PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float( + os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60") +) + + @dataclass class GRPCExecutionBase(EnvironmentConnection): """A customizable gRPC-based execution backend.""" @@ -128,9 +135,18 @@ def find_free_port() -> Tuple[str, int]: with self.start_process(address) as process: yield address, grpc.local_channel_credentials() finally: - if process is not None: - # TODO: should we check the status code here? + self.terminate_process(process) + + def terminate_process(self, process: Union[None, subprocess.Popen]) -> None: + if process is not None: + try: + print("Terminating the agent process...") process.terminate() + process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS) + print("Agent process shutdown gracefully") + except Exception as exc: + print(f"Failed to shutdown the agent process gracefully: {exc}") + process.kill() def get_python_cmd( self, diff --git a/src/isolate/connections/grpc/agent.py b/src/isolate/connections/grpc/agent.py index 08bca36..33c7205 100644 --- a/src/isolate/connections/grpc/agent.py +++ b/src/isolate/connections/grpc/agent.py @@ -10,20 +10,21 @@ from __future__ import annotations +import asyncio import os import sys import traceback from argparse import ArgumentParser -from concurrent import futures from dataclasses import dataclass from typing import ( Any, + AsyncIterator, Iterable, - Iterator, ) -import grpc -from grpc import ServicerContext, StatusCode +from grpc import StatusCode, aio, local_server_credentials + +from isolate.connections.grpc.definitions import PartialRunResult try: from isolate import __version__ as agent_version @@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None): self._run_cache: dict[str, Any] = {} self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w") - def Run( + async def Run( self, request: definitions.FunctionCall, - context: ServicerContext, - ) -> Iterator[definitions.PartialRunResult]: + context: aio.ServicerContext, + ) -> AsyncIterator[PartialRunResult]: self.log(f"A connection has been established: {context.peer()}!") server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown" self.log(f"Isolate info: server {server_version}, agent {agent_version}") @@ -87,7 +88,8 @@ def Run( ) raise AbortException("The setup function has thrown an error.") except AbortException as exc: - return self.abort_with_msg(context, exc.message) + self.abort_with_msg(context, exc.message) + return else: assert not was_it_raised self._run_cache[cache_key] = result @@ -107,7 +109,8 @@ def Run( stringized_tb, ) except AbortException as exc: - return self.abort_with_msg(context, exc.message) + self.abort_with_msg(context, exc.message) + return def execute_function( self, @@ -195,7 +198,7 @@ def log(self, message: str) -> None: def abort_with_msg( self, - context: ServicerContext, + context: aio.ServicerContext, message: str, *, code: StatusCode = StatusCode.INVALID_ARGUMENT, @@ -205,23 +208,26 @@ def abort_with_msg( return None -def create_server(address: str) -> grpc.Server: +def create_server(address: str) -> aio.Server: """Create a new (temporary) gRPC server listening on the given address.""" - server = grpc.server( - futures.ThreadPoolExecutor(max_workers=1), - maximum_concurrent_rpcs=1, + # Use asyncio server so requests can run in the main thread and intercept signals + # There seems to be a weird bug with grpcio that makes subsequent requests fail with + # concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2 + # fixes it, even though in practice, we only run one request at a time. + server = aio.server( + maximum_concurrent_rpcs=2, options=get_default_options(), ) # Local server credentials allow us to ensure that the # connection is established by a local process. - server_credentials = grpc.local_server_credentials() + server_credentials = local_server_credentials() server.add_secure_port(address, server_credentials) return server -def run_agent(address: str, log_fd: int | None = None) -> int: +async def run_agent(address: str, log_fd: int | None = None) -> int: """Run the agent servicer on the given address.""" server = create_server(address) servicer = AgentServicer(log_fd=log_fd) @@ -231,19 +237,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int: # not have any global side effects. definitions.register_agent(servicer, server) - server.start() - server.wait_for_termination() + await server.start() + await server.wait_for_termination() return 0 -def main() -> int: +async def main() -> int: parser = ArgumentParser() parser.add_argument("address", type=str) parser.add_argument("--log-fd", type=int) options = parser.parse_args() - return run_agent(options.address, log_fd=options.log_fd) + return await run_agent(options.address, log_fd=options.log_fd) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index 12a9931..b8a8884 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -2,6 +2,7 @@ import functools import os +import signal import threading import time import traceback @@ -178,11 +179,17 @@ class RunTask: def cancel(self): while True: - self.future.cancel() + # Cancelling a running future is not possible, and it sometimes blocks, + # which means we never terminate the agent. So check if it's not running + if self.future and not self.future.running(): + self.future.cancel() + if self.agent: self.agent.terminate() + try: - self.future.exception(timeout=0.1) + if self.future: + self.future.exception(timeout=0.1) return except futures.TimeoutError: pass @@ -197,6 +204,7 @@ class IsolateServicer(definitions.IsolateServicer): bridge_manager: BridgeManager default_settings: IsolateSettings = field(default_factory=IsolateSettings) background_tasks: dict[str, RunTask] = field(default_factory=dict) + _shutting_down: bool = field(default=False) _thread_pool: futures.ThreadPoolExecutor = field( default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS) @@ -420,6 +428,17 @@ def Cancel( return definitions.CancelResponse() + def shutdown(self) -> None: + if self._shutting_down: + print("Shutdown already in progress...") + return + + self._shutting_down = True + task_count = len(self.background_tasks) + print(f"Shutting down, canceling {task_count} tasks...") + self.cancel_tasks() + print("All tasks canceled.") + def watch_queue_until_completed( self, queue: Queue, is_completed: Callable[[], bool] ) -> Iterator[definitions.PartialRunResult]: @@ -584,8 +603,10 @@ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any: def termination() -> None: if is_run: print("Stopping server since run is finished") + self.servicer.shutdown() # Stop the server after the Run task is finished self.server.stop(grace=0.1) + print("Server stopped") elif is_submit: # Wait until the task_id is assigned @@ -610,7 +631,9 @@ def _stop(*args): # Small sleep to make sure the cancellation is processed time.sleep(0.1) print("Stopping server since the task is finished") + self.servicer.shutdown() self.server.stop(grace=0.1) + print("Server stopped") # Add a callback which will stop the server # after the task is finished @@ -671,11 +694,20 @@ def main(argv: list[str] | None = None) -> None: definitions.register_isolate(servicer, server) health.register_health(HealthServicer(), server) - server.add_insecure_port("[::]:50001") - print("Started listening at localhost:50001") + def handle_termination(*args): + print("Termination signal received, shutting down...") + servicer.shutdown() + server.stop(grace=0.1) + + signal.signal(signal.SIGINT, handle_termination) + signal.signal(signal.SIGTERM, handle_termination) + + server.add_insecure_port(f"[::]:{options.port}") + print(f"Started listening at {options.host}:{options.port}") server.start() server.wait_for_termination() + print("Server shut down") if __name__ == "__main__": diff --git a/tests/test_server.py b/tests/test_server.py index d41b55e..3cbc388 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -538,7 +538,7 @@ def test_bridge_connection_reuse_logs( run_request(stub, request, user_logs=logs) str_logs = [log.message for log in logs if log.message] - assert str_logs == [ + assert str_logs[:4] == [ "setup", "run", "run", diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py new file mode 100644 index 0000000..4e58e6c --- /dev/null +++ b/tests/test_shutdown.py @@ -0,0 +1,179 @@ +"""End-to-end tests for graceful shutdown behavior of IsolateServicer.""" + +import functools +import os +import signal +import subprocess +import sys +import threading +import time +from unittest.mock import Mock + +import grpc +import pytest +from isolate.server.definitions.server_pb2 import BoundFunction, EnvironmentDefinition +from isolate.server.definitions.server_pb2_grpc import IsolateStub +from isolate.server.interface import to_serialized_object +from isolate.server.server import BridgeManager, IsolateServicer, RunnerAgent, RunTask + + +def create_run_request(func, *args, **kwargs): + """Convert a Python function into a BoundFunction request for stub.Run().""" + bound_function = functools.partial(func, *args, **kwargs) + serialized_function = to_serialized_object(bound_function, method="cloudpickle") + + env_def = EnvironmentDefinition() + env_def.kind = "local" + + request = BoundFunction() + request.function.CopyFrom(serialized_function) + request.environments.append(env_def) + request.stream_logs = True + + return request + + +@pytest.fixture +def servicer(): + """Create a real IsolateServicer instance for testing.""" + with BridgeManager() as bridge_manager: + servicer = IsolateServicer(bridge_manager) + yield servicer + + +@pytest.fixture +def isolate_server_subprocess(monkeypatch): + """Set up a gRPC server with the IsolateServicer for testing.""" + # Find a free port + import socket + + monkeypatch.setenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "2") + + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + process = subprocess.Popen( + [ + sys.executable, + "-m", + "isolate.server.server", + "--single-use", + "--port", + str(port), + ] + ) + + time.sleep(5) # Wait for server to start + yield process, port + + # Cleanup + if process.poll() is None: + process.terminate() + process.wait(timeout=10) + + +def consume_responses(responses): + def _consume(): + try: + for response in responses: + pass + except grpc.RpcError: + # Expected when connection is closed + pass + + response_thread = threading.Thread(target=_consume, daemon=True) + response_thread.start() + + +def test_shutdown_with_terminate(servicer): + task = RunTask(request=Mock(), future=Mock()) + servicer.background_tasks["TEST_BLOCKING"] = task + task.agent = RunnerAgent(Mock(), Mock(), Mock(), Mock()) + task.agent.terminate = Mock(wraps=task.agent.terminate) + servicer.shutdown() + task.agent.terminate.assert_called_once() # agent should be terminated + + +def test_exit_on_client_close(isolate_server_subprocess): + """Connect with grpc client, run a task and then close the client.""" + process, port = isolate_server_subprocess + channel = grpc.insecure_channel(f"localhost:{port}") + stub = IsolateStub(channel) + + def fn(): + import time + + time.sleep(30) # Simulate long-running task + + responses = stub.Run(create_run_request(fn)) + consume_responses(responses) + + # Give task time to start + time.sleep(2) + + # there is a running grpc client connected to an isolate servicer which is + # emitting responses from an agent running a infinite loop + assert process.poll() is None, "Server should be running while client is connected" + + # Close the channel to simulate client disconnect + channel.close() + + # Give time for the channel close to propagate and trigger termination + time.sleep(1.0) + + try: + # Wait for server process to exit + process.wait(timeout=5) + except subprocess.TimeoutExpired: + raise AssertionError("Server did not shut down after client disconnect") + + assert ( + process.poll() is not None + ), "Server should have shut down after client disconnect" + + +def test_running_function_receives_sigterm(isolate_server_subprocess, tmp_path): + """Test that the user provided code receives the SIGTERM""" + process, port = isolate_server_subprocess + channel = grpc.insecure_channel(f"localhost:{port}") + stub = IsolateStub(channel) + + # Send SIGTERM to the current process + assert process.poll() is None, "Server should be running initially" + + def func_with_sigterm_handler(filepath): + import os + import pathlib + import signal + import time + + def handle_term(signum, frame): + print("Received SIGTERM, exiting gracefully...") + pathlib.Path(filepath).touch() + os._exit(0) + + signal.signal(signal.SIGTERM, handle_term) + + time.sleep(30) # Simulate long-running task + + sigterm_file_path = tmp_path.joinpath("sigterm_test") + + assert not sigterm_file_path.exists() + + responses = stub.Run( + create_run_request(func_with_sigterm_handler, str(sigterm_file_path)) + ) + consume_responses(responses) + time.sleep(2) # Give task time to start + + os.kill(process.pid, signal.SIGTERM) + process.wait(timeout=5) + assert process.poll() is not None, "Server should have shut down after SIGTERM" + assert ( + sigterm_file_path.exists() + ), "Function should have received SIGTERM and created the file" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/isolate_client.py b/tools/isolate_client.py new file mode 100644 index 0000000..a0873d7 --- /dev/null +++ b/tools/isolate_client.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import argparse +import sys + +import grpc +from isolate.connections.common import SerializationError +from isolate.connections.grpc.interface import to_serialized_object +from isolate.server import definitions + +environment = definitions.EnvironmentDefinition(kind="local") + + +def func_to_submit() -> str: + import time + + print("Task started, sleeping for 10 seconds...") + time.sleep(10) + return "hello" + + +def describe_rpc_error(error: grpc.RpcError) -> str: + detail = error.details() if hasattr(error, "details") else "" + if detail: + return detail + try: + code = error.code() + code_name = code.name if hasattr(code, "name") else str(code) + except Exception: + code_name = "UNKNOWN" + return f"{code_name}: {error}" + + +class ClientApp: + def __init__( + self, + stub: definitions.IsolateStub, + ) -> None: + self.stub = stub + + def run(self) -> None: + while True: + try: + choice = ( + input( + "\nChoose an action: submit [s], list [l], " + "cancel [c], metadata [m], quit [q]: " + ) + .strip() + .lower() + ) + except EOFError: + print() + break + if choice in {"q", "quit"}: + break + if choice in {"s", "submit"}: + self.handle_submit() + elif choice in {"l", "list"}: + self.handle_list() + elif choice in {"c", "cancel"}: + self.handle_cancel() + else: + print("Unknown choice.", file=sys.stderr) + + def handle_submit(self) -> None: + try: + serialized_function = to_serialized_object( + func_to_submit, + method="dill", + ) + except SerializationError as exc: + print(f"Failed to serialize the function: {exc.message}", file=sys.stderr) + return + except Exception as exc: + print(f"Failed to serialize the function: {exc}", file=sys.stderr) + return + bound = definitions.BoundFunction( + function=serialized_function, + environments=[environment], + ) + request = definitions.SubmitRequest(function=bound) + try: + response = self.stub.Submit(request) + except grpc.RpcError as exc: + print(f"Submit failed: {describe_rpc_error(exc)}", file=sys.stderr) + return + print(f"Task submitted with id: {response.task_id}") + + def handle_list(self) -> None: + try: + response = self.stub.List(definitions.ListRequest()) + except grpc.RpcError as exc: + print(f"List failed: {describe_rpc_error(exc)}", file=sys.stderr) + return + if not response.tasks: + print("No active tasks.") + return + print("Active task ids:") + for info in response.tasks: + print(f" {info.task_id}") + + def handle_cancel(self) -> None: + try: + task_id = input("Task id to cancel: ").strip() + except EOFError: + print() + return + if not task_id: + print("Task id is required.", file=sys.stderr) + return + request = definitions.CancelRequest(task_id=task_id) + try: + self.stub.Cancel(request) + except grpc.RpcError as exc: + print(f"Cancel failed: {describe_rpc_error(exc)}", file=sys.stderr) + return + print(f"Cancellation requested for task {task_id}.") + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Interactive client for isolate.server." + ) + parser.add_argument( + "--host", + default="localhost:50001", + help="gRPC host of the isolate server.", + ) + return parser + + +def main() -> int: + parser = build_parser() + args = parser.parse_args() + channel = grpc.insecure_channel(args.host) + stub = definitions.IsolateStub(channel) + client = ClientApp(stub=stub) + try: + client.run() + except KeyboardInterrupt: + print() + finally: + channel.close() + return 0 + + +if __name__ == "__main__": + sys.exit(main())