Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/isolate/connections/grpc/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import socket
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -23,6 +25,11 @@ class AgentError(Exception):
"""An internal problem caused by (most probably) the agent."""


PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float(
os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good solution

)


@dataclass
class GRPCExecutionBase(EnvironmentConnection):
"""A customizable gRPC-based execution backend."""
Expand Down Expand Up @@ -128,9 +135,18 @@ def find_free_port() -> Tuple[str, int]:
with self.start_process(address) as process:
yield address, grpc.local_channel_credentials()
finally:
if process is not None:
# TODO: should we check the status code here?
self.terminate_process(process)

def terminate_process(self, process: Union[None, subprocess.Popen]) -> None:
if process is not None:
try:
print("Terminating the agent process...")
process.terminate()
process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
print("Agent process shutdown gracefully")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be excessive logging, I can take it out. But it was useful for debugging.

except Exception as exc:
print(f"Failed to shutdown the agent process gracefully: {exc}")
process.kill()

def get_python_cmd(
self,
Expand Down
48 changes: 27 additions & 21 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@

from __future__ import annotations

import asyncio
import os
import sys
import traceback
from argparse import ArgumentParser
from concurrent import futures
from dataclasses import dataclass
from typing import (
Any,
AsyncIterator,
Iterable,
Iterator,
)

import grpc
from grpc import ServicerContext, StatusCode
from grpc import StatusCode, aio, local_server_credentials

from isolate.connections.grpc.definitions import PartialRunResult

try:
from isolate import __version__ as agent_version
Expand All @@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None):
self._run_cache: dict[str, Any] = {}
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")

def Run(
async def Run(
self,
request: definitions.FunctionCall,
context: ServicerContext,
) -> Iterator[definitions.PartialRunResult]:
context: aio.ServicerContext,
) -> AsyncIterator[PartialRunResult]:
self.log(f"A connection has been established: {context.peer()}!")
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
self.log(f"Isolate info: server {server_version}, agent {agent_version}")
Expand Down Expand Up @@ -87,7 +88,8 @@ def Run(
)
raise AbortException("The setup function has thrown an error.")
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
self.abort_with_msg(context, exc.message)
return
else:
assert not was_it_raised
self._run_cache[cache_key] = result
Expand All @@ -107,7 +109,8 @@ def Run(
stringized_tb,
)
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
self.abort_with_msg(context, exc.message)
return

