-
Notifications
You must be signed in to change notification settings - Fork 0
add healthcheck #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,11 +7,14 @@ | |||||
| Messages are forwarded as opaque bytes — no proto definitions required. | ||||||
| """ | ||||||
|
|
||||||
| import asyncio | ||||||
| import logging | ||||||
| import time | ||||||
| from collections.abc import AsyncIterator | ||||||
|
|
||||||
| import grpc | ||||||
| from grpc import aio | ||||||
| from grpc_health.v1 import health, health_pb2, health_pb2_grpc | ||||||
|
|
||||||
| from spark_connect_proxy.auth import AuthError, TokenValidator | ||||||
| from spark_connect_proxy.config import ProxySettings | ||||||
|
|
@@ -51,26 +54,51 @@ def _extract_token(metadata: tuple[tuple[str, str | bytes], ...] | None) -> str: | |||||
|
|
||||||
|
|
||||||
| # --------------------------------------------------------------------------- | ||||||
| # Channel pool — reuse channels to the same backend | ||||||
| # Channel pool — reuse channels to the same backend with LRU eviction | ||||||
| # --------------------------------------------------------------------------- | ||||||
|
|
||||||
|
|
||||||
| class ChannelPool: | ||||||
| """Manages a pool of gRPC channels to backend Spark Connect servers.""" | ||||||
| """Manages a pool of gRPC channels to backend Spark Connect servers. | ||||||
|
|
||||||
| def __init__(self) -> None: | ||||||
| self._channels: dict[str, aio.Channel] = {} | ||||||
| Tracks last-used time for each channel and evicts least-recently-used | ||||||
| entries when the pool exceeds max_size. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, max_size: int = 100) -> None: | ||||||
| self._max_size = max_size | ||||||
| # target → (channel, last_used_timestamp) | ||||||
| self._channels: dict[str, tuple[aio.Channel, float]] = {} | ||||||
|
|
||||||
| def get_channel(self, target: str) -> aio.Channel: | ||||||
| """Get or create a channel to the specified backend target.""" | ||||||
| if target not in self._channels: | ||||||
| logger.info("Opening channel to backend: %s", target) | ||||||
| self._channels[target] = aio.insecure_channel(target) | ||||||
| return self._channels[target] | ||||||
| if target in self._channels: | ||||||
| channel, _ = self._channels[target] | ||||||
| self._channels[target] = (channel, time.monotonic()) | ||||||
| return channel | ||||||
|
|
||||||
| # Evict LRU if at capacity | ||||||
| if len(self._channels) >= self._max_size: | ||||||
| self._evict_lru() | ||||||
|
|
||||||
| logger.info("Opening channel to backend: %s", target) | ||||||
| channel = aio.insecure_channel(target) | ||||||
| self._channels[target] = (channel, time.monotonic()) | ||||||
| return channel | ||||||
|
|
||||||
| def _evict_lru(self) -> None: | ||||||
| """Evict the least-recently-used channel.""" | ||||||
| if not self._channels: | ||||||
| return | ||||||
| lru_target = min(self._channels, key=lambda t: self._channels[t][1]) | ||||||
| channel, _ = self._channels.pop(lru_target) | ||||||
| logger.info("Evicting idle channel to: %s", lru_target) | ||||||
| # Close asynchronously — fire and forget | ||||||
| asyncio.ensure_future(channel.close()) | ||||||
|
|
||||||
| async def close_all(self) -> None: | ||||||
| """Close all open channels.""" | ||||||
| for target, channel in self._channels.items(): | ||||||
| for target, (channel, _) in self._channels.items(): | ||||||
| logger.info("Closing channel to: %s", target) | ||||||
| await channel.close() | ||||||
| self._channels.clear() | ||||||
Tianhao-Gu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
@@ -142,53 +170,91 @@ async def _proxy_stream_unary( | |||||
|
|
||||||
|
|
||||||
| # --------------------------------------------------------------------------- | ||||||
| # Generic RPC handler | ||||||
| # Generic RPC handler — authentication is deferred to async context | ||||||
| # --------------------------------------------------------------------------- | ||||||
|
|
||||||
|
|
||||||
| class SparkConnectProxyHandler(grpc.GenericRpcHandler): | ||||||
| """ | ||||||
| Generic gRPC handler that intercepts all Spark Connect RPCs and proxies | ||||||
| them to the correct user's backend based on KBase token authentication. | ||||||
|
|
||||||
| Authentication is performed inside the async handler behaviors (not in | ||||||
| the synchronous service() method) to avoid blocking the event loop. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, settings: ProxySettings, validator: TokenValidator, pool: ChannelPool): | ||||||
| self._settings = settings | ||||||
| self._validator = validator | ||||||
| self._pool = pool | ||||||
|
|
||||||
| def service( | ||||||
| self, handler_call_details: grpc.HandlerCallDetails | ||||||
| ) -> grpc.RpcMethodHandler | None: | ||||||
| method = handler_call_details.method | ||||||
| async def _authenticate( | ||||||
| self, metadata: tuple[tuple[str, str | bytes], ...] | None, context: aio.ServicerContext | ||||||
| ) -> tuple[aio.Channel, tuple[tuple[str, str | bytes], ...]]: | ||||||
| """Authenticate, resolve backend, and verify connectivity. | ||||||
|
|
||||||
| if method not in SPARK_CONNECT_METHODS: | ||||||
| return None | ||||||
|
|
||||||
| req_type, resp_type = SPARK_CONNECT_METHODS[method] | ||||||
| metadata = handler_call_details.invocation_metadata | ||||||
| Returns: | ||||||
| (channel, forward_metadata) tuple on success. | ||||||
|
|
||||||
| # Authenticate and resolve backend target | ||||||
| Raises: | ||||||
| Aborts the gRPC context with UNAUTHENTICATED on auth failure, | ||||||
| or UNAVAILABLE if the backend is not reachable. | ||||||
|
||||||
| or UNAVAILABLE if the backend is not reachable. | |
| or FAILED_PRECONDITION if the backend is not reachable. |
Tianhao-Gu marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backend connectivity timeout currently aborts with grpc.StatusCode.FAILED_PRECONDITION. For an unreachable backend, gRPC conventions typically use UNAVAILABLE (or DEADLINE_EXCEEDED if you want to surface the timeout) so clients can apply appropriate retry/backoff behavior. Consider switching the status code (and updating tests accordingly).
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serve() passes settings.MAX_CHANNELS_PER_BACKEND into ChannelPool(max_size=...). However ProxySettings describes MAX_CHANNELS_PER_BACKEND as a per-backend limit, while ChannelPool's max_size is a global cap on cached targets/channels. This mismatch can lead to unexpected aggressive eviction; consider renaming the setting and updating its description, or adjusting ChannelPool to enforce a true per-backend limit.
Uh oh!
There was an error while loading. Please reload this page.