Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ User (PySpark) → Ingress (spark.berdl.kbase.us:443) → Spark Connect Proxy

```bash
# Install
pip install -e ".[dev]"
uv sync --dev

# Run tests
pytest
uv run pytest

# Run locally
python -m spark_connect_proxy
uv run python -m spark_connect_proxy
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ classifiers = [

dependencies = [
"grpcio>=1.60.0",
"grpcio-health-checking>=1.60.0",
"httpx>=0.27.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
Expand Down
3 changes: 3 additions & 0 deletions src/spark_connect_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class ProxySettings(BaseSettings):
# Maximum concurrent gRPC connections to keep per backend
MAX_CHANNELS_PER_BACKEND: int = 5

# Timeout (seconds) for initial backend connectivity check
BACKEND_CONNECT_TIMEOUT: float = 5.0

model_config = {"env_prefix": ""}

def backend_target(self, username: str) -> str:
Expand Down
195 changes: 119 additions & 76 deletions src/spark_connect_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
Messages are forwarded as opaque bytes — no proto definitions required.
"""

import asyncio
import logging
import time
from collections.abc import AsyncIterator

import grpc
from grpc import aio
from grpc_health.v1 import health, health_pb2, health_pb2_grpc

from spark_connect_proxy.auth import AuthError, TokenValidator
from spark_connect_proxy.config import ProxySettings
Expand Down Expand Up @@ -51,26 +54,51 @@ def _extract_token(metadata: tuple[tuple[str, str | bytes], ...] | None) -> str:


# ---------------------------------------------------------------------------
# Channel pool — reuse channels to the same backend
# Channel pool — reuse channels to the same backend with LRU eviction
# ---------------------------------------------------------------------------


class ChannelPool:
"""Manages a pool of gRPC channels to backend Spark Connect servers."""
"""Manages a pool of gRPC channels to backend Spark Connect servers.

def __init__(self) -> None:
self._channels: dict[str, aio.Channel] = {}
Tracks last-used time for each channel and evicts least-recently-used
entries when the pool exceeds max_size.
"""

def __init__(self, max_size: int = 100) -> None:
self._max_size = max_size
# target → (channel, last_used_timestamp)
self._channels: dict[str, tuple[aio.Channel, float]] = {}

def get_channel(self, target: str) -> aio.Channel:
"""Get or create a channel to the specified backend target."""
if target not in self._channels:
logger.info("Opening channel to backend: %s", target)
self._channels[target] = aio.insecure_channel(target)
return self._channels[target]
if target in self._channels:
channel, _ = self._channels[target]
self._channels[target] = (channel, time.monotonic())
return channel

# Evict LRU if at capacity
if len(self._channels) >= self._max_size:
self._evict_lru()

logger.info("Opening channel to backend: %s", target)
channel = aio.insecure_channel(target)
self._channels[target] = (channel, time.monotonic())
return channel

def _evict_lru(self) -> None:
"""Evict the least-recently-used channel."""
if not self._channels:
return
lru_target = min(self._channels, key=lambda t: self._channels[t][1])
channel, _ = self._channels.pop(lru_target)
logger.info("Evicting idle channel to: %s", lru_target)
# Close asynchronously — fire and forget
asyncio.ensure_future(channel.close())

async def close_all(self) -> None:
"""Close all open channels."""
for target, channel in self._channels.items():
for target, (channel, _) in self._channels.items():
logger.info("Closing channel to: %s", target)
await channel.close()
self._channels.clear()
Expand Down Expand Up @@ -142,53 +170,91 @@ async def _proxy_stream_unary(


# ---------------------------------------------------------------------------
# Generic RPC handler
# Generic RPC handler — authentication is deferred to async context
# ---------------------------------------------------------------------------


class SparkConnectProxyHandler(grpc.GenericRpcHandler):
"""
Generic gRPC handler that intercepts all Spark Connect RPCs and proxies
them to the correct user's backend based on KBase token authentication.

Authentication is performed inside the async handler behaviors (not in
the synchronous service() method) to avoid blocking the event loop.
"""

def __init__(self, settings: ProxySettings, validator: TokenValidator, pool: ChannelPool):
self._settings = settings
self._validator = validator
self._pool = pool

def service(
self, handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler | None:
method = handler_call_details.method
async def _authenticate(
self, metadata: tuple[tuple[str, str | bytes], ...] | None, context: aio.ServicerContext
) -> tuple[aio.Channel, tuple[tuple[str, str | bytes], ...]]:
"""Authenticate, resolve backend, and verify connectivity.

if method not in SPARK_CONNECT_METHODS:
return None

req_type, resp_type = SPARK_CONNECT_METHODS[method]
metadata = handler_call_details.invocation_metadata
Returns:
(channel, forward_metadata) tuple on success.

# Authenticate and resolve backend target
Raises:
Aborts the gRPC context with UNAUTHENTICATED on auth failure,
or UNAVAILABLE if the backend is not reachable.
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _authenticate() docstring says the backend-unreachable case aborts with UNAVAILABLE, but the implementation currently aborts with FAILED_PRECONDITION. Please update either the docstring or the status code so behavior and documentation match.

Suggested change
or UNAVAILABLE if the backend is not reachable.
or FAILED_PRECONDITION if the backend is not reachable.

Copilot uses AI. Check for mistakes.
"""
try:
token = _extract_token(metadata)
username = self._validator.get_username(token)
# Run the synchronous auth call in a thread to avoid blocking
username = await asyncio.to_thread(self._validator.get_username, token)
except AuthError as e:
logger.warning("Authentication failed: %s", e)
return self._unauthenticated_handler(str(e), req_type, resp_type)
await context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
raise # unreachable after abort, satisfies type checker

target = self._settings.backend_target(username)
channel = self._pool.get_channel(target)

# Forward original metadata (including x-kbase-token for server-side validation)
# Fast connectivity check — fail early if backend is unreachable
try:
await asyncio.wait_for(
channel.channel_ready(),
timeout=self._settings.BACKEND_CONNECT_TIMEOUT,
)
except TimeoutError:
logger.error(
"Backend %s unreachable for user %s (timeout=%.1fs)",
target,
username,
self._settings.BACKEND_CONNECT_TIMEOUT,
)
await context.abort(
grpc.StatusCode.FAILED_PRECONDITION,
f"Spark Connect server at {target} is not reachable. "
f"Please ensure you have logged in to BERDL JupyterHub "
f"and your notebook's Spark Connect service is running.",
)
Comment on lines +228 to +233
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backend connectivity timeout currently aborts with grpc.StatusCode.FAILED_PRECONDITION. For an unreachable backend, gRPC conventions typically use UNAVAILABLE (or DEADLINE_EXCEEDED if you want to surface the timeout) so clients can apply appropriate retry/backoff behavior. Consider switching the status code (and updating tests accordingly).

Copilot uses AI. Check for mistakes.
raise # unreachable after abort

fwd_metadata: tuple[tuple[str, str | bytes], ...] = tuple(metadata) if metadata else ()

logger.debug("Proxying %s for user %s → %s", method, username, target)
logger.debug("Proxying for user %s → %s", username, target)
return channel, fwd_metadata

def service(
self, handler_call_details: grpc.HandlerCallDetails
) -> grpc.RpcMethodHandler | None:
method = handler_call_details.method

if method not in SPARK_CONNECT_METHODS:
return None

req_type, resp_type = SPARK_CONNECT_METHODS[method]
metadata = handler_call_details.invocation_metadata

# Return the appropriate handler type
# Return the appropriate handler type
# Authentication is deferred to the async behavior function
if req_type == "unary" and resp_type == "unary":

async def unary_unary_behavior(request, context):
async def unary_unary_behavior(request: bytes, context: aio.ServicerContext) -> bytes:
channel, fwd_metadata = await self._authenticate(metadata, context)
return await _proxy_unary_unary(method, request, context, channel, fwd_metadata)

return grpc.unary_unary_rpc_method_handler(
Expand All @@ -198,7 +264,10 @@ async def unary_unary_behavior(request, context):
)
elif req_type == "unary" and resp_type == "stream":

async def unary_stream_behavior(request, context):
async def unary_stream_behavior(
request: bytes, context: aio.ServicerContext
) -> AsyncIterator[bytes]:
channel, fwd_metadata = await self._authenticate(metadata, context)
async for response in _proxy_unary_stream(
method, request, context, channel, fwd_metadata
):
Expand All @@ -211,7 +280,10 @@ async def unary_stream_behavior(request, context):
)
elif req_type == "stream" and resp_type == "unary":

async def stream_unary_behavior(request_iterator, context):
async def stream_unary_behavior(
request_iterator: AsyncIterator[bytes], context: aio.ServicerContext
) -> bytes:
channel, fwd_metadata = await self._authenticate(metadata, context)
return await _proxy_stream_unary(
method, request_iterator, context, channel, fwd_metadata
)
Expand All @@ -224,53 +296,6 @@ async def stream_unary_behavior(request_iterator, context):
else:
return None

def _unauthenticated_handler(
self, message: str, req_type: str, resp_type: str
) -> grpc.RpcMethodHandler:
"""Return a handler that immediately aborts with UNAUTHENTICATED."""

async def _abort_unary(_request: bytes, context: aio.ServicerContext) -> bytes:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
return b"" # unreachable but satisfies type checker

async def _abort_stream(
_request: bytes, context: aio.ServicerContext
) -> AsyncIterator[bytes]:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
return # type: ignore[return-value]
yield # noqa: F841 — unreachable, makes this a generator

async def _abort_client_stream(
_request_iterator: AsyncIterator[bytes], context: aio.ServicerContext
) -> bytes:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
return b""

if req_type == "unary" and resp_type == "unary":
return grpc.unary_unary_rpc_method_handler(
_abort_unary,
request_deserializer=_IDENTITY,
response_serializer=_IDENTITY,
)
elif req_type == "unary" and resp_type == "stream":
return grpc.unary_stream_rpc_method_handler(
_abort_stream,
request_deserializer=_IDENTITY,
response_serializer=_IDENTITY,
)
elif req_type == "stream" and resp_type == "unary":
return grpc.stream_unary_rpc_method_handler(
_abort_client_stream,
request_deserializer=_IDENTITY,
response_serializer=_IDENTITY,
)
else:
return grpc.unary_unary_rpc_method_handler(
_abort_unary,
request_deserializer=_IDENTITY,
response_serializer=_IDENTITY,
)


# ---------------------------------------------------------------------------
# Server lifecycle
Expand All @@ -288,11 +313,23 @@ async def serve(settings: ProxySettings | None = None) -> None:
cache_max_size=settings.TOKEN_CACHE_MAX_SIZE,
require_mfa=settings.REQUIRE_MFA,
)
pool = ChannelPool()
pool = ChannelPool(max_size=settings.MAX_CHANNELS_PER_BACKEND)
handler = SparkConnectProxyHandler(settings, validator, pool)
Comment on lines +316 to 317
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serve() passes settings.MAX_CHANNELS_PER_BACKEND into ChannelPool(max_size=...). However ProxySettings describes MAX_CHANNELS_PER_BACKEND as a per-backend limit, while ChannelPool's max_size is a global cap on cached targets/channels. This mismatch can lead to unexpected aggressive eviction; consider renaming the setting and updating its description, or adjusting ChannelPool to enforce a true per-backend limit.

Copilot uses AI. Check for mistakes.

server = aio.server()
server.add_generic_rpc_handlers([handler])

# Register gRPC health check service
health_servicer = health.HealthServicer()
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
# Mark the proxy as serving
health_servicer.set(
"spark.connect.SparkConnectService",
health_pb2.HealthCheckResponse.SERVING,
)
# Also set the overall server health
health_servicer.set("", health_pb2.HealthCheckResponse.SERVING)

listen_addr = f"[::]:{settings.PROXY_LISTEN_PORT}"
server.add_insecure_port(listen_addr)

Expand All @@ -305,6 +342,12 @@ async def serve(settings: ProxySettings | None = None) -> None:
await server.wait_for_termination()
finally:
logger.info("Shutting down proxy server...")
# Mark as not serving before closing
health_servicer.set(
"spark.connect.SparkConnectService",
health_pb2.HealthCheckResponse.NOT_SERVING,
)
health_servicer.set("", health_pb2.HealthCheckResponse.NOT_SERVING)
await pool.close_all()
await server.stop(grace=5)
logger.info("Proxy server stopped.")
Loading