Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 15 additions & 39 deletions agentlightning/llm_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult

from agentlightning.types import LLM, ProxyLLM
from agentlightning.utils.uvicorn_server import (
UvicornServerHandle,
create_uvicorn_server,
start_uvicorn_in_thread,
)

from .store.base import LightningStore

Expand Down Expand Up @@ -522,10 +527,10 @@ def __init__(
self.litellm_config.setdefault("litellm_settings", {})
self.litellm_config["litellm_settings"].setdefault("num_retries", num_retries)

self._server_thread = None
self._server_handle: UvicornServerHandle | None = None
self._config_file = None
self._uvicorn_server = None
self._ready_event = threading.Event()
self._uvicorn_server: uvicorn.Server | None = None

def set_store(self, store: LightningStore) -> None:
"""Set the store for the proxy.
Expand All @@ -547,25 +552,6 @@ def update_model_list(self, model_list: List[ModelConfig]) -> None:
self.restart()
# Do nothing if the server is not running.

def _wait_until_started(self, startup_timeout: float = 20.0):
"""Block until the uvicorn server reports started or timeout.

Args:
startup_timeout: Maximum seconds to wait.
"""
start = time.time()
while True:
if self._uvicorn_server is None:
break
if self._uvicorn_server.started:
self._ready_event.set()
break
if self._uvicorn_server.should_exit:
break
if time.time() - start > startup_timeout:
break
time.sleep(0.01)

def start(self):
"""Start the proxy server thread and initialize global wiring.

Expand Down Expand Up @@ -601,18 +587,13 @@ def start(self):
save_worker_config(config=self._config_file)

# Bind to all interfaces to allow other hosts to reach it if needed.
self._uvicorn_server = uvicorn.Server(uvicorn.Config(app, host="0.0.0.0", port=self.port))

def run_server():
# Serve uvicorn in this background thread with its own event loop.
assert self._uvicorn_server is not None
asyncio.run(self._uvicorn_server.serve())
self._uvicorn_server = create_uvicorn_server(app, host="0.0.0.0", port=self.port)

logger.info("Starting LLMProxy server thread...")
self._ready_event.clear()
self._server_thread = threading.Thread(target=run_server, daemon=True)
self._server_thread.start()
self._wait_until_started()
self._server_handle = start_uvicorn_in_thread(self._uvicorn_server)
if self._server_handle.wait_until_started():
self._ready_event.set()

def stop(self):
"""Stop the proxy server and clean up temporary artifacts.
Expand All @@ -629,15 +610,10 @@ def stop(self):

logger.info("Stopping LLMProxy server thread...")
stop_success = True
if self._server_thread is not None and self._uvicorn_server is not None and self._uvicorn_server.started:
self._uvicorn_server.should_exit = True
self._server_thread.join(timeout=10.0) # Allow time for graceful shutdown.
if self._server_thread.is_alive():
logger.error(
"LLMProxy server thread is still alive after 10 seconds. Cannot kill it because it's a thread."
)
if self._server_handle is not None and self._uvicorn_server is not None:
if not self._server_handle.stop(timeout=10.0):
stop_success = False
self._server_thread = None
self._server_handle = None
self._uvicorn_server = None
self._config_file = None
self._ready_event.clear()
Expand Down Expand Up @@ -667,7 +643,7 @@ def is_running(self) -> bool:
Returns:
bool: True if server was started and did not signal exit.
"""
return self._uvicorn_server is not None and self._uvicorn_server.started
return self._server_handle is not None and self._server_handle.is_running()

def as_resource(
self,
Expand Down
44 changes: 22 additions & 22 deletions agentlightning/store/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
Span,
TaskInput,
)
from agentlightning.utils.uvicorn_server import (
UvicornServerHandle,
create_uvicorn_server,
start_uvicorn_in_thread,
)

from .base import UNSET, LightningStore, Unset

Expand Down Expand Up @@ -95,12 +100,8 @@ def __init__(self, store: LightningStore, host: str, port: int):
self.port = port
self.app: FastAPI | None = FastAPI(title="LightningStore Server")
self._setup_routes()
self._uvicorn_config: uvicorn.Config | None = uvicorn.Config(
self.app, host="0.0.0.0", port=self.port, log_level="error"
)
self._uvicorn_server: uvicorn.Server | None = uvicorn.Server(self._uvicorn_config)

self._serving_thread: Optional[threading.Thread] = None
self._uvicorn_server: uvicorn.Server | None = None
self._server_handle: UvicornServerHandle | None = None

# Process-awareness:
# LightningStoreServer holds a plain Python object (self.store) in one process
Expand Down Expand Up @@ -142,8 +143,7 @@ def __setstate__(self, state: Dict[str, Any]):
self.port = state["port"]
self._owner_pid = state["_owner_pid"]
self._client = None
# Do NOT reconstruct app, _uvicorn_config, _uvicorn_server
# to avoid transferring server state to subprocess
# Do NOT reconstruct app or _uvicorn_server to avoid transferring server state to subprocess

@property
def endpoint(self) -> str:
Expand All @@ -154,16 +154,16 @@ async def start(self):

You need to call this method in the same process as the server was created in.
"""
assert self._uvicorn_server is not None
if self._server_handle is not None and self._server_handle.is_running():
await self.stop()
assert self.app is not None
self._uvicorn_server = create_uvicorn_server(self.app, host="0.0.0.0", port=self.port, log_level="error")
logger.info(f"Starting server at {self.endpoint}")

uvicorn_server = self._uvicorn_server

def run_server_forever():
asyncio.run(uvicorn_server.serve())

self._serving_thread = threading.Thread(target=run_server_forever, daemon=True)
self._serving_thread.start()
self._server_handle = start_uvicorn_in_thread(self._uvicorn_server)
started = await self._server_handle.wait_until_started_async()
if not started:
raise RuntimeError("uvicorn server failed to report started state")

# Wait for /health to be available
current_time = time.time()
Expand All @@ -181,14 +181,14 @@ async def stop(self):

You need to call this method in the same process as the server was created in.
"""
assert self._uvicorn_server is not None
if self._uvicorn_server.started:
if self._uvicorn_server is None or self._server_handle is None:
return
if self._server_handle.is_running():
logger.info("Stopping server...")
self._uvicorn_server.should_exit = True
if self._serving_thread is not None:
self._serving_thread.join(timeout=10)
self._serving_thread = None
self._server_handle.stop(timeout=10)
logger.info("Server stopped.")
self._server_handle = None
self._uvicorn_server = None

def _backend(self) -> LightningStore:
"""Returns the object to delegate to in *this* process.
Expand Down
Empty file.
142 changes: 142 additions & 0 deletions agentlightning/utils/uvicorn_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Utilities for running uvicorn servers in background threads or processes."""

from __future__ import annotations

import asyncio
import logging
import multiprocessing
import threading
import time
from dataclasses import dataclass
from typing import Any, Callable, Optional

import uvicorn

logger = logging.getLogger(__name__)


@dataclass
class UvicornServerHandle:
"""Handle for a running uvicorn server."""

server: Optional[uvicorn.Server]
thread: Optional[threading.Thread] = None
process: Optional[multiprocessing.Process] = None

def is_running(self) -> bool:
"""Return True if the underlying server/process is alive."""

if self.server is not None:
return bool(self.server.started and not self.server.should_exit)
if self.process is not None:
return self.process.is_alive()
return False

def wait_until_started(self, timeout: float = 20.0, poll_interval: float = 0.01) -> bool:
"""Block until the server reports started or timeout occurs."""

if self.server is None:
raise RuntimeError("Cannot wait for start on a process-based server handle.")

start_time = time.time()
while time.time() - start_time < timeout:
if self.server.started:
return True
if self.server.should_exit:
return False
time.sleep(poll_interval)
return bool(self.server.started)

async def wait_until_started_async(self, timeout: float = 20.0, poll_interval: float = 0.01) -> bool:
"""Async variant of :meth:`wait_until_started`."""

if self.server is None:
raise RuntimeError("Cannot wait for start on a process-based server handle.")

start_time = time.time()
while time.time() - start_time < timeout:
if self.server.started:
return True
if self.server.should_exit:
return False
await asyncio.sleep(poll_interval)
return bool(self.server.started)

def stop(self, timeout: float = 10.0, force: bool = False) -> bool:
"""Attempt to stop the running server."""

success = True
if self.server is not None:
if self.server.started and not self.server.should_exit:
self.server.should_exit = True
if self.thread is not None:
self.thread.join(timeout=timeout)
if self.thread.is_alive():
logger.error("uvicorn server thread is still alive after %.1f seconds", timeout)
success = False
if force:
logger.warning("Force flag has no effect for threads; manual intervention required.")
self.thread = None
self.server = None
elif self.process is not None:
if self.process.is_alive():
if force:
logger.warning("Forcefully terminating uvicorn process.")
self.process.kill()
else:
self.process.terminate()
self.process.join(timeout=timeout)
if self.process.is_alive():
logger.error("uvicorn server process is still alive after %.1f seconds", timeout)
success = False
self.process = None
else:
success = False
return success


def create_uvicorn_server(
app: Any,
host: str,
port: int,
*,
log_level: str = "info",
config_factory: Callable[..., uvicorn.Config] | None = None,
**config_kwargs: Any,
) -> uvicorn.Server:
"""Create a uvicorn server for the given ASGI app."""

factory = config_factory or uvicorn.Config
config = factory(app, host=host, port=port, log_level=log_level, **config_kwargs)
return uvicorn.Server(config)


def start_uvicorn_in_thread(
server: uvicorn.Server,
*,
daemon: bool = True,
) -> UvicornServerHandle:
"""Start a uvicorn server inside a background thread."""

def run() -> None:
asyncio.run(server.serve())

thread = threading.Thread(target=run, daemon=daemon)
thread.start()
return UvicornServerHandle(server=server, thread=thread)


def start_uvicorn_in_process(
config: uvicorn.Config,
*,
daemon: bool = True,
) -> UvicornServerHandle:
"""Start a uvicorn server inside a separate process."""

def target() -> None:
server = uvicorn.Server(config)
asyncio.run(server.serve())

process = multiprocessing.Process(target=target, daemon=daemon)
process.start()
return UvicornServerHandle(server=None, process=process)
32 changes: 14 additions & 18 deletions tests/tracer/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@
from agentlightning.tracer.agentops import AgentOpsTracer, LightningSpanProcessor
from agentlightning.tracer.http import HttpTracer
from agentlightning.types import Span, Triplet
from agentlightning.utils.uvicorn_server import (
UvicornServerHandle,
create_uvicorn_server,
start_uvicorn_in_thread,
)

from ..common.tracer import clear_agentops_init, clear_tracer_provider

Expand Down Expand Up @@ -105,7 +110,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8000) -> None:
self.host = host
self.port = port
self.app = FastAPI()
self.server_thread = None
self.server_handle: UvicornServerHandle | None = None
self.server = None
self.prompt_caches = self._load_prompt_caches()
self._setup_routes()
Expand Down Expand Up @@ -169,28 +174,19 @@ def chat_completions(request: ChatCompletionRequest):

async def __aenter__(self):
# Start the server manually
config = uvicorn.Config(self.app, host=self.host, port=self.port, log_level="error")
self.server = uvicorn.Server(config)
self.server_thread = threading.Thread(target=self.server.run, daemon=True)
self.server_thread.start()

# Wait for server to start
max_wait = 10 # seconds
wait_time = 0
while not getattr(self.server, "started", False) and wait_time < max_wait:
await asyncio.sleep(0.1)
wait_time += 0.1

if not getattr(self.server, "started", False):
self.server = create_uvicorn_server(self.app, host=self.host, port=self.port, log_level="error")
self.server_handle = start_uvicorn_in_thread(self.server)
started = await self.server_handle.wait_until_started_async(timeout=10.0)
if not started:
raise RuntimeError("Server failed to start within timeout")

return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.server:
self.server.should_exit = True
if self.server_thread and self.server_thread.is_alive():
self.server_thread.join(timeout=5)
if self.server_handle:
self.server_handle.stop(timeout=5)
self.server_handle = None
self.server = None


async def run_agent(agent_func: Callable[[], Any]) -> None:
Expand Down