Skip to content

Commit

Permalink
DefaultCompressionManager has "active" defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
apetenchea committed Sep 22, 2024
1 parent 98012ca commit 009da12
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
6 changes: 3 additions & 3 deletions arangoasync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ArangoClient:
responses. Enable it by passing an instance of
:class:`DefaultCompressionManager
<arangoasync.compression.DefaultCompressionManager>`
or a subclass of :class:`CompressionManager
or a custom subclass of :class:`CompressionManager
<arangoasync.compression.CompressionManager>`.
Raises:
Expand Down Expand Up @@ -143,8 +143,8 @@ async def db(
auth (Auth | None): Login information.
token (JwtToken | None): JWT token.
verify (bool): Verify the connection by sending a test request.
compression (CompressionManager | None): Supersedes the client-level
compression settings.
compression (CompressionManager | None): If set, supersedes the
client-level compression settings.
Returns:
Database: Database API wrapper.
Expand Down
11 changes: 5 additions & 6 deletions arangoasync/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,16 @@ class DefaultCompressionManager(CompressionManager):
Args:
threshold (int): Will compress requests to the server if
the size of the request body (in bytes) is at least the value of this option.
Setting it to -1 will disable request compression (default).
Setting it to -1 will disable request compression.
level (int): Compression level. Defaults to 6.
accept (str | None): Accepted encoding. By default, there is
no compression of responses.
accept (str | None): Accepted encoding. Can be disabled by setting it to `None`.
"""

def __init__(
self,
threshold: int = -1,
threshold: int = 1024,
level: int = 6,
accept: Optional[AcceptEncoding] = None,
accept: Optional[AcceptEncoding] = AcceptEncoding.DEFLATE,
) -> None:
self._threshold = threshold
self._level = level
Expand Down Expand Up @@ -132,7 +131,7 @@ def content_encoding(self) -> str:
return self._content_encoding

def needs_compression(self, data: str | bytes) -> bool:
return self._threshold != -1 and len(data) >= self._threshold
return len(data) >= self._threshold

def compress(self, data: str | bytes) -> bytes:
if isinstance(data, bytes):
Expand Down
48 changes: 37 additions & 11 deletions arangoasync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from arangoasync import errno, logger
from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager, DefaultCompressionManager
from arangoasync.compression import CompressionManager
from arangoasync.exceptions import (
AuthHeaderError,
ClientConnectionAbortedError,
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
self._host_resolver = host_resolver
self._http_client = http_client
self._db_name = db_name
self._compression = compression or DefaultCompressionManager()
self._compression = compression

@property
def db_name(self) -> str:
Expand Down Expand Up @@ -100,6 +100,38 @@ def prep_response(request: Request, resp: Response) -> Response:
resp.error_message = body.get("errorMessage")
return resp

def compress_request(self, request: Request) -> bool:
"""Compress request if needed.
Additionally, the server may be instructed to compress the response.
The decision to compress the request is based on the compression strategy
passed during the connection initialization.
The request headers and may be modified as a result of this operation.
Args:
request (Request): Request to be compressed.
Returns:
bool: True if compression settings were applied.
"""
if self._compression is None:
return False

result: bool = False
if request.data is not None and self._compression.needs_compression(
request.data
):
request.data = self._compression.compress(request.data)
request.headers["content-encoding"] = self._compression.content_encoding
result = True

accept_encoding: str | None = self._compression.accept_encoding
if accept_encoding is not None:
request.headers["accept-encoding"] = accept_encoding
result = True

return result

async def process_request(self, request: Request) -> Response:
"""Process request, potentially trying multiple hosts.
Expand Down Expand Up @@ -198,15 +230,7 @@ async def send_request(self, request: Request) -> Response:
ArangoClientError: If an error occurred from the client side.
ArangoServerError: If an error occurred from the server side.
"""
if request.data is not None and self._compression.needs_compression(
request.data
):
request.data = self._compression.compress(request.data)
request.headers["content-encoding"] = self._compression.content_encoding

accept_encoding: str | None = self._compression.accept_encoding
if accept_encoding is not None:
request.headers["accept-encoding"] = accept_encoding
self.compress_request(request)

if self._auth:
request.auth = self._auth
Expand Down Expand Up @@ -335,6 +359,7 @@ async def send_request(self, request: Request) -> Response:
raise AuthHeaderError("Failed to generate authorization header.")

request.headers["authorization"] = self._auth_header
self.compress_request(request)

resp = await self.process_request(request)
if (
Expand Down Expand Up @@ -416,6 +441,7 @@ async def send_request(self, request: Request) -> Response:
if self._auth_header is None:
raise AuthHeaderError("Failed to generate authorization header.")
request.headers["authorization"] = self._auth_header
self.compress_request(request)

resp = await self.process_request(request)
self.raise_for_status(request, resp)
Expand Down

0 comments on commit 009da12

Please sign in to comment.