def execute_function(
self,
Expand Down Expand Up @@ -195,7 +198,7 @@ def log(self, message: str) -> None:

def abort_with_msg(
self,
context: ServicerContext,
context: aio.ServicerContext,
message: str,
*,
code: StatusCode = StatusCode.INVALID_ARGUMENT,
Expand All @@ -205,23 +208,26 @@ def abort_with_msg(
return None


def create_server(address: str) -> grpc.Server:
def create_server(address: str) -> aio.Server:
"""Create a new (temporary) gRPC server listening on the given
address."""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
maximum_concurrent_rpcs=1,
# Use asyncio server so requests can run in the main thread and intercept signals
# There seems to be a weird bug with grpcio that makes subsequent requests fail with
# concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2
# fixes it, even though in practice, we only run one request at a time.
server = aio.server(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change, so all requests run in the main thread.

maximum_concurrent_rpcs=2,
options=get_default_options(),
)

# Local server credentials allow us to ensure that the
# connection is established by a local process.
server_credentials = grpc.local_server_credentials()
server_credentials = local_server_credentials()
server.add_secure_port(address, server_credentials)
return server


def run_agent(address: str, log_fd: int | None = None) -> int:
async def run_agent(address: str, log_fd: int | None = None) -> int:
"""Run the agent servicer on the given address."""
server = create_server(address)
servicer = AgentServicer(log_fd=log_fd)
Expand All @@ -231,19 +237,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int:
# not have any global side effects.
definitions.register_agent(servicer, server)

server.start()
server.wait_for_termination()
await server.start()
await server.wait_for_termination()
return 0


def main() -> int:
async def main() -> int:
parser = ArgumentParser()
parser.add_argument("address", type=str)
parser.add_argument("--log-fd", type=int)

options = parser.parse_args()
return run_agent(options.address, log_fd=options.log_fd)
return await run_agent(options.address, log_fd=options.log_fd)


if __name__ == "__main__":
main()
asyncio.run(main())
40 changes: 36 additions & 4 deletions src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import os
import signal
import threading
import time
import traceback
Expand Down Expand Up @@ -178,11 +179,17 @@ class RunTask:

def cancel(self):
while True:
self.future.cancel()
# Cancelling a running future is not possible, and it sometimes blocks,
# which means we never terminate the agent. So check if it's not running
if self.future and not self.future.running():
self.future.cancel()
Comment on lines +184 to +185
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we dont cancel it, then what happens?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in almost all cases, nothing. But there could be rare race conditions where the future that hasn't started yet starts executing after this leading to an orphaned agent process (more likely to happen when server is handling multiple tasks).

The chances are quite low, but it's more correct to always cancel imo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can just do a log about this scenario then


if self.agent:
self.agent.terminate()

try:
self.future.exception(timeout=0.1)
if self.future:
self.future.exception(timeout=0.1)
return
except futures.TimeoutError:
pass
Expand All @@ -197,6 +204,7 @@ class IsolateServicer(definitions.IsolateServicer):
bridge_manager: BridgeManager
default_settings: IsolateSettings = field(default_factory=IsolateSettings)
background_tasks: dict[str, RunTask] = field(default_factory=dict)
_shutting_down: bool = field(default=False)

_thread_pool: futures.ThreadPoolExecutor = field(
default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS)
Expand Down Expand Up @@ -420,6 +428,17 @@ def Cancel(

return definitions.CancelResponse()

def shutdown(self) -> None:
if self._shutting_down:
print("Shutdown already in progress...")
return

self._shutting_down = True
task_count = len(self.background_tasks)
print(f"Shutting down, canceling {task_count} tasks...")
self.cancel_tasks()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cancels in sequence, and not parallel but it should be ok because we only run one at a time AFAIK?

print("All tasks canceled.")

def watch_queue_until_completed(
self, queue: Queue, is_completed: Callable[[], bool]
) -> Iterator[definitions.PartialRunResult]:
Expand Down Expand Up @@ -584,8 +603,10 @@ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
def termination() -> None:
if is_run:
print("Stopping server since run is finished")
self.servicer.shutdown()
# Stop the server after the Run task is finished
self.server.stop(grace=0.1)
print("Server stopped")

elif is_submit:
# Wait until the task_id is assigned
Expand All @@ -610,7 +631,9 @@ def _stop(*args):
# Small sleep to make sure the cancellation is processed
time.sleep(0.1)
print("Stopping server since the task is finished")
self.servicer.shutdown()
self.server.stop(grace=0.1)
print("Server stopped")

# Add a callback which will stop the server
# after the task is finished
Expand Down Expand Up @@ -671,11 +694,20 @@ def main(argv: list[str] | None = None) -> None:
definitions.register_isolate(servicer, server)
health.register_health(HealthServicer(), server)

server.add_insecure_port("[::]:50001")
print("Started listening at localhost:50001")
def handle_termination(*args):
print("Termination signal received, shutting down...")
servicer.shutdown()
server.stop(grace=0.1)

signal.signal(signal.SIGINT, handle_termination)
signal.signal(signal.SIGTERM, handle_termination)

server.add_insecure_port(f"[::]:{options.port}")
print(f"Started listening at {options.host}:{options.port}")

server.start()
server.wait_for_termination()
print("Server shut down")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def test_bridge_connection_reuse_logs(
run_request(stub, request, user_logs=logs)

str_logs = [log.message for log in logs if log.message]
assert str_logs == [
assert str_logs[:4] == [
"setup",
"run",
"run",
Expand Down
Loading
Loading