Skip to content

Commit

Permalink
Improved tests (#13)
Browse files Browse the repository at this point in the history
* Improved tests

* Marking test async
  • Loading branch information
apetenchea authored Aug 25, 2024
1 parent b67f60d commit bd76c61
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 38 deletions.
42 changes: 32 additions & 10 deletions arangoasync/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def compress(self, data: str | bytes) -> bytes:
"""
raise NotImplementedError

@property
@abstractmethod
def content_encoding(self) -> str:
"""Return the content encoding.
Expand All @@ -65,6 +66,7 @@ def content_encoding(self) -> str:
"""
raise NotImplementedError

@property
@abstractmethod
def accept_encoding(self) -> str | None:
"""Return the accept encoding.
Expand Down Expand Up @@ -101,18 +103,38 @@ def __init__(
self._content_encoding = ContentEncoding.DEFLATE.name.lower()
self._accept_encoding = accept.name.lower() if accept else None

def needs_compression(self, data: str | bytes) -> bool:
return self._threshold != -1 and len(data) >= self._threshold
@property
def threshold(self) -> int:
return self._threshold

def compress(self, data: str | bytes) -> bytes:
if data is not None:
if isinstance(data, bytes):
return zlib.compress(data, self._level)
return zlib.compress(data.encode("utf-8"), self._level)
return b""
@threshold.setter
def threshold(self, value: int) -> None:
self._threshold = value

def content_encoding(self) -> str:
return self._content_encoding
@property
def level(self) -> int:
return self._level

@level.setter
def level(self, value: int) -> None:
self._level = value

@property
def accept_encoding(self) -> str | None:
return self._accept_encoding

@accept_encoding.setter
def accept_encoding(self, value: AcceptEncoding | None) -> None:
self._accept_encoding = value.name.lower() if value else None

@property
def content_encoding(self) -> str:
return self._content_encoding

def needs_compression(self, data: str | bytes) -> bool:
return self._threshold != -1 and len(data) >= self._threshold

def compress(self, data: str | bytes) -> bytes:
if isinstance(data, bytes):
return zlib.compress(data, self._level)
return zlib.compress(data.encode("utf-8"), self._level)
12 changes: 5 additions & 7 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def prep_response(self, request: Request, resp: Response) -> Response:
ServerConnectionError: If the response status code is not successful.
"""
resp.is_success = 200 <= resp.status_code < 300
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request)
raise ServerConnectionError(resp, request, "Bad server response.")
return resp

async def process_request(self, request: Request) -> Response:
Expand Down Expand Up @@ -110,10 +112,6 @@ async def ping(self) -> int:
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
resp = await self.send_request(request)
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")
return resp.status_code

@abstractmethod
Expand Down Expand Up @@ -161,9 +159,9 @@ async def send_request(self, request: Request) -> Response:
request.data
):
request.data = self._compression.compress(request.data)
request.headers["content-encoding"] = self._compression.content_encoding()
request.headers["content-encoding"] = self._compression.content_encoding

accept_encoding: str | None = self._compression.accept_encoding()
accept_encoding: str | None = self._compression.accept_encoding
if accept_encoding is not None:
request.headers["accept-encoding"] = accept_encoding

Expand Down
4 changes: 2 additions & 2 deletions arangoasync/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ async def send_request(
async with session.request(
request.method.name,
request.endpoint,
headers=request.headers,
params=request.params,
headers=request.normalized_headers(),
params=request.normalized_params(),
data=request.data,
auth=auth,
) as response:
Expand Down
24 changes: 8 additions & 16 deletions arangoasync/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,14 @@ def __init__(
) -> None:
self.method: Method = method
self.endpoint: str = endpoint
self.headers: RequestHeaders = self._normalize_headers(headers)
self.params: Params = self._normalize_params(params)
self.headers: RequestHeaders = headers or dict()
self.params: Params = params or dict()
self.data: Optional[bytes] = data
self.auth: Optional[Auth] = auth

@staticmethod
def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders:
def normalized_headers(self) -> RequestHeaders:
"""Normalize request headers.
Parameters:
headers (dict | None): Request headers.
Returns:
dict: Normalized request headers.
"""
Expand All @@ -85,26 +81,22 @@ def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders:
"x-arango-driver": driver_header,
}

if headers is not None:
for key, value in headers.items():
if self.headers is not None:
for key, value in self.headers.items():
normalized_headers[key.lower()] = value

return normalized_headers

@staticmethod
def _normalize_params(params: Optional[Params]) -> Params:
def normalized_params(self) -> Params:
"""Normalize URL parameters.
Parameters:
params (dict | None): URL parameters.
Returns:
dict: Normalized URL parameters.
"""
normalized_params: Params = {}

if params is not None:
for key, value in params.items():
if self.params is not None:
for key, value in self.params.items():
if isinstance(value, bool):
value = int(value)
normalized_params[key] = str(value)
Expand Down
23 changes: 21 additions & 2 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,24 @@ def test_DefaultCompressionManager_compress():
data = "a" * 10 + "b" * 10
assert manager.needs_compression(data)
assert len(manager.compress(data)) < len(data)
assert manager.content_encoding() == "deflate"
assert manager.accept_encoding() == "deflate"
assert manager.content_encoding == "deflate"
assert manager.accept_encoding == "deflate"
data = b"a" * 10 + b"b" * 10
assert manager.needs_compression(data)
assert len(manager.compress(data)) < len(data)


def test_DefaultCompressionManager_properties():
manager = DefaultCompressionManager(
threshold=1, level=9, accept=AcceptEncoding.DEFLATE
)
assert manager.threshold == 1
assert manager.level == 9
assert manager.accept_encoding == "deflate"
assert manager.content_encoding == "deflate"
manager.threshold = 10
assert manager.threshold == 10
manager.level = 2
assert manager.level == 2
manager.accept_encoding = AcceptEncoding.GZIP
assert manager.accept_encoding == "gzip"
90 changes: 89 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import zlib

import pytest

from arangoasync.auth import Auth
from arangoasync.compression import AcceptEncoding, DefaultCompressionManager
from arangoasync.connection import BasicConnection
from arangoasync.exceptions import ServerConnectionError
from arangoasync.exceptions import (
ClientConnectionError,
ConnectionAbortedError,
ServerConnectionError,
)
from arangoasync.http import AioHTTPClient
from arangoasync.request import Method, Request
from arangoasync.resolver import DefaultHostResolver
from arangoasync.response import Response


@pytest.mark.asyncio
Expand Down Expand Up @@ -40,5 +49,84 @@ async def test_BasicConnection_ping_success(
auth=Auth(username=root, password=password),
)

assert connection.db_name == sys_db_name
status_code = await connection.ping()
assert status_code == 200


@pytest.mark.asyncio
async def test_BasicConnection_with_compression(
client_session, url, sys_db_name, root, password
):
client = AioHTTPClient()
session = client_session(client, url)
resolver = DefaultHostResolver(1)
compression = DefaultCompressionManager(
threshold=2, level=5, accept=AcceptEncoding.DEFLATE
)

connection = BasicConnection(
sessions=[session],
host_resolver=resolver,
http_client=client,
db_name=sys_db_name,
auth=Auth(username=root, password=password),
compression=compression,
)

data = b"a" * 100
request = Request(method=Method.GET, endpoint="/_api/collection", data=data)
_ = await connection.send_request(request)
assert len(request.data) < len(data)
assert zlib.decompress(request.data) == data
assert request.headers["content-encoding"] == "deflate"
assert request.headers["accept-encoding"] == "deflate"


@pytest.mark.asyncio
async def test_BasicConnection_prep_response_bad_response(
client_session, url, sys_db_name
):
client = AioHTTPClient()
session = client_session(client, url)
resolver = DefaultHostResolver(1)

connection = BasicConnection(
sessions=[session],
host_resolver=resolver,
http_client=client,
db_name=sys_db_name,
)

request = Request(method=Method.GET, endpoint="/_api/collection")
response = Response(Method.GET, url, {}, 0, "ERROR", b"")

with pytest.raises(ServerConnectionError):
connection.prep_response(request, response)


@pytest.mark.asyncio
async def test_BasicConnection_process_request_connection_aborted(
monkeypatch, client_session, url, sys_db_name, root, password
):
client = AioHTTPClient()
session = client_session(client, url)
resolver = DefaultHostResolver(1, 1)

request = Request(method=Method.GET, endpoint="/_api/collection")

async def mock_send_request(*args, **kwargs):
raise ClientConnectionError("test")

monkeypatch.setattr(client, "send_request", mock_send_request)

connection = BasicConnection(
sessions=[session],
host_resolver=resolver,
http_client=client,
db_name=sys_db_name,
auth=Auth(username=root, password=password),
)

with pytest.raises(ConnectionAbortedError):
await connection.process_request(request)

0 comments on commit bd76c61

Please sign in to comment.