Skip to content

Commit

Permalink
Adding ArangoClient (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
apetenchea authored Sep 21, 2024
1 parent e642e5b commit 98012ca
Show file tree
Hide file tree
Showing 11 changed files with 379 additions and 15 deletions.
4 changes: 2 additions & 2 deletions arangoasync/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class JwtToken:
Raises:
TypeError: If the token type is not str or bytes.
jwt.ExpiredSignatureError: If the token expired.
jwt.exceptions.ExpiredSignatureError: If the token expired.
"""

def __init__(self, token: str) -> None:
Expand Down Expand Up @@ -82,7 +82,7 @@ def token(self, token: str) -> None:
"""Set token.
Raises:
jwt.ExpiredSignatureError: If the token expired.
jwt.exceptions.ExpiredSignatureError: If the token expired.
"""
self._token = token
self._validate()
Expand Down
201 changes: 201 additions & 0 deletions arangoasync/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
__all__ = ["ArangoClient"]

import asyncio
from typing import Any, Optional, Sequence

from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager
from arangoasync.connection import (
BasicConnection,
Connection,
JwtConnection,
JwtSuperuserConnection,
)
from arangoasync.database import Database
from arangoasync.http import DefaultHTTPClient, HTTPClient
from arangoasync.resolver import HostResolver, get_resolver
from arangoasync.version import __version__


class ArangoClient:
"""ArangoDB client.
Args:
hosts (str | Sequence[str]): Host URL or list of URL's.
In case of a cluster, this would be the list of coordinators.
Which coordinator to use is determined by the `host_resolver`.
host_resolver (str | HostResolver): Host resolver strategy.
This determines how the client will choose which server to use.
Passing a string would configure a resolver with the default settings.
See :class:`DefaultHostResolver <arangoasync.resolver.DefaultHostResolver>`
and :func:`get_resolver <arangoasync.resolver.get_resolver>`
for more information.
If you need more customization, pass a subclass of
:class:`HostResolver <arangoasync.resolver.HostResolver>`.
http_client (HTTPClient | None): HTTP client implementation.
This is the core component that sends requests to the ArangoDB server.
Defaults to :class:`DefaultHttpClient <arangoasync.http.DefaultHTTPClient>`,
but you can fully customize its parameters or even use a different
implementation by subclassing
:class:`HTTPClient <arangoasync.http.HTTPClient>`.
compression (CompressionManager | None): Disabled by default.
Used to compress requests to the server or instruct the server to compress
responses. Enable it by passing an instance of
:class:`DefaultCompressionManager
<arangoasync.compression.DefaultCompressionManager>`
or a subclass of :class:`CompressionManager
<arangoasync.compression.CompressionManager>`.
Raises:
ValueError: If the `host_resolver` is not supported.
"""

def __init__(
self,
hosts: str | Sequence[str] = "http://127.0.0.1:8529",
host_resolver: str | HostResolver = "default",
http_client: Optional[HTTPClient] = None,
compression: Optional[CompressionManager] = None,
) -> None:
self._hosts = [hosts] if isinstance(hosts, str) else hosts
self._host_resolver = (
get_resolver(host_resolver, len(self._hosts))
if isinstance(host_resolver, str)
else host_resolver
)
self._http_client = http_client or DefaultHTTPClient()
self._sessions = [
self._http_client.create_session(host) for host in self._hosts
]
self._compression = compression

def __repr__(self) -> str:
return f"<ArangoClient {','.join(self._hosts)}>"

async def __aenter__(self) -> "ArangoClient":
return self

async def __aexit__(self, *exc: Any) -> None:
await self.close()

@property
def hosts(self) -> Sequence[str]:
"""Return the list of hosts."""
return self._hosts

@property
def host_resolver(self) -> HostResolver:
"""Return the host resolver."""
return self._host_resolver

@property
def compression(self) -> Optional[CompressionManager]:
"""Return the compression manager."""
return self._compression

@property
def sessions(self) -> Sequence[Any]:
"""Return the list of sessions.
You may use this to customize sessions on the fly (for example,
adjust the timeout). Not recommended unless you know what you are doing.
Warning:
Modifying only a subset of sessions may lead to unexpected behavior.
In order to keep the client in a consistent state, you should make sure
all sessions are configured in the same way.
"""
return self._sessions

@property
def version(self) -> str:
"""Return the version of the client."""
return __version__

async def close(self) -> None:
"""Close HTTP sessions."""
await asyncio.gather(*(session.close() for session in self._sessions))

async def db(
self,
name: str,
auth_method: str = "basic",
auth: Optional[Auth] = None,
token: Optional[JwtToken] = None,
verify: bool = False,
compression: Optional[CompressionManager] = None,
) -> Database:
"""Connects to a database and returns and API wrapper.
Args:
name (str): Database name.
auth_method (str): The following methods are supported:
- "basic": HTTP authentication.
Requires the `auth` parameter. The `token` parameter is ignored.
- "jwt": User JWT authentication.
At least one of the `auth` or `token` parameters are required.
If `auth` is provided, but the `token` is not, the token will be
refreshed automatically. This assumes that the clocks of the server
and client are synchronized.
- "superuser": Superuser JWT authentication.
The `token` parameter is required. The `auth` parameter is ignored.
auth (Auth | None): Login information.
token (JwtToken | None): JWT token.
verify (bool): Verify the connection by sending a test request.
compression (CompressionManager | None): Supersedes the client-level
compression settings.
Returns:
Database: Database API wrapper.
Raises:
ValueError: If the authentication is invalid.
ServerConnectionError: If `verify` is `True` and the connection fails.
"""
connection: Connection
if auth_method == "basic":
if auth is None:
raise ValueError("Basic authentication requires the `auth` parameter")
connection = BasicConnection(
sessions=self._sessions,
host_resolver=self._host_resolver,
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
auth=auth,
)
elif auth_method == "jwt":
if auth is None and token is None:
raise ValueError(
"JWT authentication requires the `auth` or `token` parameter"
)
connection = JwtConnection(
sessions=self._sessions,
host_resolver=self._host_resolver,
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
auth=auth,
token=token,
)
elif auth_method == "superuser":
if token is None:
raise ValueError(
"Superuser JWT authentication requires the `token` parameter"
)
connection = JwtSuperuserConnection(
sessions=self._sessions,
host_resolver=self._host_resolver,
http_client=self._http_client,
db_name=name,
compression=compression or self._compression,
token=token,
)
else:
raise ValueError(f"Invalid authentication method: {auth_method}")

if verify:
await connection.ping()

return Database(connection)
2 changes: 1 addition & 1 deletion arangoasync/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def accept_encoding(self) -> str | None:
"""Return the accept encoding.
This is the value of the Accept-Encoding header in the HTTP request.
Currently, only deflate and "gzip" are supported.
Currently, only "deflate" and "gzip" are supported.
Returns:
str: Accept encoding
Expand Down
14 changes: 10 additions & 4 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"BaseConnection",
"BasicConnection",
"Connection",
"JwtConnection",
"JwtSuperuserConnection",
]
Expand Down Expand Up @@ -244,9 +245,9 @@ def __init__(
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._auth = auth
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._token: Optional[JwtToken] = token
self._auth_header: Optional[str] = None
self.token = token
self.token = self._token

if self._token is None and self._auth is None:
raise ValueError("Either token or auth must be provided.")
Expand Down Expand Up @@ -323,6 +324,7 @@ async def send_request(self, request: Request) -> Response:
Response: HTTP response
Raises:
AuthHeaderError: If the authentication header could not be generated.
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
Expand Down Expand Up @@ -372,9 +374,9 @@ def __init__(
) -> None:
super().__init__(sessions, host_resolver, http_client, db_name, compression)
self._expire_leeway: int = 0
self._token: Optional[JwtToken] = None
self._token: Optional[JwtToken] = token
self._auth_header: Optional[str] = None
self.token = token
self.token = self._token

@property
def token(self) -> Optional[JwtToken]:
Expand Down Expand Up @@ -407,6 +409,7 @@ async def send_request(self, request: Request) -> Response:
Response: HTTP response
Raises:
AuthHeaderError: If the authentication header could not be generated.
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
Expand All @@ -417,3 +420,6 @@ async def send_request(self, request: Request) -> Response:
resp = await self.process_request(request)
self.raise_for_status(request, resp)
return resp


Connection = BasicConnection | JwtConnection | JwtSuperuserConnection
17 changes: 17 additions & 0 deletions arangoasync/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
__all__ = [
"Database",
]

from arangoasync.connection import BaseConnection


class Database:
"""Database API."""

def __init__(self, connection: BaseConnection) -> None:
self._conn = connection

@property
def conn(self) -> BaseConnection:
"""Return the HTTP connection."""
return self._conn
4 changes: 0 additions & 4 deletions arangoasync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ class AuthHeaderError(ArangoClientError):
"""The authentication header could not be determined."""


class JWTExpiredError(ArangoClientError):
"""JWT token has expired."""


class JWTRefreshError(ArangoClientError):
"""Failed to refresh the JWT token."""

Expand Down
10 changes: 10 additions & 0 deletions arangoasync/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
]

