From 85e0b06375047863fa4927d0aacf06e89411febc Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Sun, 8 Sep 2024 20:57:30 +0300 Subject: [PATCH] Implemented JwtSuperuserConnection --- arangoasync/auth.py | 33 +++++++++ arangoasync/connection.py | 140 +++++++++++++++++++++++++++++++------- arangoasync/request.py | 3 + tests/conftest.py | 7 +- tests/helpers.py | 25 ------- tests/test_connection.py | 107 +++++++++++++++++++++++++++-- 6 files changed, 259 insertions(+), 56 deletions(-) delete mode 100644 tests/helpers.py diff --git a/arangoasync/auth.py b/arangoasync/auth.py index 586047f..9b5da9c 100644 --- a/arangoasync/auth.py +++ b/arangoasync/auth.py @@ -5,6 +5,7 @@ import time from dataclasses import dataclass +from typing import Optional import jwt @@ -39,6 +40,38 @@ def __init__(self, token: str) -> None: self._token = token self._validate() + @staticmethod + def generate_token( + secret: str | bytes, + iat: Optional[int] = None, + exp: int = 3600, + iss: str = "arangodb", + server_id: str = "client", + ) -> "JwtToken": + """Generate and return a JWT token. + + Args: + secret (str | bytes): JWT secret. + iat (int): Time the token was issued in seconds. Defaults to current time. + exp (int): Time to expire in seconds. + iss (str): Issuer. + server_id (str): Server ID. + + Returns: + str: JWT token. + """ + iat = iat or int(time.time()) + token = jwt.encode( + payload={ + "iat": iat, + "exp": iat + exp, + "iss": iss, + "server_id": server_id, + }, + key=secret, + ) + return JwtToken(token) + @property def token(self) -> str: """Get token.""" diff --git a/arangoasync/connection.py b/arangoasync/connection.py index 7a998af..cb52a4c 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -1,6 +1,8 @@ __all__ = [ "BaseConnection", "BasicConnection", + "JwtConnection", + "JwtSuperuserConnection", ] import json @@ -9,6 +11,7 @@ import jwt +from arangoasync import errno, logger from arangoasync.auth import Auth, JwtToken from arangoasync.compression import CompressionManager, DefaultCompressionManager from arangoasync.exceptions import ( @@ -55,25 +58,45 @@ def db_name(self) -> str: """Return the database name.""" return self._db_name - def prep_response(self, request: Request, resp: Response) -> Response: - """Prepare response for return. + @staticmethod + def raise_for_status(request: Request, resp: Response) -> None: + """Raise an exception based on the response. Args: request (Request): Request object. resp (Response): Response object. - Returns: - Response: Response object - Raises: ServerConnectionError: If the response status code is not successful. """ - # TODO needs refactoring such that it does not throw - resp.is_success = 200 <= resp.status_code < 300 if resp.status_code in {401, 403}: raise ServerConnectionError(resp, request, "Authentication failed.") if not resp.is_success: raise ServerConnectionError(resp, request, "Bad server response.") + + @staticmethod + def prep_response(request: Request, resp: Response) -> Response: + """Prepare response for return. + + Args: + request (Request): Request object. + resp (Response): Response object. + + Returns: + Response: Response object + """ + resp.is_success = 200 <= resp.status_code < 300 + if not resp.is_success: + try: + body = json.loads(resp.raw_body) + except json.JSONDecodeError as e: + logger.debug( + f"Failed to decode response body: {e} (from request {request})" + ) + else: + if body.get("error") is True: + resp.error_code = body.get("errorNum") + resp.error_message = body.get("errorMessage") return resp async def process_request(self, request: Request) -> Response: @@ -86,7 +109,7 @@ async def process_request(self, request: Request) -> Response: Response: Response object. Raises: - ConnectionAbortedError: If can't connect to host(s) within limit. + ConnectionAbortedError: If it can't connect to host(s) within limit. """ host_index = self._host_resolver.get_host_index() @@ -100,6 +123,7 @@ async def process_request(self, request: Request) -> Response: ex_host_index = host_index host_index = self._host_resolver.get_host_index() if ex_host_index == host_index: + # Force change host if the same host is selected self._host_resolver.change_host() host_index = self._host_resolver.get_host_index() @@ -117,8 +141,8 @@ async def ping(self) -> int: ServerConnectionError: If the response status code is not successful. """ request = Request(method=Method.GET, endpoint="/_api/collection") - request.headers = {"abde": "fghi"} resp = await self.send_request(request) + self.raise_for_status(request, resp) return resp.status_code @abstractmethod @@ -257,7 +281,7 @@ async def refresh_token(self) -> None: if self._auth is None: raise JWTRefreshError("Auth must be provided to refresh the token.") - data = json.dumps( + auth_data = json.dumps( dict(username=self._auth.username, password=self._auth.password), separators=(",", ":"), ensure_ascii=False, @@ -265,7 +289,7 @@ async def refresh_token(self) -> None: request = Request( method=Method.POST, endpoint="/_open/auth", - data=data.encode("utf-8"), + data=auth_data.encode("utf-8"), ) try: @@ -310,16 +334,86 @@ async def send_request(self, request: Request) -> Response: request.headers["authorization"] = self._auth_header - try: - resp = await self.process_request(request) - if ( - resp.status_code == 401 # Unauthorized - and self._token is not None - and self._token.needs_refresh(self._expire_leeway) - ): - await self.refresh_token() - return await self.process_request(request) # Retry with new token - except ServerConnectionError: - # TODO modify after refactoring of prep_response, so we can inspect response + resp = await self.process_request(request) + if ( + resp.status_code == errno.HTTP_UNAUTHORIZED + and self._token is not None + and self._token.needs_refresh(self._expire_leeway) + ): + # If the token has expired, refresh it and retry the request await self.refresh_token() - return await self.process_request(request) # Retry with new token + resp = await self.process_request(request) + self.raise_for_status(request, resp) + return resp + + +class JwtSuperuserConnection(BaseConnection): + """Connection to a specific ArangoDB database, using superuser JWT. + + The JWT token is not refreshed and (username and password) are not required. + + Args: + sessions (list): List of client sessions. + host_resolver (HostResolver): Host resolver. + http_client (HTTPClient): HTTP client. + db_name (str): Database name. + compression (CompressionManager | None): Compression manager. + token (JwtToken | None): JWT token. + """ + + def __init__( + self, + sessions: List[Any], + host_resolver: HostResolver, + http_client: HTTPClient, + db_name: str, + compression: Optional[CompressionManager] = None, + token: Optional[JwtToken] = None, + ) -> None: + super().__init__(sessions, host_resolver, http_client, db_name, compression) + self._expire_leeway: int = 0 + self._token: Optional[JwtToken] = None + self._auth_header: Optional[str] = None + self.token = token + + @property + def token(self) -> Optional[JwtToken]: + """Get the JWT token. + + Returns: + JwtToken | None: JWT token. + """ + return self._token + + @token.setter + def token(self, token: Optional[JwtToken]) -> None: + """Set the JWT token. + + Args: + token (JwtToken | None): JWT token. + Setting it to None will cause the token to be automatically + refreshed on the next request, if auth information is provided. + """ + self._token = token + self._auth_header = f"bearer {self._token.token}" if self._token else None + + async def send_request(self, request: Request) -> Response: + """Send an HTTP request to the ArangoDB server. + + Args: + request (Request): HTTP request. + + Returns: + Response: HTTP response + + Raises: + ArangoClientError: If an error occurred from the client side. + ArangoServerError: If an error occurred from the server side. + """ + if self._auth_header is None: + raise AuthHeaderError("Failed to generate authorization header.") + request.headers["authorization"] = self._auth_header + + resp = await self.process_request(request) + self.raise_for_status(request, resp) + return resp diff --git a/arangoasync/request.py b/arangoasync/request.py index 8890468..9e824f0 100644 --- a/arangoasync/request.py +++ b/arangoasync/request.py @@ -102,3 +102,6 @@ def normalized_params(self) -> Params: normalized_params[key] = str(value) return normalized_params + + def __repr__(self) -> str: + return f"<{self.method.name} {self.endpoint}>" diff --git a/tests/conftest.py b/tests/conftest.py index 22d63d4..d6f5bbc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import pytest import pytest_asyncio -from tests.helpers import generate_jwt +from arangoasync.auth import JwtToken @dataclass @@ -45,8 +45,8 @@ def pytest_configure(config): global_data.url = url global_data.root = config.getoption("root") global_data.password = config.getoption("password") - global_data.secret = generate_jwt(config.getoption("secret")) - global_data.token = generate_jwt(global_data.secret) + global_data.secret = config.getoption("secret") + global_data.token = JwtToken.generate_token(global_data.secret) @pytest.fixture(autouse=False) @@ -76,6 +76,7 @@ def sys_db_name(): @pytest_asyncio.fixture async def client_session(): + """Make sure we close all sessions after the test is done.""" sessions = [] def get_client_session(client, url): diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index 78693c4..0000000 --- a/tests/helpers.py +++ /dev/null @@ -1,25 +0,0 @@ -import time - -import jwt - - -def generate_jwt(secret, exp=3600) -> str: - """Generate and return a JWT token. - - Args: - secret (str | bytes): JWT secret. - exp (int): Time to expire in seconds. - - Returns: - str: JWT token. - """ - now = int(time.time()) - return jwt.encode( - payload={ - "iat": now, - "exp": now + exp, - "iss": "arangodb", - "server_id": "client", - }, - key=secret, - ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6e0d8ff..43f0a4f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,9 +2,13 @@ import pytest -from arangoasync.auth import Auth +from arangoasync.auth import Auth, JwtToken from arangoasync.compression import AcceptEncoding, DefaultCompressionManager -from arangoasync.connection import BasicConnection, JwtConnection +from arangoasync.connection import ( + BasicConnection, + JwtConnection, + JwtSuperuserConnection, +) from arangoasync.exceptions import ( ClientConnectionAbortedError, ClientConnectionError, @@ -100,9 +104,16 @@ async def test_BasicConnection_prep_response_bad_response( request = Request(method=Method.GET, endpoint="/_api/collection") response = Response(Method.GET, url, {}, 0, "ERROR", b"") - + connection.prep_response(request, response) + assert response.is_success is False with pytest.raises(ServerConnectionError): - connection.prep_response(request, response) + connection.raise_for_status(request, response) + + error = b'{"error": true, "errorMessage": "msg", "errorNum": 404}' + response = Response(Method.GET, url, {}, 0, "ERROR", error) + connection.prep_response(request, response) + assert response.error_code == 404 + assert response.error_message == "msg" @pytest.mark.asyncio @@ -111,11 +122,16 @@ async def test_BasicConnection_process_request_connection_aborted( ): client = AioHTTPClient() session = client_session(client, url) - resolver = DefaultHostResolver(1, 1) + max_tries = 4 + resolver = DefaultHostResolver(1, max_tries=max_tries) request = Request(method=Method.GET, endpoint="/_api/collection") + tries = 0 + async def mock_send_request(*args, **kwargs): + nonlocal tries + tries += 1 raise ClientConnectionError("test") monkeypatch.setattr(client, "send_request", mock_send_request) @@ -130,6 +146,38 @@ async def mock_send_request(*args, **kwargs): with pytest.raises(ClientConnectionAbortedError): await connection.process_request(request) + assert tries == max_tries + + +def test_JwtConnection_no_auth(client_session, url, sys_db_name): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1) + with pytest.raises(ValueError): + _ = JwtConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + ) + + +@pytest.mark.asyncio +async def test_JwtConnection_invalid_token(client_session, url, sys_db_name): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1) + + invalid_token = JwtToken.generate_token("invalid token") + connection = JwtConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + token=invalid_token, + ) + with pytest.raises(ServerConnectionError): + await connection.ping() @pytest.mark.asyncio @@ -151,6 +199,7 @@ async def test_JwtConnection_ping_success( status_code = await connection1.ping() assert status_code == 200 + # Test reusing the token connection2 = JwtConnection( sessions=[session], host_resolver=resolver, @@ -161,3 +210,51 @@ async def test_JwtConnection_ping_success( assert connection2.db_name == sys_db_name status_code = await connection2.ping() assert status_code == 200 + + connection3 = JwtConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + auth=Auth(username=root, password=password), + ) + connection3.token = connection1.token + status_code = await connection1.ping() + assert status_code == 200 + + +@pytest.mark.asyncio +async def test_JwtSuperuserConnection_ping_success( + client_session, url, sys_db_name, token +): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1) + + connection = JwtSuperuserConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + token=token, + ) + assert connection.db_name == sys_db_name + status_code = await connection.ping() + assert status_code == 200 + + +@pytest.mark.asyncio +async def test_JwtSuperuserConnection_ping_failed(client_session, url, sys_db_name): + client = AioHTTPClient() + session = client_session(client, url) + resolver = DefaultHostResolver(1) + + connection = JwtSuperuserConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + token=JwtToken.generate_token("invalid token"), + ) + with pytest.raises(ServerConnectionError): + await connection.ping()