Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved tests #13

Merged
merged 2 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions arangoasync/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def compress(self, data: str | bytes) -> bytes:
"""
raise NotImplementedError

@property
@abstractmethod
def content_encoding(self) -> str:
"""Return the content encoding.
Expand All @@ -65,6 +66,7 @@ def content_encoding(self) -> str:
"""
raise NotImplementedError

@property
@abstractmethod
def accept_encoding(self) -> str | None:
"""Return the accept encoding.
Expand Down Expand Up @@ -101,18 +103,38 @@ def __init__(
self._content_encoding = ContentEncoding.DEFLATE.name.lower()
self._accept_encoding = accept.name.lower() if accept else None

def needs_compression(self, data: str | bytes) -> bool:
return self._threshold != -1 and len(data) >= self._threshold
@property
def threshold(self) -> int:
return self._threshold

def compress(self, data: str | bytes) -> bytes:
if data is not None:
if isinstance(data, bytes):
return zlib.compress(data, self._level)
return zlib.compress(data.encode("utf-8"), self._level)
return b""
@threshold.setter
def threshold(self, value: int) -> None:
self._threshold = value

def content_encoding(self) -> str:
return self._content_encoding
@property
def level(self) -> int:
return self._level

@level.setter
def level(self, value: int) -> None:
self._level = value

@property
def accept_encoding(self) -> str | None:
return self._accept_encoding

@accept_encoding.setter
def accept_encoding(self, value: AcceptEncoding | None) -> None:
self._accept_encoding = value.name.lower() if value else None

@property
def content_encoding(self) -> str:
return self._content_encoding

def needs_compression(self, data: str | bytes) -> bool:
return self._threshold != -1 and len(data) >= self._threshold

def compress(self, data: str | bytes) -> bytes:
if isinstance(data, bytes):
return zlib.compress(data, self._level)
return zlib.compress(data.encode("utf-8"), self._level)
12 changes: 5 additions & 7 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def prep_response(self, request: Request, resp: Response) -> Response:
ServerConnectionError: If the response status code is not successful.
"""
resp.is_success = 200 <= resp.status_code < 300
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request)
raise ServerConnectionError(resp, request, "Bad server response.")
return resp

async def process_request(self, request: Request) -> Response:
Expand Down Expand Up @@ -110,10 +112,6 @@ async def ping(self) -> int:
"""
request = Request(method=Method.GET, endpoint="/_api/collection")
resp = await self.send_request(request)
if resp.status_code in {401, 403}:
raise ServerConnectionError(resp, request, "Authentication failed.")
if not resp.is_success:
raise ServerConnectionError(resp, request, "Bad server response.")
return resp.status_code

@abstractmethod
Expand Down Expand Up @@ -161,9 +159,9 @@ async def send_request(self, request: Request) -> Response:
request.data
):
request.data = self._compression.compress(request.data)
request.headers["content-encoding"] = self._compression.content_encoding()
request.headers["content-encoding"] = self._compression.content_encoding

accept_encoding: str | None = self._compression.accept_encoding()
accept_encoding: str | None = self._compression.accept_encoding
if accept_encoding is not None:
request.headers["accept-encoding"] = accept_encoding

Expand Down
4 changes: 2 additions & 2 deletions arangoasync/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ async def send_request(
async with session.request(
request.method.name,
request.endpoint,
headers=request.headers,
params=request.params,
headers=request.normalized_headers(),
params=request.normalized_params(),
data=request.data,
auth=auth,
) as response:
Expand Down
24 changes: 8 additions & 16 deletions arangoasync/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,14 @@ def __init__(
) -> None:
self.method: Method = method
self.endpoint: str = endpoint
self.headers: RequestHeaders = self._normalize_headers(headers)
self.params: Params = self._normalize_params(params)
self.headers: RequestHeaders = headers or dict()
self.params: Params = params or dict()
self.data: Optional[bytes] = data
self.auth: Optional[Auth] = auth

@staticmethod
def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders:
def normalized_headers(self) -> RequestHeaders:
"""Normalize request headers.

Parameters:
headers (dict | None): Request headers.

Returns:
dict: Normalized request headers.
"""
Expand All @@ -85,26 +81,22 @@ def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders:
"x-arango-driver": driver_header,
}

if headers is not None:
for key, value in headers.items():
if self.headers is not None:
for key, value in self.headers.items():
normalized_headers[key.lower()] = value

return normalized_headers

