Skip to content

Commit

Permalink
Generic Collection (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
apetenchea authored Sep 30, 2024
1 parent 8dea5f4 commit 8c8b237
Show file tree
Hide file tree
Showing 11 changed files with 636 additions and 74 deletions.
24 changes: 16 additions & 8 deletions arangoasync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Deserializer,
Serializer,
)
from arangoasync.typings import Json, Jsons
from arangoasync.version import __version__


Expand Down Expand Up @@ -51,14 +52,18 @@ class ArangoClient:
<arangoasync.compression.DefaultCompressionManager>`
or a custom subclass of :class:`CompressionManager
<arangoasync.compression.CompressionManager>`.
serializer (Serializer | None): Custom serializer implementation.
serializer (Serializer | None): Custom JSON serializer implementation.
Leave as `None` to use the default serializer.
See :class:`DefaultSerializer
<arangoasync.serialization.DefaultSerializer>`.
deserializer (Deserializer | None): Custom deserializer implementation.
For custom serialization of collection documents, see :class:`Collection
<arangoasync.collection.Collection>`.
deserializer (Deserializer | None): Custom JSON deserializer implementation.
Leave as `None` to use the default deserializer.
See :class:`DefaultDeserializer
<arangoasync.serialization.DefaultDeserializer>`.
For custom deserialization of collection documents, see :class:`Collection
<arangoasync.collection.Collection>`.
Raises:
ValueError: If the `host_resolver` is not supported.
Expand All @@ -70,8 +75,8 @@ def __init__(
host_resolver: str | HostResolver = "default",
http_client: Optional[HTTPClient] = None,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
serializer: Optional[Serializer[Json]] = None,
deserializer: Optional[Deserializer[Json, Jsons]] = None,
) -> None:
self._hosts = [hosts] if isinstance(hosts, str) else hosts
self._host_resolver = (
Expand All @@ -84,8 +89,10 @@ def __init__(
self._http_client.create_session(host) for host in self._hosts
]
self._compression = compression
self._serializer = serializer or DefaultSerializer()
self._deserializer = deserializer or DefaultDeserializer()
self._serializer: Serializer[Json] = serializer or DefaultSerializer()
self._deserializer: Deserializer[Json, Jsons] = (
deserializer or DefaultDeserializer()
)

def __repr__(self) -> str:
return f"<ArangoClient {','.join(self._hosts)}>"
Expand Down Expand Up @@ -142,8 +149,8 @@ async def db(
token: Optional[JwtToken] = None,
verify: bool = False,
compression: Optional[CompressionManager] = None,
serializer: Optional[Serializer] = None,
deserializer: Optional[Deserializer] = None,
serializer: Optional[Serializer[Json]] = None,
deserializer: Optional[Deserializer[Json, Jsons]] = None,
) -> StandardDatabase:
"""Connects to a database and returns and API wrapper.
Expand Down Expand Up @@ -178,6 +185,7 @@ async def db(
ServerConnectionError: If `verify` is `True` and the connection fails.
"""
connection: Connection

if auth_method == "basic":
if auth is None:
raise ValueError("Basic authentication requires the `auth` parameter")
Expand Down
205 changes: 205 additions & 0 deletions arangoasync/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
__all__ = ["Collection", "Collection", "StandardCollection"]


from enum import Enum
from typing import Generic, Optional, Tuple, TypeVar

from arangoasync.errno import HTTP_NOT_FOUND, HTTP_PRECONDITION_FAILED
from arangoasync.exceptions import (
DocumentGetError,
DocumentParseError,
DocumentRevisionError,
)
from arangoasync.executor import ApiExecutor
from arangoasync.request import Method, Request
from arangoasync.response import Response
from arangoasync.serialization import Deserializer, Serializer
from arangoasync.typings import Json, Result

T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")


class CollectionType(Enum):
"""Collection types."""

DOCUMENT = 2
EDGE = 3


class Collection(Generic[T, U, V]):
"""Base class for collection API wrappers.
Args:
executor (ApiExecutor): API executor.
name (str): Collection name
doc_serializer (Serializer): Document serializer.
doc_deserializer (Deserializer): Document deserializer.
"""

def __init__(
self,
executor: ApiExecutor,
name: str,
doc_serializer: Serializer[T],
doc_deserializer: Deserializer[U, V],
) -> None:
self._executor = executor
self._name = name
self._doc_serializer = doc_serializer
self._doc_deserializer = doc_deserializer
self._id_prefix = f"{self._name}/"

def __repr__(self) -> str:
return f"<StandardCollection {self.name}>"

def _validate_id(self, doc_id: str) -> str:
"""Check the collection name in the document ID.
Args:
doc_id (str): Document ID.
Returns:
str: Verified document ID.
Raises:
DocumentParseError: On bad collection name.
"""
if not doc_id.startswith(self._id_prefix):
raise DocumentParseError(f'Bad collection name in document ID "{doc_id}"')
return doc_id

def _extract_id(self, body: Json) -> str:
"""Extract the document ID from document body.
Args:
body (dict): Document body.
Returns:
str: Document ID.
Raises:
DocumentParseError: On missing ID and key.
"""
try:
if "_id" in body:
return self._validate_id(body["_id"])
else:
key: str = body["_key"]
return self._id_prefix + key
except KeyError:
raise DocumentParseError('Field "_key" or "_id" required')

def _prep_from_doc(
self,
document: str | Json,
rev: Optional[str] = None,
check_rev: bool = False,
) -> Tuple[str, Json]:
"""Prepare document ID, body and request headers before a query.
Args:
document (str | dict): Document ID, key or body.
rev (str | None): Document revision.
check_rev (bool): Whether to check the revision.
Returns:
Document ID and request headers.
Raises:
DocumentParseError: On missing ID and key.
TypeError: On bad document type.
"""
if isinstance(document, dict):
doc_id = self._extract_id(document)
rev = rev or document.get("_rev")
elif isinstance(document, str):
if "/" in document:
doc_id = self._validate_id(document)
else:
doc_id = self._id_prefix + document
else:
raise TypeError("Document must be str or a dict")

if not check_rev or rev is None:
return doc_id, {}
else:
return doc_id, {"If-Match": rev}

@property
def name(self) -> str:
"""Return the name of the collection.
Returns:
str: Collection name.
"""
return self._name


class StandardCollection(Collection[T, U, V]):
"""Standard collection API wrapper.
Args:
executor (ApiExecutor): API executor.
name (str): Collection name
doc_serializer (Serializer): Document serializer.
doc_deserializer (Deserializer): Document deserializer.
"""

def __init__(
self,
executor: ApiExecutor,
name: str,
doc_serializer: Serializer[T],
doc_deserializer: Deserializer[U, V],
) -> None:
super().__init__(executor, name, doc_serializer, doc_deserializer)

async def get(
self,
document: str | Json,
rev: Optional[str] = None,
check_rev: bool = True,
allow_dirty_read: bool = False,
) -> Result[Optional[U]]:
"""Return a document.
Args:
document (str | dict): Document ID, key or body.
Document body must contain the "_id" or "_key" field.
rev (str | None): Expected document revision. Overrides the
value of "_rev" field in **document** if present.
check_rev (bool): If set to True, revision of **document** (if given)
is compared against the revision of target document.
allow_dirty_read (bool): Allow reads from followers in a cluster.
Returns:
Document or None if not found.
Raises:
DocumentRevisionError: If the revision is incorrect.
DocumentGetError: If retrieval fails.
"""
handle, headers = self._prep_from_doc(document, rev, check_rev)

if allow_dirty_read:
headers["x-arango-allow-dirty-read"] = "true"

request = Request(
method=Method.GET,
endpoint=f"/_api/document/{handle}",
headers=headers,
)

def response_handler(resp: Response) -> Optional[U]:
if resp.is_success:
return self._doc_deserializer.loads(resp.raw_body)
elif resp.error_code == HTTP_NOT_FOUND:
return None
elif resp.error_code == HTTP_PRECONDITION_FAILED:
raise DocumentRevisionError(resp, request)
else:
raise DocumentGetError(resp, request)

return await self._executor.execute(request, response_handler)
Loading

0 comments on commit 8c8b237

Please sign in to comment.