Skip to content

Commit

Permalink
JwtSuperuserConnection (#15)
Browse files Browse the repository at this point in the history
* Implemented JwtSuperuserConnection

* Fixed failing test
  • Loading branch information
apetenchea authored Sep 8, 2024
1 parent 44f4fa0 commit e642e5b
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 56 deletions.
33 changes: 33 additions & 0 deletions arangoasync/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import time
from dataclasses import dataclass
from typing import Optional

import jwt

Expand Down Expand Up @@ -39,6 +40,38 @@ def __init__(self, token: str) -> None:
self._token = token
self._validate()

@staticmethod
def generate_token(
secret: str | bytes,
iat: Optional[int] = None,
exp: int = 3600,
iss: str = "arangodb",
server_id: str = "client",
) -> "JwtToken":
"""Generate and return a JWT token.
Args:
secret (str | bytes): JWT secret.
iat (int): Time the token was issued in seconds. Defaults to current time.
exp (int): Time to expire in seconds.
iss (str): Issuer.
server_id (str): Server ID.
Returns:
str: JWT token.
"""
iat = iat or int(time.time())
token = jwt.encode(
payload={
"iat": iat,
"exp": iat + exp,
"iss": iss,
"server_id": server_id,
},
key=secret,
)
return JwtToken(token)

@property
def token(self) -> str:
"""Get token."""
Expand Down
140 changes: 117 additions & 23 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"BaseConnection",
"BasicConnection",
"JwtConnection",
"JwtSuperuserConnection",
]

import json
Expand All @@ -9,6 +11,7 @@

import jwt