@staticmethod
def _normalize_params(params: Optional[Params]) -> Params:
def normalized_params(self) -> Params:
"""Normalize URL parameters.

Parameters:
params (dict | None): URL parameters.

Returns:
dict: Normalized URL parameters.
"""
normalized_params: Params = {}

if params is not None:
for key, value in params.items():
if self.params is not None:
for key, value in self.params.items():
if isinstance(value, bool):
value = int(value)
normalized_params[key] = str(value)
Expand Down
23 changes: 21 additions & 2 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,24 @@ def test_DefaultCompressionManager_compress():
data = "a" * 10 + "b" * 10
assert manager.needs_compression(data)
assert len(manager.compress(data)) < len(data)
assert manager.content_encoding() == "deflate"
assert manager.accept_encoding() == "deflate"
assert manager.content_encoding == "deflate"
assert manager.accept_encoding == "deflate"
data = b"a" * 10 + b"b" * 10
assert manager.needs_compression(data)
assert len(manager.compress(data)) < len(data)


def test_DefaultCompressionManager_properties():
manager = DefaultCompressionManager(
threshold=1, level=9, accept=AcceptEncoding.DEFLATE
)
assert manager.threshold == 1
assert manager.level == 9
assert manager.accept_encoding == "deflate"
assert manager.content_encoding == "deflate"
manager.threshold = 10
assert manager.threshold == 10
manager.level = 2
assert manager.level == 2
manager.accept_encoding = AcceptEncoding.GZIP
assert manager.accept_encoding == "gzip"
90 changes: 89 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import zlib

import pytest

from arangoasync.auth import Auth
from arangoasync.compression import AcceptEncoding, DefaultCompressionManager
from arangoasync.connection import BasicConnection
from arangoasync.exceptions import ServerConnectionError
from arangoasync.exceptions import (
ClientConnectionError,
ConnectionAbortedError,
ServerConnectionError,
)
from arangoasync.http import AioHTTPClient
from arangoasync.request import Method, Request
from arangoasync.resolver import DefaultHostResolver
from arangoasync.response import Response


@pytest.mark.asyncio
Expand Down Expand Up @@ -40,5 +49,84 @@ async def test_BasicConnection_ping_success(
auth=Auth(username=root, password=password),
)

assert connection.db_name == sys_db_name
status_code = await connection.ping()
assert status_code == 200


@pytest.mark.asyncio
async def test_BasicConnection_with_compression(
client_session, url, sys_db_name, root, password
):
client = AioHTTPClient()
session = client_session(client, url)
resolver = DefaultHostResolver(1)
compression = DefaultCompressionManager(
threshold=2, level=5, accept=AcceptEncoding.DEFLATE
)

connection = BasicConnection(
sessions=[session],
host_resolver=resolver,
http_client=client,
db_name=sys_db_name,
auth=Auth(username=root, password=password),
compression=compression,
)

data = b"a" * 100
request = Request(method=Method.GET, endpoint="/_api/collection", data=data)
_ = await connection.send_request(request)
assert len(request.data) < len(data)
assert zlib.decompress(request.data) == data
assert request.headers["content-encoding"] == "deflate"
assert request.headers["accept-encoding"] == "deflate"


@pytest.mark.asyncio
async def test_BasicConnection_prep_response_bad_response(
client_session, url, sys_db_name
):
client = AioHTTPClient()
session = client_session(client, url)
resolver = DefaultHostResolver(1)

connection = BasicConnection(
sessions=[session],
host_resolver=resolver,
http_client=client,
db_name=sys_db_name,
)

request = Request(method=Method.GET, endpoint="/_api/collection")
response = Response(Method.GET, url, {}, 0, "ERROR", b"")

with pytest.raises(ServerConnectionError):
connection.prep_response(request, response)


@pytest.mark.asyncio
async def test_BasicConnection_process_request_connection_aborted(
monkeypatch, client_session, url, sys_db_name, root, password
):
client = AioHTTPClient()
session = client_session(client, url)
resolver = DefaultHostResolver(1, 1)

request = Request(method=Method.GET, endpoint="/_api/collection")

async def mock_send_request(*args, **kwargs):
raise ClientConnectionError("test")

monkeypatch.setattr(client, "send_request", mock_send_request)

connection = BasicConnection(
sessions=[session],
host_resolver=resolver,
http_client=client,
db_name=sys_db_name,
auth=Auth(username=root, password=password),
)

with pytest.raises(ConnectionAbortedError):
await connection.process_request(request)
Loading