-
Couldn't load subscription status.
- Fork 8
feat: Allow agent to handle signals for graceful termination #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e77e535
4961a58
201d4b9
42d04ec
3a3d108
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| import os | ||
| import socket | ||
| import subprocess | ||
| from contextlib import contextmanager | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
|
|
@@ -23,6 +25,11 @@ class AgentError(Exception): | |
| """An internal problem caused by (most probably) the agent.""" | ||
|
|
||
|
|
||
| PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float( | ||
| os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60") | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class GRPCExecutionBase(EnvironmentConnection): | ||
| """A customizable gRPC-based execution backend.""" | ||
|
|
@@ -128,9 +135,18 @@ def find_free_port() -> Tuple[str, int]: | |
| with self.start_process(address) as process: | ||
| yield address, grpc.local_channel_credentials() | ||
| finally: | ||
| if process is not None: | ||
| # TODO: should we check the status code here? | ||
| self.terminate_process(process) | ||
|
|
||
| def terminate_process(self, process: Union[None, subprocess.Popen]) -> None: | ||
| if process is not None: | ||
| try: | ||
| print("Terminating the agent process...") | ||
| process.terminate() | ||
| process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS) | ||
| print("Agent process shutdown gracefully") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be excessive logging, I can take it out. But it was useful for debugging. |
||
| except Exception as exc: | ||
| print(f"Failed to shutdown the agent process gracefully: {exc}") | ||
| process.kill() | ||
|
|
||
| def get_python_cmd( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,20 +10,21 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import os | ||
| import sys | ||
| import traceback | ||
| from argparse import ArgumentParser | ||
| from concurrent import futures | ||
| from dataclasses import dataclass | ||
| from typing import ( | ||
| Any, | ||
| AsyncIterator, | ||
| Iterable, | ||
| Iterator, | ||
| ) | ||
|
|
||
| import grpc | ||
| from grpc import ServicerContext, StatusCode | ||
| from grpc import StatusCode, aio, local_server_credentials | ||
|
|
||
| from isolate.connections.grpc.definitions import PartialRunResult | ||
|
|
||
| try: | ||
| from isolate import __version__ as agent_version | ||
|
|
@@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None): | |
| self._run_cache: dict[str, Any] = {} | ||
| self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w") | ||
|
|
||
| def Run( | ||
| async def Run( | ||
| self, | ||
| request: definitions.FunctionCall, | ||
| context: ServicerContext, | ||
| ) -> Iterator[definitions.PartialRunResult]: | ||
| context: aio.ServicerContext, | ||
| ) -> AsyncIterator[PartialRunResult]: | ||
| self.log(f"A connection has been established: {context.peer()}!") | ||
| server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown" | ||
| self.log(f"Isolate info: server {server_version}, agent {agent_version}") | ||
|
|
@@ -87,7 +88,8 @@ def Run( | |
| ) | ||
| raise AbortException("The setup function has thrown an error.") | ||
| except AbortException as exc: | ||
| return self.abort_with_msg(context, exc.message) | ||
| self.abort_with_msg(context, exc.message) | ||
| return | ||
| else: | ||
| assert not was_it_raised | ||
| self._run_cache[cache_key] = result | ||
|
|
@@ -107,7 +109,8 @@ def Run( | |
| stringized_tb, | ||
| ) | ||
| except AbortException as exc: | ||
| return self.abort_with_msg(context, exc.message) | ||
| self.abort_with_msg(context, exc.message) | ||
| return | ||
|
|
||
| def execute_function( | ||
| self, | ||
|
|
@@ -195,7 +198,7 @@ def log(self, message: str) -> None: | |
|
|
||
| def abort_with_msg( | ||
| self, | ||
| context: ServicerContext, | ||
| context: aio.ServicerContext, | ||
| message: str, | ||
| *, | ||
| code: StatusCode = StatusCode.INVALID_ARGUMENT, | ||
|
|
@@ -205,23 +208,26 @@ def abort_with_msg( | |
| return None | ||
|
|
||
|
|
||
| def create_server(address: str) -> grpc.Server: | ||
| def create_server(address: str) -> aio.Server: | ||
| """Create a new (temporary) gRPC server listening on the given | ||
| address.""" | ||
| server = grpc.server( | ||
| futures.ThreadPoolExecutor(max_workers=1), | ||
| maximum_concurrent_rpcs=1, | ||
| # Use asyncio server so requests can run in the main thread and intercept signals | ||
| # There seems to be a weird bug with grpcio that makes subsequent requests fail with | ||
| # concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2 | ||
| # fixes it, even though in practice, we only run one request at a time. | ||
| server = aio.server( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main change, so all requests run in the main thread. |
||
| maximum_concurrent_rpcs=2, | ||
| options=get_default_options(), | ||
| ) | ||
|
|
||
| # Local server credentials allow us to ensure that the | ||
| # connection is established by a local process. | ||
| server_credentials = grpc.local_server_credentials() | ||
| server_credentials = local_server_credentials() | ||
| server.add_secure_port(address, server_credentials) | ||
| return server | ||
|
|
||
|
|
||
| def run_agent(address: str, log_fd: int | None = None) -> int: | ||
| async def run_agent(address: str, log_fd: int | None = None) -> int: | ||
| """Run the agent servicer on the given address.""" | ||
| server = create_server(address) | ||
| servicer = AgentServicer(log_fd=log_fd) | ||
|
|
@@ -231,19 +237,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int: | |
| # not have any global side effects. | ||
| definitions.register_agent(servicer, server) | ||
|
|
||
| server.start() | ||
| server.wait_for_termination() | ||
| await server.start() | ||
| await server.wait_for_termination() | ||
| return 0 | ||
|
|
||
|
|
||
| def main() -> int: | ||
| async def main() -> int: | ||
| parser = ArgumentParser() | ||
| parser.add_argument("address", type=str) | ||
| parser.add_argument("--log-fd", type=int) | ||
|
|
||
| options = parser.parse_args() | ||
| return run_agent(options.address, log_fd=options.log_fd) | ||
| return await run_agent(options.address, log_fd=options.log_fd) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| asyncio.run(main()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import functools | ||
| import os | ||
| import signal | ||
| import threading | ||
| import time | ||
| import traceback | ||
|
|
@@ -178,11 +179,17 @@ class RunTask: | |
|
|
||
| def cancel(self): | ||
| while True: | ||
| self.future.cancel() | ||
| # Cancelling a running future is not possible, and it sometimes blocks, | ||
| # which means we never terminate the agent. So check if it's not running | ||
| if self.future and not self.future.running(): | ||
| self.future.cancel() | ||
|
Comment on lines
+184
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But if we dont cancel it, then what happens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in almost all cases, nothing. But there could be rare race conditions where the future that hasn't started yet starts executing after this leading to an orphaned agent process (more likely to happen when server is handling multiple tasks). The chances are quite low, but it's more correct to always cancel imo. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can just do a log about this scenario then |
||
|
|
||
| if self.agent: | ||
| self.agent.terminate() | ||
|
|
||
| try: | ||
| self.future.exception(timeout=0.1) | ||
| if self.future: | ||
| self.future.exception(timeout=0.1) | ||
| return | ||
| except futures.TimeoutError: | ||
| pass | ||
|
|
@@ -197,6 +204,7 @@ class IsolateServicer(definitions.IsolateServicer): | |
| bridge_manager: BridgeManager | ||
| default_settings: IsolateSettings = field(default_factory=IsolateSettings) | ||
| background_tasks: dict[str, RunTask] = field(default_factory=dict) | ||
| _shutting_down: bool = field(default=False) | ||
|
|
||
| _thread_pool: futures.ThreadPoolExecutor = field( | ||
| default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS) | ||
|
|
@@ -420,6 +428,17 @@ def Cancel( | |
|
|
||
| return definitions.CancelResponse() | ||
|
|
||
| def shutdown(self) -> None: | ||
| if self._shutting_down: | ||
| print("Shutdown already in progress...") | ||
| return | ||
|
|
||
| self._shutting_down = True | ||
| task_count = len(self.background_tasks) | ||
| print(f"Shutting down, canceling {task_count} tasks...") | ||
| self.cancel_tasks() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This cancels in sequence, and not parallel but it should be ok because we only run one at a time AFAIK? |
||
| print("All tasks canceled.") | ||
|
|
||
| def watch_queue_until_completed( | ||
| self, queue: Queue, is_completed: Callable[[], bool] | ||
| ) -> Iterator[definitions.PartialRunResult]: | ||
|
|
@@ -584,8 +603,10 @@ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any: | |
| def termination() -> None: | ||
| if is_run: | ||
| print("Stopping server since run is finished") | ||
| self.servicer.shutdown() | ||
| # Stop the server after the Run task is finished | ||
| self.server.stop(grace=0.1) | ||
| print("Server stopped") | ||
|
|
||
| elif is_submit: | ||
| # Wait until the task_id is assigned | ||
|
|
@@ -610,7 +631,9 @@ def _stop(*args): | |
| # Small sleep to make sure the cancellation is processed | ||
| time.sleep(0.1) | ||
| print("Stopping server since the task is finished") | ||
| self.servicer.shutdown() | ||
| self.server.stop(grace=0.1) | ||
| print("Server stopped") | ||
|
|
||
| # Add a callback which will stop the server | ||
| # after the task is finished | ||
|
|
@@ -671,11 +694,20 @@ def main(argv: list[str] | None = None) -> None: | |
| definitions.register_isolate(servicer, server) | ||
| health.register_health(HealthServicer(), server) | ||
|
|
||
| server.add_insecure_port("[::]:50001") | ||
| print("Started listening at localhost:50001") | ||
| def handle_termination(*args): | ||
| print("Termination signal received, shutting down...") | ||
| servicer.shutdown() | ||
| server.stop(grace=0.1) | ||
|
|
||
| signal.signal(signal.SIGINT, handle_termination) | ||
| signal.signal(signal.SIGTERM, handle_termination) | ||
|
|
||
| server.add_insecure_port(f"[::]:{options.port}") | ||
| print(f"Started listening at {options.host}:{options.port}") | ||
|
|
||
| server.start() | ||
| server.wait_for_termination() | ||
| print("Server shut down") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very good solution