Skip to content

Commit

Permalink
Merge pull request #87 from Avivsalem/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
Avivsalem authored Aug 24, 2023
2 parents 988df41 + 0dd62f0 commit 47e106e
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 46 deletions.
83 changes: 58 additions & 25 deletions messageflux/iodevices/sqs/sqs_input_device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import threading
from io import BytesIO
from typing import Optional, Union, TYPE_CHECKING
from typing import Optional, Union, TYPE_CHECKING, Dict, List, Any

from messageflux.iodevices.base import (
InputDevice,
Expand All @@ -15,7 +15,7 @@
from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase

if TYPE_CHECKING:
from mypy_boto3_sqs.service_resource import Message as SQSMessage
from mypy_boto3_sqs.service_resource import Message as SQSMessage, SQSServiceResource


class SQSInputTransaction(InputTransaction):
Expand Down Expand Up @@ -57,13 +57,16 @@ def __init__(
self,
device_manager: "SQSInputDeviceManager",
queue_name: str,
included_message_attributes: Optional[Union[str, list]] = None, # TODO: what's this?
max_messages_per_request: int = 1,
included_message_attributes: Optional[Union[str, List[str]]] = None,
):
"""
constructs a new input SQS device
:param device_manager: the SQS device Manager that holds this device
:param queue_name: the name for the queue
:param max_messages_per_request: maximum messages to retrieve from the queue (max 10)
:param included_message_attributes: list of message attributes to get for the message. defaults to ALL
"""
super().__init__(device_manager, queue_name)
Expand All @@ -72,8 +75,27 @@ def __init__(
included_message_attributes = ["All"]

self._included_message_attributes = included_message_attributes
self._max_messages_per_request = 1 # TODO: get this in manager
self._max_messages_per_request = min(max_messages_per_request, 10)
self._queue = self.manager.get_queue(queue_name)
self._message_cache: List['SQSMessage'] = []

def _get_sqs_message(self, timeout: Optional[float]) -> 'Optional[SQSMessage]':
if not self._message_cache:
additional_args: Dict[str, Any] = {}

if timeout is not None:
additional_args = dict(WaitTimeSeconds=int(timeout))

sqs_messages = self._queue.receive_messages(
MessageAttributeNames=self._included_message_attributes,
MaxNumberOfMessages=self._max_messages_per_request,
**additional_args
)
if not sqs_messages:
return None
self._message_cache.extend(sqs_messages)

return self._message_cache.pop(0)

def _read_message(
self,
Expand All @@ -87,25 +109,11 @@ def _read_message(
:param timeout: the timeout in seconds to block. negative number means no blocking
:return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available
"""
if timeout is None:
sqs_messages = self._queue.receive_messages(
MessageAttributeNames=self._included_message_attributes,
MaxNumberOfMessages=self._max_messages_per_request,
) # TODO: what's the visibility timeout? should we extend it?
else:
sqs_messages = self._queue.receive_messages(
MessageAttributeNames=self._included_message_attributes,
MaxNumberOfMessages=self._max_messages_per_request,
WaitTimeSeconds=int(timeout)
) # TODO: what's the visibility timeout? should we extend it?

if not sqs_messages:
sqs_message = self._get_sqs_message(timeout=timeout)
if sqs_message is None:
return None

assert (len(sqs_messages) == 1), "SQSInputDevice should only return one message at a time"

sqs_message = sqs_messages[0]

transaction: InputTransaction
if with_transaction:
transaction = SQSInputTransaction(device=self,
Expand All @@ -114,11 +122,13 @@ def _read_message(
transaction = NULLTransaction(self)
sqs_message.delete()

message_attributes = sqs_message.message_attributes or {}

return ReadResult(
message=Message(
headers={
key: value["BinaryValue"] if value['DataType'] == "Binary" else value['StringValue']
for key, value in sqs_message.message_attributes.items()
for key, value in message_attributes.items()
},
data=BytesIO(sqs_message.body.encode()),
),
Expand All @@ -131,16 +141,39 @@ class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]):
SQS input device manager
"""

def get_input_device(self, device_name: str) -> SQSInputDevice:
def __init__(self, *,
sqs_resource: Optional['SQSServiceResource'] = None,
max_messages_per_request: int = 1,
included_message_attributes: Optional[Union[str, List[str]]] = None, ) -> None:
"""
:param sqs_resource: the boto sqs service resource. Defaults to creating from env vars
:param max_messages_per_request: maximum messages to retrieve from the queue (max 10)
:param included_message_attributes: list of message attributes to get for the message. defaults to ALL
"""
super().__init__(sqs_resource=sqs_resource)
self._device_cache: Dict[str, SQSInputDevice] = {}
self._max_messages_per_request = max_messages_per_request
self._included_message_attributes = included_message_attributes

def get_input_device(self, name: str) -> SQSInputDevice:
"""
Returns an incoming device by name
:param device_name: the name of the device to read from
:param name: the name of the device to read from
:return: an input device for 'device_name'
"""
try:
return SQSInputDevice(self, device_name)
device = self._device_cache.get(name, None)
if device is None:
device = SQSInputDevice(device_manager=self,
queue_name=name,
max_messages_per_request=self._max_messages_per_request,
included_message_attributes=self._included_message_attributes)

self._device_cache[name] = device

return device
except Exception as e:
message = f"Couldn't create input device '{device_name}'"
message = f"Couldn't create input device '{name}'"
self._logger.exception(message)
raise InputDeviceException(message) from e
3 changes: 2 additions & 1 deletion messageflux/iodevices/sqs/sqs_manager_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class SQSManagerBase:
base class for sqs device managers
"""

def __init__(self, sqs_resource: Optional['SQSServiceResource'] = None) -> None:
def __init__(self, *,
sqs_resource: Optional['SQSServiceResource'] = None) -> None:
"""
:param sqs_resource: the boto sqs service resource. Defaults to creating from env vars
"""
Expand Down
49 changes: 31 additions & 18 deletions messageflux/iodevices/sqs/sqs_output_device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional, Dict, TYPE_CHECKING, Any

from messageflux.iodevices.base import (
OutputDevice,
Expand All @@ -10,6 +11,9 @@
from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase
from messageflux.utils import get_random_id

if TYPE_CHECKING:
from mypy_boto3_sqs.service_resource import SQSServiceResource


class SQSOutputDevice(OutputDevice["SQSOutputDeviceManager"]):
"""
Expand All @@ -31,21 +35,17 @@ def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str):
self._logger = logging.getLogger(__name__)

def _send_message(self, message_bundle: MessageBundle):
additional_args: Dict[str, Any] = {}
if self._is_fifo:
response = self._sqs_queue.send_message(
MessageBody=message_bundle.message.bytes.decode(),
MessageAttributes=generate_message_attributes(
message_bundle.message.headers
),
MessageGroupId=get_random_id(),
)
else:
response = self._sqs_queue.send_message(
MessageBody=message_bundle.message.bytes.decode(),
MessageAttributes=generate_message_attributes(
message_bundle.message.headers
),
)
additional_args = dict(MessageGroupId=get_random_id())

response = self._sqs_queue.send_message(
MessageBody=message_bundle.message.bytes.decode(),
MessageAttributes=generate_message_attributes(
message_bundle.message.headers
),
**additional_args,
)

if "MessageId" not in response:
raise OutputDeviceException("Couldn't send message to SQS")
Expand All @@ -56,16 +56,29 @@ class SQSOutputDeviceManager(SQSManagerBase, OutputDeviceManager[SQSOutputDevice
this manager is used to create SQS devices
"""

def get_output_device(self, queue_name: str) -> SQSOutputDevice:
def __init__(self, sqs_resource: Optional['SQSServiceResource'] = None) -> None:
"""
:param sqs_resource: the boto sqs service resource. Defaults to creating from env vars
"""
super().__init__(sqs_resource=sqs_resource)
self._device_cache: Dict[str, SQSOutputDevice] = {}

def get_output_device(self, name: str) -> SQSOutputDevice:
"""
Returns and outgoing device by name
:param queue_name: the name of the queue
:param name: the name of the queue
:return: an output device for 'queue_name'
"""
try:
return SQSOutputDevice(self, queue_name)
device = self._device_cache.get(name, None)
if device is None:
device = SQSOutputDevice(self, name)
self._device_cache[name] = device

return device

except Exception as e:
message = f"Couldn't create output device '{queue_name}'"
message = f"Couldn't create output device '{name}'"
self._logger.exception(message)
raise OutputDeviceException(message) from e
23 changes: 21 additions & 2 deletions tests/devices/sqs_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from threading import Event

import boto3
from moto import mock_sqs
Expand All @@ -11,7 +12,7 @@
@mock_sqs
def test_generic_sanity():
sqs_resource = boto3.resource('sqs', region_name='us-west-2')
input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource)
input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource, max_messages_per_request=4)
output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource)
queue_name = str(uuid.uuid4())
with input_manager, output_manager:
Expand All @@ -27,7 +28,7 @@ def test_generic_sanity():
@mock_sqs
def test_generic_rollback():
sqs_resource = boto3.resource('sqs', region_name='us-west-2')
input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource)
input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource, max_messages_per_request=4)
output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource)
queue_name = str(uuid.uuid4())
with input_manager, output_manager:
Expand All @@ -38,3 +39,21 @@ def test_generic_rollback():
device_name=queue_name)
finally:
q.delete()


@mock_sqs
def test_empty_headers():
sqs_resource = boto3.resource('sqs', region_name='us-west-2')
input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource)
output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource)
queue_name = str(uuid.uuid4())
test_message = str(uuid.uuid4())
with input_manager, output_manager:
q = output_manager.create_queue(queue_name)
q.send_message(MessageBody=test_message)
try:
id = input_manager.get_input_device(queue_name)
rr = id.read_message(cancellation_token=Event())
assert rr.message.bytes.decode() == test_message
finally:
q.delete()

0 comments on commit 47e106e

Please sign in to comment.