from abc import ABC, abstractmethod
from ssl import SSLContext, create_default_context
from typing import Any, Optional

from aiohttp import (
Expand Down Expand Up @@ -82,6 +83,10 @@ class AioHTTPClient(HTTPClient):
timeout (aiohttp.ClientTimeout | None): Client timeout settings.
300s total timeout by default for a complete request/response operation.
read_bufsize (int): Size of read buffer (64KB default).
ssl_context (ssl.SSLContext | bool): SSL validation mode.
`True` for default SSL checks (see :func:`ssl.create_default_context`).
`False` disables SSL checks.
Additionally, you can pass a custom :class:`ssl.SSLContext`.
.. _aiohttp:
https://docs.aiohttp.org/en/stable/
Expand All @@ -92,6 +97,7 @@ def __init__(
connector: Optional[BaseConnector] = None,
timeout: Optional[ClientTimeout] = None,
read_bufsize: int = 2**16,
ssl_context: bool | SSLContext = True,
) -> None:
self._connector = connector or TCPConnector(
keepalive_timeout=60, # timeout for connection reusing after releasing
Expand All @@ -102,6 +108,9 @@ def __init__(
connect=60, # max number of seconds for acquiring a pool connection
)
self._read_bufsize = read_bufsize
self._ssl_context = (
ssl_context if ssl_context is not True else create_default_context()
)

def create_session(self, host: str) -> ClientSession:
"""Return a new session given the base host URL.
Expand Down Expand Up @@ -155,6 +164,7 @@ async def send_request(
params=request.normalized_params(),
data=request.data,
auth=auth,
ssl=self._ssl_context,
) as response:
raw_body = await response.read()
return Response(
Expand Down
10 changes: 7 additions & 3 deletions arangoasync/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class HostResolver(ABC):
Args:
host_count (int): Number of hosts.
max_tries (int): Maximum number of attempts to try a host.
max_tries (int | None): Maximum number of attempts to try a host.
Will default to 3 times the number of hosts if not provided.
Raises:
ValueError: If max_tries is less than host_count.
Expand All @@ -42,7 +43,7 @@ def get_host_index(self) -> int: # pragma: no cover
raise NotImplementedError

def change_host(self) -> None:
"""If there aer multiple hosts available, switch to the next one."""
"""If there are multiple hosts available, switch to the next one."""
self._index = (self._index + 1) % self.host_count

@property
Expand All @@ -57,7 +58,10 @@ def max_tries(self) -> int:


class SingleHostResolver(HostResolver):
"""Single host resolver. Always returns the same host index."""
"""Single host resolver.
Always returns the same host index, unless prompted to change.
"""

def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
super().__init__(host_count, max_tries)
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
autodoc_typehints = "none"

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"aiohttp": ("https://docs.aiohttp.org/en/stable/", None),
"jwt": ("https://pyjwt.readthedocs.io/en/stable/", None),
}
Expand Down
Loading

0 comments on commit 98012ca

Please sign in to comment.