diff --git a/README.md b/README.md index be74c36..b6b63b7 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,11 @@ User (PySpark) → Ingress (spark.berdl.kbase.us:443) → Spark Connect Proxy ```bash # Install -pip install -e ".[dev]" +uv sync --dev # Run tests -pytest +uv run pytest # Run locally -python -m spark_connect_proxy +uv run python -m spark_connect_proxy ``` diff --git a/pyproject.toml b/pyproject.toml index 18105d7..abe95dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ dependencies = [ "grpcio>=1.60.0", + "grpcio-health-checking>=1.60.0", "httpx>=0.27.0", "pydantic>=2.0.0", "pydantic-settings>=2.0.0", diff --git a/src/spark_connect_proxy/config.py b/src/spark_connect_proxy/config.py index 80b9432..6fb4c57 100644 --- a/src/spark_connect_proxy/config.py +++ b/src/spark_connect_proxy/config.py @@ -35,6 +35,9 @@ class ProxySettings(BaseSettings): # Maximum concurrent gRPC connections to keep per backend MAX_CHANNELS_PER_BACKEND: int = 5 + # Timeout (seconds) for initial backend connectivity check + BACKEND_CONNECT_TIMEOUT: float = 5.0 + model_config = {"env_prefix": ""} def backend_target(self, username: str) -> str: diff --git a/src/spark_connect_proxy/proxy.py b/src/spark_connect_proxy/proxy.py index 68599a2..bf3309a 100644 --- a/src/spark_connect_proxy/proxy.py +++ b/src/spark_connect_proxy/proxy.py @@ -7,11 +7,14 @@ Messages are forwarded as opaque bytes — no proto definitions required. """ +import asyncio import logging +import time from collections.abc import AsyncIterator import grpc from grpc import aio +from grpc_health.v1 import health, health_pb2, health_pb2_grpc from spark_connect_proxy.auth import AuthError, TokenValidator from spark_connect_proxy.config import ProxySettings @@ -51,26 +54,51 @@ def _extract_token(metadata: tuple[tuple[str, str | bytes], ...] | None) -> str: # --------------------------------------------------------------------------- -# Channel pool — reuse channels to the same backend +# Channel pool — reuse channels to the same backend with LRU eviction # --------------------------------------------------------------------------- class ChannelPool: - """Manages a pool of gRPC channels to backend Spark Connect servers.""" + """Manages a pool of gRPC channels to backend Spark Connect servers. - def __init__(self) -> None: - self._channels: dict[str, aio.Channel] = {} + Tracks last-used time for each channel and evicts least-recently-used + entries when the pool exceeds max_size. + """ + + def __init__(self, max_size: int = 100) -> None: + self._max_size = max_size + # target → (channel, last_used_timestamp) + self._channels: dict[str, tuple[aio.Channel, float]] = {} def get_channel(self, target: str) -> aio.Channel: """Get or create a channel to the specified backend target.""" - if target not in self._channels: - logger.info("Opening channel to backend: %s", target) - self._channels[target] = aio.insecure_channel(target) - return self._channels[target] + if target in self._channels: + channel, _ = self._channels[target] + self._channels[target] = (channel, time.monotonic()) + return channel + + # Evict LRU if at capacity + if len(self._channels) >= self._max_size: + self._evict_lru() + + logger.info("Opening channel to backend: %s", target) + channel = aio.insecure_channel(target) + self._channels[target] = (channel, time.monotonic()) + return channel + + def _evict_lru(self) -> None: + """Evict the least-recently-used channel.""" + if not self._channels: + return + lru_target = min(self._channels, key=lambda t: self._channels[t][1]) + channel, _ = self._channels.pop(lru_target) + logger.info("Evicting idle channel to: %s", lru_target) + # Close asynchronously — fire and forget + asyncio.ensure_future(channel.close()) async def close_all(self) -> None: """Close all open channels.""" - for target, channel in self._channels.items(): + for target, (channel, _) in self._channels.items(): logger.info("Closing channel to: %s", target) await channel.close() self._channels.clear() @@ -142,7 +170,7 @@ async def _proxy_stream_unary( # --------------------------------------------------------------------------- -# Generic RPC handler +# Generic RPC handler — authentication is deferred to async context # --------------------------------------------------------------------------- @@ -150,6 +178,9 @@ class SparkConnectProxyHandler(grpc.GenericRpcHandler): """ Generic gRPC handler that intercepts all Spark Connect RPCs and proxies them to the correct user's backend based on KBase token authentication. + + Authentication is performed inside the async handler behaviors (not in + the synchronous service() method) to avoid blocking the event loop. """ def __init__(self, settings: ProxySettings, validator: TokenValidator, pool: ChannelPool): @@ -157,38 +188,73 @@ def __init__(self, settings: ProxySettings, validator: TokenValidator, pool: Cha self._validator = validator self._pool = pool - def service( - self, handler_call_details: grpc.HandlerCallDetails - ) -> grpc.RpcMethodHandler | None: - method = handler_call_details.method + async def _authenticate( + self, metadata: tuple[tuple[str, str | bytes], ...] | None, context: aio.ServicerContext + ) -> tuple[aio.Channel, tuple[tuple[str, str | bytes], ...]]: + """Authenticate, resolve backend, and verify connectivity. - if method not in SPARK_CONNECT_METHODS: - return None - - req_type, resp_type = SPARK_CONNECT_METHODS[method] - metadata = handler_call_details.invocation_metadata + Returns: + (channel, forward_metadata) tuple on success. - # Authenticate and resolve backend target + Raises: + Aborts the gRPC context with UNAUTHENTICATED on auth failure, + or UNAVAILABLE if the backend is not reachable. + """ try: token = _extract_token(metadata) - username = self._validator.get_username(token) + # Run the synchronous auth call in a thread to avoid blocking + username = await asyncio.to_thread(self._validator.get_username, token) except AuthError as e: logger.warning("Authentication failed: %s", e) - return self._unauthenticated_handler(str(e), req_type, resp_type) + await context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e)) + raise # unreachable after abort, satisfies type checker target = self._settings.backend_target(username) channel = self._pool.get_channel(target) - # Forward original metadata (including x-kbase-token for server-side validation) + # Fast connectivity check — fail early if backend is unreachable + try: + await asyncio.wait_for( + channel.channel_ready(), + timeout=self._settings.BACKEND_CONNECT_TIMEOUT, + ) + except TimeoutError: + logger.error( + "Backend %s unreachable for user %s (timeout=%.1fs)", + target, + username, + self._settings.BACKEND_CONNECT_TIMEOUT, + ) + await context.abort( + grpc.StatusCode.FAILED_PRECONDITION, + f"Spark Connect server at {target} is not reachable. " + f"Please ensure you have logged in to BERDL JupyterHub " + f"and your notebook's Spark Connect service is running.", + ) + raise # unreachable after abort + fwd_metadata: tuple[tuple[str, str | bytes], ...] = tuple(metadata) if metadata else () - logger.debug("Proxying %s for user %s → %s", method, username, target) + logger.debug("Proxying for user %s → %s", username, target) + return channel, fwd_metadata + + def service( + self, handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler | None: + method = handler_call_details.method + + if method not in SPARK_CONNECT_METHODS: + return None + + req_type, resp_type = SPARK_CONNECT_METHODS[method] + metadata = handler_call_details.invocation_metadata # Return the appropriate handler type - # Return the appropriate handler type + # Authentication is deferred to the async behavior function if req_type == "unary" and resp_type == "unary": - async def unary_unary_behavior(request, context): + async def unary_unary_behavior(request: bytes, context: aio.ServicerContext) -> bytes: + channel, fwd_metadata = await self._authenticate(metadata, context) return await _proxy_unary_unary(method, request, context, channel, fwd_metadata) return grpc.unary_unary_rpc_method_handler( @@ -198,7 +264,10 @@ async def unary_unary_behavior(request, context): ) elif req_type == "unary" and resp_type == "stream": - async def unary_stream_behavior(request, context): + async def unary_stream_behavior( + request: bytes, context: aio.ServicerContext + ) -> AsyncIterator[bytes]: + channel, fwd_metadata = await self._authenticate(metadata, context) async for response in _proxy_unary_stream( method, request, context, channel, fwd_metadata ): @@ -211,7 +280,10 @@ async def unary_stream_behavior(request, context): ) elif req_type == "stream" and resp_type == "unary": - async def stream_unary_behavior(request_iterator, context): + async def stream_unary_behavior( + request_iterator: AsyncIterator[bytes], context: aio.ServicerContext + ) -> bytes: + channel, fwd_metadata = await self._authenticate(metadata, context) return await _proxy_stream_unary( method, request_iterator, context, channel, fwd_metadata ) @@ -224,53 +296,6 @@ async def stream_unary_behavior(request_iterator, context): else: return None - def _unauthenticated_handler( - self, message: str, req_type: str, resp_type: str - ) -> grpc.RpcMethodHandler: - """Return a handler that immediately aborts with UNAUTHENTICATED.""" - - async def _abort_unary(_request: bytes, context: aio.ServicerContext) -> bytes: - await context.abort(grpc.StatusCode.UNAUTHENTICATED, message) - return b"" # unreachable but satisfies type checker - - async def _abort_stream( - _request: bytes, context: aio.ServicerContext - ) -> AsyncIterator[bytes]: - await context.abort(grpc.StatusCode.UNAUTHENTICATED, message) - return # type: ignore[return-value] - yield # noqa: F841 — unreachable, makes this a generator - - async def _abort_client_stream( - _request_iterator: AsyncIterator[bytes], context: aio.ServicerContext - ) -> bytes: - await context.abort(grpc.StatusCode.UNAUTHENTICATED, message) - return b"" - - if req_type == "unary" and resp_type == "unary": - return grpc.unary_unary_rpc_method_handler( - _abort_unary, - request_deserializer=_IDENTITY, - response_serializer=_IDENTITY, - ) - elif req_type == "unary" and resp_type == "stream": - return grpc.unary_stream_rpc_method_handler( - _abort_stream, - request_deserializer=_IDENTITY, - response_serializer=_IDENTITY, - ) - elif req_type == "stream" and resp_type == "unary": - return grpc.stream_unary_rpc_method_handler( - _abort_client_stream, - request_deserializer=_IDENTITY, - response_serializer=_IDENTITY, - ) - else: - return grpc.unary_unary_rpc_method_handler( - _abort_unary, - request_deserializer=_IDENTITY, - response_serializer=_IDENTITY, - ) - # --------------------------------------------------------------------------- # Server lifecycle @@ -288,11 +313,23 @@ async def serve(settings: ProxySettings | None = None) -> None: cache_max_size=settings.TOKEN_CACHE_MAX_SIZE, require_mfa=settings.REQUIRE_MFA, ) - pool = ChannelPool() + pool = ChannelPool(max_size=settings.MAX_CHANNELS_PER_BACKEND) handler = SparkConnectProxyHandler(settings, validator, pool) server = aio.server() server.add_generic_rpc_handlers([handler]) + + # Register gRPC health check service + health_servicer = health.HealthServicer() + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + # Mark the proxy as serving + health_servicer.set( + "spark.connect.SparkConnectService", + health_pb2.HealthCheckResponse.SERVING, + ) + # Also set the overall server health + health_servicer.set("", health_pb2.HealthCheckResponse.SERVING) + listen_addr = f"[::]:{settings.PROXY_LISTEN_PORT}" server.add_insecure_port(listen_addr) @@ -305,6 +342,12 @@ async def serve(settings: ProxySettings | None = None) -> None: await server.wait_for_termination() finally: logger.info("Shutting down proxy server...") + # Mark as not serving before closing + health_servicer.set( + "spark.connect.SparkConnectService", + health_pb2.HealthCheckResponse.NOT_SERVING, + ) + health_servicer.set("", health_pb2.HealthCheckResponse.NOT_SERVING) await pool.close_all() await server.stop(grace=5) logger.info("Proxy server stopped.") diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 93e0984..235e2f3 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -55,7 +55,7 @@ def test_empty_metadata(self) -> None: class TestChannelPool: - """Tests for ChannelPool.""" + """Tests for ChannelPool with LRU eviction.""" def test_reuses_channels(self) -> None: pool = ChannelPool() @@ -86,6 +86,40 @@ async def test_close_all(self) -> None: ch2.close.assert_awaited_once() assert len(pool._channels) == 0 + def test_lru_eviction(self) -> None: + """Evicts least-recently-used channel when pool is full.""" + pool = ChannelPool(max_size=2) + + pool.get_channel("target1:15002") + pool.get_channel("target2:15002") + assert len(pool._channels) == 2 + + # Access target1 to make target2 the LRU + pool.get_channel("target1:15002") + + # Adding target3 should evict target2 (LRU) + with patch("spark_connect_proxy.proxy.asyncio.ensure_future"): + pool.get_channel("target3:15002") + + assert len(pool._channels) == 2 + assert "target1:15002" in pool._channels + assert "target3:15002" in pool._channels + assert "target2:15002" not in pool._channels + + def test_updates_last_used_on_access(self) -> None: + """Accessing an existing channel updates its last-used timestamp.""" + pool = ChannelPool() + pool.get_channel("target1:15002") + _, ts1 = pool._channels["target1:15002"] + + import time + + time.sleep(0.01) + pool.get_channel("target1:15002") + _, ts2 = pool._channels["target1:15002"] + + assert ts2 > ts1 + # --------------------------------------------------------------------------- # Proxy function tests @@ -116,7 +150,6 @@ async def test_proxy_unary_stream(self) -> None: """_proxy_unary_stream yields all responses from the backend.""" mock_channel = MagicMock() - # Create an async iterator for the streaming response async def mock_stream(*_args, **_kwargs): for chunk in [b"chunk-1", b"chunk-2", b"chunk-3"]: yield chunk @@ -185,21 +218,19 @@ def test_unknown_method_returns_none(self) -> None: result = self.handler.service(details) assert result is None - def test_valid_unary_unary(self) -> None: - """Valid token for a unary-unary method returns a handler.""" - self.validator.get_username.return_value = "alice" + def test_valid_unary_unary_returns_handler(self) -> None: + """Valid unary-unary method returns a handler (auth deferred to async).""" metadata = (("x-kbase-token", "valid"),) details = self._make_call_details( "/spark.connect.SparkConnectService/AnalyzePlan", metadata ) result = self.handler.service(details) assert result is not None - self.validator.get_username.assert_called_once_with("valid") - self.pool.get_channel.assert_called_once_with("jupyter-alice.test-ns.svc.local:15002") + # Auth should NOT be called in service() — it's deferred + self.validator.get_username.assert_not_called() - def test_valid_unary_stream(self) -> None: - """Valid token for a server-streaming method returns a handler.""" - self.validator.get_username.return_value = "bob" + def test_valid_unary_stream_returns_handler(self) -> None: + """Valid unary-stream method returns a handler.""" metadata = (("x-kbase-token", "valid"),) details = self._make_call_details( "/spark.connect.SparkConnectService/ExecutePlan", metadata @@ -207,9 +238,8 @@ def test_valid_unary_stream(self) -> None: result = self.handler.service(details) assert result is not None - def test_valid_stream_unary(self) -> None: - """Valid token for a client-streaming method returns a handler.""" - self.validator.get_username.return_value = "carol" + def test_valid_stream_unary_returns_handler(self) -> None: + """Valid stream-unary method returns a handler.""" metadata = (("x-kbase-token", "valid"),) details = self._make_call_details( "/spark.connect.SparkConnectService/AddArtifacts", metadata @@ -217,44 +247,138 @@ def test_valid_stream_unary(self) -> None: result = self.handler.service(details) assert result is not None - def test_missing_token_returns_unauthenticated_handler(self) -> None: - """Missing token returns an error handler (not None).""" - metadata = (("other-header", "value"),) - details = self._make_call_details("/spark.connect.SparkConnectService/Config", metadata) + def test_no_metadata_still_returns_handler(self) -> None: + """Missing metadata still returns a handler — auth failure happens in async.""" + details = self._make_call_details("/spark.connect.SparkConnectService/Config", None) result = self.handler.service(details) - # Should return an error handler, not None + # Handler is returned; auth will fail when invoked asynchronously assert result is not None - self.pool.get_channel.assert_not_called() - def test_invalid_token_returns_unauthenticated_handler(self) -> None: - """Invalid token returns an error handler.""" + +# --------------------------------------------------------------------------- +# Async authentication tests +# --------------------------------------------------------------------------- + + +class TestAuthenticate: + """Tests for the async _authenticate method.""" + + def setup_method(self) -> None: + self.settings = ProxySettings( + BACKEND_NAMESPACE="test-ns", + SERVICE_TEMPLATE="jupyter-{username}.{namespace}.svc.local", + BACKEND_CONNECT_TIMEOUT=1.0, + ) + self.validator = MagicMock(spec=TokenValidator) + self.mock_channel = AsyncMock() + self.mock_channel.channel_ready = AsyncMock() # resolves immediately + self.pool = MagicMock(spec=ChannelPool) + self.pool.get_channel = MagicMock(return_value=self.mock_channel) + self.handler = SparkConnectProxyHandler(self.settings, self.validator, self.pool) + + @pytest.mark.asyncio + async def test_authenticate_success(self) -> None: + """Successful auth returns channel and forward metadata.""" + self.validator.get_username.return_value = "alice" + metadata = (("x-kbase-token", "valid-token"),) + context = AsyncMock() + + channel, fwd_metadata = await self.handler._authenticate(metadata, context) + + assert channel is self.mock_channel + assert ("x-kbase-token", "valid-token") in fwd_metadata + context.abort.assert_not_called() + self.mock_channel.channel_ready.assert_awaited_once() + + @pytest.mark.asyncio + async def test_authenticate_routes_correctly(self) -> None: + """Auth resolves to correct backend target.""" + self.validator.get_username.return_value = "tgu2" + metadata = (("x-kbase-token", "tok"),) + context = AsyncMock() + + await self.handler._authenticate(metadata, context) + + self.pool.get_channel.assert_called_once_with("jupyter-tgu2.test-ns.svc.local:15002") + + @pytest.mark.asyncio + async def test_authenticate_missing_token(self) -> None: + """Missing token aborts with UNAUTHENTICATED.""" + metadata = (("other-header", "value"),) + context = AsyncMock() + context.abort = AsyncMock(side_effect=grpc.RpcError()) + + with pytest.raises(grpc.RpcError): + await self.handler._authenticate(metadata, context) + + context.abort.assert_awaited_once() + args = context.abort.call_args[0] + assert args[0] == grpc.StatusCode.UNAUTHENTICATED + + @pytest.mark.asyncio + async def test_authenticate_invalid_token(self) -> None: + """Invalid token aborts with UNAUTHENTICATED.""" self.validator.get_username.side_effect = AuthError("Invalid token") metadata = (("x-kbase-token", "bad"),) - details = self._make_call_details("/spark.connect.SparkConnectService/Config", metadata) - result = self.handler.service(details) - assert result is not None - self.pool.get_channel.assert_not_called() + context = AsyncMock() + context.abort = AsyncMock(side_effect=grpc.RpcError()) + + with pytest.raises(grpc.RpcError): + await self.handler._authenticate(metadata, context) - def test_routes_to_correct_user(self) -> None: + context.abort.assert_awaited_once() + + @pytest.mark.asyncio + async def test_authenticate_different_users(self) -> None: """Different tokens route to different backends.""" - # First user + context = AsyncMock() + self.validator.get_username.return_value = "tgu2" - details = self._make_call_details( - "/spark.connect.SparkConnectService/Config", - (("x-kbase-token", "token-tgu2"),), - ) - self.handler.service(details) - self.pool.get_channel.assert_called_with("jupyter-tgu2.test-ns.svc.local:15002") + await self.handler._authenticate((("x-kbase-token", "tok1"),), context) - # Second user - self.pool.reset_mock() self.validator.get_username.return_value = "bsadkhin" - details = self._make_call_details( - "/spark.connect.SparkConnectService/Config", - (("x-kbase-token", "token-boris"),), - ) - self.handler.service(details) - self.pool.get_channel.assert_called_with("jupyter-bsadkhin.test-ns.svc.local:15002") + await self.handler._authenticate((("x-kbase-token", "tok2"),), context) + + assert self.pool.get_channel.call_count == 2 + + @pytest.mark.asyncio + async def test_authenticate_backend_unreachable(self) -> None: + """Unreachable backend aborts with UNAVAILABLE and clear error message.""" + self.validator.get_username.return_value = "tgu3" + + # Simulate channel_ready() that never completes + async def never_ready(): + await asyncio.sleep(999) + + self.mock_channel.channel_ready = never_ready + self.settings.BACKEND_CONNECT_TIMEOUT = 0.1 # short timeout for test + + metadata = (("x-kbase-token", "tok"),) + context = AsyncMock() + context.abort = AsyncMock() + + with pytest.raises(TimeoutError): + await self.handler._authenticate(metadata, context) + + context.abort.assert_awaited_once() + args = context.abort.call_args[0] + assert args[0] == grpc.StatusCode.FAILED_PRECONDITION + assert "not reachable" in args[1] + assert "jupyter-tgu3" in args[1] + assert "BERDL JupyterHub" in args[1] + + @pytest.mark.asyncio + async def test_authenticate_backend_reachable(self) -> None: + """Reachable backend passes the connectivity check.""" + self.validator.get_username.return_value = "tgu2" + # channel_ready resolves immediately (default mock behavior) + metadata = (("x-kbase-token", "tok"),) + context = AsyncMock() + + channel, _ = await self.handler._authenticate(metadata, context) + + assert channel is self.mock_channel + context.abort.assert_not_called() # --------------------------------------------------------------------------- @@ -263,7 +387,7 @@ def test_routes_to_correct_user(self) -> None: class TestHandlerBehaviors: - """Tests that actually invoke the async handler behaviors returned by service().""" + """Tests that invoke the async handler behaviors returned by service().""" def setup_method(self) -> None: self.settings = ProxySettings( @@ -272,7 +396,10 @@ def setup_method(self) -> None: ) self.validator = MagicMock(spec=TokenValidator) self.validator.get_username.return_value = "testuser" - self.pool = ChannelPool() + self.mock_channel = AsyncMock() + self.mock_channel.channel_ready = AsyncMock() + self.pool = MagicMock(spec=ChannelPool) + self.pool.get_channel = MagicMock(return_value=self.mock_channel) self.handler = SparkConnectProxyHandler(self.settings, self.validator, self.pool) def _make_call_details( @@ -284,23 +411,23 @@ def _make_call_details( return details @pytest.mark.asyncio - async def test_unary_unary_behavior_invokes_proxy(self) -> None: - """The unary-unary handler behavior calls _proxy_unary_unary.""" + async def test_unary_unary_behavior(self) -> None: + """The unary-unary handler authenticates and proxies.""" metadata = (("x-kbase-token", "tok"),) details = self._make_call_details("/spark.connect.SparkConnectService/Config", metadata) result = self.handler.service(details) assert result is not None - # Patch the channel's unary_unary to return a callable with patch("spark_connect_proxy.proxy._proxy_unary_unary", new_callable=AsyncMock) as mock: mock.return_value = b"response" - context = MagicMock() + context = AsyncMock() response = await result.unary_unary(b"request", context) assert response == b"response" + mock.assert_awaited_once() @pytest.mark.asyncio - async def test_unary_stream_behavior_invokes_proxy(self) -> None: - """The unary-stream handler behavior calls _proxy_unary_stream.""" + async def test_unary_stream_behavior(self) -> None: + """The unary-stream handler authenticates and proxies.""" metadata = (("x-kbase-token", "tok"),) details = self._make_call_details( "/spark.connect.SparkConnectService/ExecutePlan", metadata @@ -313,15 +440,15 @@ async def mock_proxy(*_args, **_kwargs): yield b"chunk-2" with patch("spark_connect_proxy.proxy._proxy_unary_stream", side_effect=mock_proxy): - context = MagicMock() + context = AsyncMock() chunks = [] async for chunk in result.unary_stream(b"request", context): chunks.append(chunk) assert chunks == [b"chunk-1", b"chunk-2"] @pytest.mark.asyncio - async def test_stream_unary_behavior_invokes_proxy(self) -> None: - """The stream-unary handler behavior calls _proxy_stream_unary.""" + async def test_stream_unary_behavior(self) -> None: + """The stream-unary handler authenticates and proxies.""" metadata = (("x-kbase-token", "tok"),) details = self._make_call_details( "/spark.connect.SparkConnectService/AddArtifacts", metadata @@ -335,90 +462,11 @@ async def test_stream_unary_behavior_invokes_proxy(self) -> None: async def request_iter(): yield b"part" - context = MagicMock() + context = AsyncMock() response = await result.stream_unary(request_iter(), context) assert response == b"aggregated" -# --------------------------------------------------------------------------- -# Unauthenticated handler tests -# --------------------------------------------------------------------------- - - -class TestUnauthenticatedHandlers: - """Tests for _unauthenticated_handler for all RPC types.""" - - def setup_method(self) -> None: - self.settings = ProxySettings( - BACKEND_NAMESPACE="test-ns", - SERVICE_TEMPLATE="jupyter-{username}.{namespace}.svc.local", - ) - self.validator = MagicMock(spec=TokenValidator) - self.pool = MagicMock(spec=ChannelPool) - self.handler = SparkConnectProxyHandler(self.settings, self.validator, self.pool) - - def test_unauthenticated_unary_unary(self) -> None: - """Returns unary-unary abort handler for unary-unary methods.""" - result = self.handler._unauthenticated_handler("bad token", "unary", "unary") - assert result is not None - assert result.unary_unary is not None - - def test_unauthenticated_unary_stream(self) -> None: - """Returns unary-stream abort handler for server-streaming methods.""" - result = self.handler._unauthenticated_handler("bad token", "unary", "stream") - assert result is not None - assert result.unary_stream is not None - - def test_unauthenticated_stream_unary(self) -> None: - """Returns stream-unary abort handler for client-streaming methods.""" - result = self.handler._unauthenticated_handler("bad token", "stream", "unary") - assert result is not None - assert result.stream_unary is not None - - def test_unauthenticated_unknown_type_fallback(self) -> None: - """Unknown RPC types fall back to unary-unary abort handler.""" - result = self.handler._unauthenticated_handler("bad token", "stream", "stream") - assert result is not None - assert result.unary_unary is not None - - @pytest.mark.asyncio - async def test_abort_unary_calls_context_abort(self) -> None: - """The unary abort handler calls context.abort with UNAUTHENTICATED.""" - result = self.handler._unauthenticated_handler("invalid", "unary", "unary") - context = AsyncMock() - context.abort = AsyncMock(side_effect=grpc.RpcError()) - - with pytest.raises(grpc.RpcError): - await result.unary_unary(b"req", context) - context.abort.assert_awaited_once_with(grpc.StatusCode.UNAUTHENTICATED, "invalid") - - @pytest.mark.asyncio - async def test_abort_stream_calls_context_abort(self) -> None: - """The stream abort handler calls context.abort with UNAUTHENTICATED.""" - result = self.handler._unauthenticated_handler("invalid", "unary", "stream") - context = AsyncMock() - context.abort = AsyncMock(side_effect=grpc.RpcError()) - - with pytest.raises(grpc.RpcError): - async for _ in result.unary_stream(b"req", context): - pass - context.abort.assert_awaited_once_with(grpc.StatusCode.UNAUTHENTICATED, "invalid") - - @pytest.mark.asyncio - async def test_abort_client_stream_calls_context_abort(self) -> None: - """The client-stream abort handler calls context.abort with UNAUTHENTICATED.""" - result = self.handler._unauthenticated_handler("invalid", "stream", "unary") - context = AsyncMock() - context.abort = AsyncMock(side_effect=grpc.RpcError()) - - async def request_iter(): - yield b"data" - - with pytest.raises(grpc.RpcError): - await result.stream_unary(request_iter(), context) - context.abort.assert_awaited_once_with(grpc.StatusCode.UNAUTHENTICATED, "invalid") - - # --------------------------------------------------------------------------- # serve() function tests # --------------------------------------------------------------------------- @@ -431,18 +479,25 @@ class TestServe: async def test_serve_starts_and_stops(self) -> None: """serve() creates a server, starts it, and can be shut down.""" mock_server = AsyncMock() - # Simulate termination by raising an exception mock_server.wait_for_termination = AsyncMock(side_effect=asyncio.CancelledError) + # Make sync methods return non-coroutines + mock_server.add_generic_rpc_handlers = MagicMock() + mock_server.add_insecure_port = MagicMock() with ( patch("spark_connect_proxy.proxy.aio.server", return_value=mock_server), patch("spark_connect_proxy.proxy.TokenValidator"), patch("spark_connect_proxy.proxy.ChannelPool") as mock_pool_cls, + patch("spark_connect_proxy.proxy.health.HealthServicer") as mock_health_cls, + patch("spark_connect_proxy.proxy.health_pb2_grpc.add_HealthServicer_to_server"), ): mock_pool = MagicMock() mock_pool.close_all = AsyncMock() mock_pool_cls.return_value = mock_pool + mock_health = MagicMock() + mock_health_cls.return_value = mock_health + settings = ProxySettings() with pytest.raises(asyncio.CancelledError): @@ -459,16 +514,23 @@ async def test_serve_uses_default_settings(self) -> None: """serve() creates default ProxySettings when none provided.""" mock_server = AsyncMock() mock_server.wait_for_termination = AsyncMock(side_effect=asyncio.CancelledError) + mock_server.add_generic_rpc_handlers = MagicMock() + mock_server.add_insecure_port = MagicMock() with ( patch("spark_connect_proxy.proxy.aio.server", return_value=mock_server), patch("spark_connect_proxy.proxy.TokenValidator") as mock_validator_cls, patch("spark_connect_proxy.proxy.ChannelPool") as mock_pool_cls, + patch("spark_connect_proxy.proxy.health.HealthServicer") as mock_health_cls, + patch("spark_connect_proxy.proxy.health_pb2_grpc.add_HealthServicer_to_server"), ): mock_pool = MagicMock() mock_pool.close_all = AsyncMock() mock_pool_cls.return_value = mock_pool + mock_health = MagicMock() + mock_health_cls.return_value = mock_health + with pytest.raises(asyncio.CancelledError): await serve() # No settings — uses defaults @@ -476,21 +538,60 @@ async def test_serve_uses_default_settings(self) -> None: call_kwargs = mock_validator_cls.call_args[1] assert call_kwargs["auth_url"] == "https://kbase.us/services/auth/" + @pytest.mark.asyncio + async def test_serve_registers_health_check(self) -> None: + """serve() registers gRPC health check service.""" + mock_server = AsyncMock() + mock_server.wait_for_termination = AsyncMock(side_effect=asyncio.CancelledError) + mock_server.add_generic_rpc_handlers = MagicMock() + mock_server.add_insecure_port = MagicMock() + + with ( + patch("spark_connect_proxy.proxy.aio.server", return_value=mock_server), + patch("spark_connect_proxy.proxy.TokenValidator"), + patch("spark_connect_proxy.proxy.ChannelPool") as mock_pool_cls, + patch("spark_connect_proxy.proxy.health.HealthServicer") as mock_health_cls, + patch( + "spark_connect_proxy.proxy.health_pb2_grpc.add_HealthServicer_to_server" + ) as mock_add, + ): + mock_pool = MagicMock() + mock_pool.close_all = AsyncMock() + mock_pool_cls.return_value = mock_pool + + mock_health = MagicMock() + mock_health_cls.return_value = mock_health + + with pytest.raises(asyncio.CancelledError): + await serve(ProxySettings()) + + # Health servicer should be added to the server + mock_add.assert_called_once_with(mock_health, mock_server) + # Health status should be set to SERVING + assert mock_health.set.call_count >= 2 # overall + service-specific + @pytest.mark.asyncio async def test_serve_listen_address(self) -> None: """serve() binds to the configured port.""" mock_server = AsyncMock() mock_server.wait_for_termination = AsyncMock(side_effect=asyncio.CancelledError) + mock_server.add_generic_rpc_handlers = MagicMock() + mock_server.add_insecure_port = MagicMock() with ( patch("spark_connect_proxy.proxy.aio.server", return_value=mock_server), patch("spark_connect_proxy.proxy.TokenValidator"), patch("spark_connect_proxy.proxy.ChannelPool") as mock_pool_cls, + patch("spark_connect_proxy.proxy.health.HealthServicer") as mock_health_cls, + patch("spark_connect_proxy.proxy.health_pb2_grpc.add_HealthServicer_to_server"), ): mock_pool = MagicMock() mock_pool.close_all = AsyncMock() mock_pool_cls.return_value = mock_pool + mock_health = MagicMock() + mock_health_cls.return_value = mock_health + settings = ProxySettings(PROXY_LISTEN_PORT=9999) with pytest.raises(asyncio.CancelledError): diff --git a/uv.lock b/uv.lock index 50e77ee..fecbf74 100644 --- a/uv.lock +++ b/uv.lock @@ -141,6 +141,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" }, ] +[[package]] +name = "grpcio-health-checking" +version = "1.78.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/ac/8eb871f4e47b11abfe45497e6187a582ec680ccd7232706d228474a8c7a5/grpcio_health_checking-1.78.0.tar.gz", hash = "sha256:78526d5c60b9b99fd18954b89f86d70033c702e96ad6ccc9749baf16136979b3", size = 17008, upload-time = "2026-02-06T10:01:47.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/30/dbaf47e2210697e2923b49eb62a6a2c07d5ee55bb40cff1e6cc0c5bb22e1/grpcio_health_checking-1.78.0-py3-none-any.whl", hash = "sha256:309798c098c5de72a9bff7172d788fdf309d246d231db9955b32e7c1c773fbeb", size = 19010, upload-time = "2026-02-06T10:01:37.949Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -300,6 +313,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "protobuf" +version = "6.33.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/25/7c72c307aafc96fa87062aa6291d9f7c94836e43214d43722e86037aac02/protobuf-6.33.5.tar.gz", hash = "sha256:6ddcac2a081f8b7b9642c09406bc6a4290128fce5f471cddd165960bb9119e5c", size = 444465, upload-time = "2026-01-29T21:51:33.494Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/79/af92d0a8369732b027e6d6084251dd8e782c685c72da161bd4a2e00fbabb/protobuf-6.33.5-cp310-abi3-win32.whl", hash = "sha256:d71b040839446bac0f4d162e758bea99c8251161dae9d0983a3b88dee345153b", size = 425769, upload-time = "2026-01-29T21:51:21.751Z" }, + { url = "https://files.pythonhosted.org/packages/55/75/bb9bc917d10e9ee13dee8607eb9ab963b7cf8be607c46e7862c748aa2af7/protobuf-6.33.5-cp310-abi3-win_amd64.whl", hash = "sha256:3093804752167bcab3998bec9f1048baae6e29505adaf1afd14a37bddede533c", size = 437118, upload-time = "2026-01-29T21:51:24.022Z" }, + { url = "https://files.pythonhosted.org/packages/a2/6b/e48dfc1191bc5b52950246275bf4089773e91cb5ba3592621723cdddca62/protobuf-6.33.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:a5cb85982d95d906df1e2210e58f8e4f1e3cdc088e52c921a041f9c9a0386de5", size = 427766, upload-time = "2026-01-29T21:51:25.413Z" }, + { url = "https://files.pythonhosted.org/packages/4e/b1/c79468184310de09d75095ed1314b839eb2f72df71097db9d1404a1b2717/protobuf-6.33.5-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:9b71e0281f36f179d00cbcb119cb19dec4d14a81393e5ea220f64b286173e190", size = 324638, upload-time = "2026-01-29T21:51:26.423Z" }, + { url = "https://files.pythonhosted.org/packages/c5/f5/65d838092fd01c44d16037953fd4c2cc851e783de9b8f02b27ec4ffd906f/protobuf-6.33.5-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8afa18e1d6d20af15b417e728e9f60f3aa108ee76f23c3b2c07a2c3b546d3afd", size = 339411, upload-time = "2026-01-29T21:51:27.446Z" }, + { url = "https://files.pythonhosted.org/packages/9b/53/a9443aa3ca9ba8724fdfa02dd1887c1bcd8e89556b715cfbacca6b63dbec/protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:cbf16ba3350fb7b889fca858fb215967792dc125b35c7976ca4818bee3521cf0", size = 323465, upload-time = "2026-01-29T21:51:28.925Z" }, + { url = "https://files.pythonhosted.org/packages/57/bf/2086963c69bdac3d7cff1cc7ff79b8ce5ea0bec6797a017e1be338a46248/protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02", size = 170687, upload-time = "2026-01-29T21:51:32.557Z" }, +] + [[package]] name = "pydantic" version = "2.12.5" @@ -473,6 +501,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "grpcio" }, + { name = "grpcio-health-checking" }, { name = "httpx" }, { name = "pydantic" }, { name = "pydantic-settings" }, @@ -490,6 +519,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "grpcio", specifier = ">=1.60.0" }, + { name = "grpcio-health-checking", specifier = ">=1.60.0" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10.0" }, { name = "pydantic", specifier = ">=2.0.0" },