from arangoasync import errno, logger
from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager, DefaultCompressionManager
from arangoasync.exceptions import (
Expand Down Expand Up @@ -55,25 +58,45 @@ def db_name(self) -> str:
"""Return the database name."""
return self._db_name

def prep_response(self, request: Request, resp: Response) -> Response:
"""Prepare response for return.
@staticmethod
def raise_for_status(request: Request, resp: Response) -> None:
"""Raise an exception based on the response.
Args:
request (Request): Request object.
resp (Response): Response object.
Returns:
Response: Response object
Raises:
ServerConnectionError: If the response status code is not successful.
"""
# TODO needs refactoring such that it does not throw
resp.is_success = 200 <= resp.status_code < 300
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")

@staticmethod
def prep_response(request: Request, resp: Response) -> Response:
"""Prepare response for return.
Args:
request (Request): Request object.
resp (Response): Response object.
Returns:
Response: Response object
"""
resp.is_success = 200 <= resp.status_code < 300
if not resp.is_success:
try:
body = json.loads(resp.raw_body)
except json.JSONDecodeError as e:
logger.debug(
f"Failed to decode response body: {e} (from request {request})"
)
else:
if body.get("error") is True:
resp.error_code = body.get("errorNum")
resp.error_message = body.get("errorMessage")
return resp

async def process_request(self, request: Request) -> Response:
Expand All @@ -86,7 +109,7 @@ async def process_request(self, request: Request) -> Response:
Response: Response object.
Raises:
ConnectionAbortedError: If can't connect to host(s) within limit.
ConnectionAbortedError: If it can't connect to host(s) within limit.
"""

host_index = self._host_resolver.get_host_index()
Expand All @@ -100,6 +123,7 @@ async def process_request(self, request: Request) -> Response:
ex_host_index = host_index
host_index = self._host_resolver.get_host_index()
if ex_host_index == host_index:
# Force change host if the same host is selected
self._host_resolver.change_host()
host_index = self._host_resolver.get_host_index()

Expand All @@ -117,8 +141,8 @@ async def ping(self) -> int:
ServerConnectionError: If the response status code is not successful.
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
request.headers = {"abde": "fghi"}
resp = await self.send_request(request)
self.raise_for_status(request, resp)
return resp.status_code

@abstractmethod
Expand Down Expand Up @@ -257,15 +281,15 @@ async def refresh_token(self) -> None:
if self._auth is None:
raise JWTRefreshError("Auth must be provided to refresh the token.")

data = json.dumps(
auth_data = json.dumps(
dict(username=self._auth.username, password=self._auth.password),
separators=(",", ":"),
ensure_ascii=False,
)
request = Request(
method=Method.POST,
endpoint="/_open/auth",
data=data.encode("utf-8"),
data=auth_data.encode("utf-8"),
)

try:
Expand Down Expand Up @@ -310,16 +334,86 @@ async def send_request(self, request: Request) -> Response:

request.headers["authorization"] = self._auth_header

try:
resp = await self.process_request(request)
if (
resp.status_code == 401 # Unauthorized
and self._token is not None
and self._token.needs_refresh(self._expire_leeway)
):
await self.refresh_token()
return await self.process_request(request) # Retry with new token
except ServerConnectionError:
# TODO modify after refactoring of prep_response, so we can inspect response
resp = await self.process_request(request)
if (
resp.status_code == errno.HTTP_UNAUTHORIZED
and self._token is not None
and self._token.needs_refresh(self._expire_leeway)
):
# If the token has expired, refresh it and retry the request
await self.refresh_token()
return await self.process_request(request) # Retry with new token
resp = await self.process_request(request)
self.raise_for_status(request, resp)
return resp


class JwtSuperuserConnection(BaseConnection):
"""Connection to a specific ArangoDB database, using superuser JWT.
The JWT token is not refreshed and (username and password) are not required.
Args:
sessions (list): List of client sessions.
host_resolver (HostResolver): Host resolver.
http_client (HTTPClient): HTTP client.
db_name (str): Database name.
compression (CompressionManager | None): Compression manager.
token (JwtToken | None): JWT token.
"""

def __init__(
self,
sessions: List[Any],
host_resolver: HostResolver,
http_client: HTTPClient,
db_name: str,
compression: Optional[CompressionManager] = None,
token: Optional[JwtToken] = None,
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._auth_header: Optional[str] = None
self.token = token

@property
def token(self) -> Optional[JwtToken]:
"""Get the JWT token.
Returns:
JwtToken | None: JWT token.
"""
return self._token

@token.setter
def token(self, token: Optional[JwtToken]) -> None:
"""Set the JWT token.
Args:
token (JwtToken | None): JWT token.
Setting it to None will cause the token to be automatically
refreshed on the next request, if auth information is provided.
"""
self._token = token
self._auth_header = f"bearer {self._token.token}" if self._token else None

async def send_request(self, request: Request) -> Response:
"""Send an HTTP request to the ArangoDB server.
Args:
request (Request): HTTP request.
Returns:
Response: HTTP response
Raises:
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
if self._auth_header is None:
raise AuthHeaderError("Failed to generate authorization header.")
request.headers["authorization"] = self._auth_header

resp = await self.process_request(request)
self.raise_for_status(request, resp)
return resp
3 changes: 3 additions & 0 deletions arangoasync/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ def normalized_params(self) -> Params:
normalized_params[key] = str(value)

return normalized_params

def __repr__(self) -> str:
return f"<{self.method.name} {self.endpoint}>"
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import pytest_asyncio

from tests.helpers import generate_jwt
from arangoasync.auth import JwtToken


@dataclass
Expand Down Expand Up @@ -45,8 +45,8 @@ def pytest_configure(config):
global_data.url = url
global_data.root = config.getoption("root")
global_data.password = config.getoption("password")
global_data.secret = generate_jwt(config.getoption("secret"))
global_data.token = generate_jwt(global_data.secret)
global_data.secret = config.getoption("secret")
global_data.token = JwtToken.generate_token(global_data.secret)


@pytest.fixture(autouse=False)
Expand Down Expand Up @@ -76,6 +76,7 @@ def sys_db_name():

@pytest_asyncio.fixture
async def client_session():
"""Make sure we close all sessions after the test is done."""
sessions = []

def get_client_session(client, url):
Expand Down
25 changes: 0 additions & 25 deletions tests/helpers.py

This file was deleted.

Loading

0 comments on commit e642e5b

Please sign in to comment.