diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index 6043332..3ea7429 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -1,7 +1,7 @@ import logging import threading from io import BytesIO -from typing import Optional, Union, TYPE_CHECKING +from typing import Optional, Union, TYPE_CHECKING, Dict, List, Any from messageflux.iodevices.base import ( InputDevice, @@ -15,7 +15,7 @@ from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase if TYPE_CHECKING: - from mypy_boto3_sqs.service_resource import Message as SQSMessage + from mypy_boto3_sqs.service_resource import Message as SQSMessage, SQSServiceResource class SQSInputTransaction(InputTransaction): @@ -57,13 +57,16 @@ def __init__( self, device_manager: "SQSInputDeviceManager", queue_name: str, - included_message_attributes: Optional[Union[str, list]] = None, # TODO: what's this? + max_messages_per_request: int = 1, + included_message_attributes: Optional[Union[str, List[str]]] = None, ): """ constructs a new input SQS device :param device_manager: the SQS device Manager that holds this device :param queue_name: the name for the queue + :param max_messages_per_request: maximum messages to retrieve from the queue (max 10) + :param included_message_attributes: list of message attributes to get for the message. defaults to ALL """ super().__init__(device_manager, queue_name) @@ -72,8 +75,27 @@ def __init__( included_message_attributes = ["All"] self._included_message_attributes = included_message_attributes - self._max_messages_per_request = 1 # TODO: get this in manager + self._max_messages_per_request = min(max_messages_per_request, 10) self._queue = self.manager.get_queue(queue_name) + self._message_cache: List['SQSMessage'] = [] + + def _get_sqs_message(self, timeout: Optional[float]) -> 'Optional[SQSMessage]': + if not self._message_cache: + additional_args: Dict[str, Any] = {} + + if timeout is not None: + additional_args = dict(WaitTimeSeconds=int(timeout)) + + sqs_messages = self._queue.receive_messages( + MessageAttributeNames=self._included_message_attributes, + MaxNumberOfMessages=self._max_messages_per_request, + **additional_args + ) + if not sqs_messages: + return None + self._message_cache.extend(sqs_messages) + + return self._message_cache.pop(0) def _read_message( self, @@ -87,25 +109,11 @@ def _read_message( :param timeout: the timeout in seconds to block. negative number means no blocking :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available """ - if timeout is None: - sqs_messages = self._queue.receive_messages( - MessageAttributeNames=self._included_message_attributes, - MaxNumberOfMessages=self._max_messages_per_request, - ) # TODO: what's the visibility timeout? should we extend it? - else: - sqs_messages = self._queue.receive_messages( - MessageAttributeNames=self._included_message_attributes, - MaxNumberOfMessages=self._max_messages_per_request, - WaitTimeSeconds=int(timeout) - ) # TODO: what's the visibility timeout? should we extend it? - if not sqs_messages: + sqs_message = self._get_sqs_message(timeout=timeout) + if sqs_message is None: return None - assert (len(sqs_messages) == 1), "SQSInputDevice should only return one message at a time" - - sqs_message = sqs_messages[0] - transaction: InputTransaction if with_transaction: transaction = SQSInputTransaction(device=self, @@ -114,11 +122,13 @@ def _read_message( transaction = NULLTransaction(self) sqs_message.delete() + message_attributes = sqs_message.message_attributes or {} + return ReadResult( message=Message( headers={ key: value["BinaryValue"] if value['DataType'] == "Binary" else value['StringValue'] - for key, value in sqs_message.message_attributes.items() + for key, value in message_attributes.items() }, data=BytesIO(sqs_message.body.encode()), ), @@ -131,16 +141,39 @@ class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): SQS input device manager """ - def get_input_device(self, device_name: str) -> SQSInputDevice: + def __init__(self, *, + sqs_resource: Optional['SQSServiceResource'] = None, + max_messages_per_request: int = 1, + included_message_attributes: Optional[Union[str, List[str]]] = None, ) -> None: + """ + :param sqs_resource: the boto sqs service resource. Defaults to creating from env vars + :param max_messages_per_request: maximum messages to retrieve from the queue (max 10) + :param included_message_attributes: list of message attributes to get for the message. defaults to ALL + """ + super().__init__(sqs_resource=sqs_resource) + self._device_cache: Dict[str, SQSInputDevice] = {} + self._max_messages_per_request = max_messages_per_request + self._included_message_attributes = included_message_attributes + + def get_input_device(self, name: str) -> SQSInputDevice: """ Returns an incoming device by name - :param device_name: the name of the device to read from + :param name: the name of the device to read from :return: an input device for 'device_name' """ try: - return SQSInputDevice(self, device_name) + device = self._device_cache.get(name, None) + if device is None: + device = SQSInputDevice(device_manager=self, + queue_name=name, + max_messages_per_request=self._max_messages_per_request, + included_message_attributes=self._included_message_attributes) + + self._device_cache[name] = device + + return device except Exception as e: - message = f"Couldn't create input device '{device_name}'" + message = f"Couldn't create input device '{name}'" self._logger.exception(message) raise InputDeviceException(message) from e diff --git a/messageflux/iodevices/sqs/sqs_manager_base.py b/messageflux/iodevices/sqs/sqs_manager_base.py index 4c34c42..823985b 100644 --- a/messageflux/iodevices/sqs/sqs_manager_base.py +++ b/messageflux/iodevices/sqs/sqs_manager_base.py @@ -16,7 +16,8 @@ class SQSManagerBase: base class for sqs device managers """ - def __init__(self, sqs_resource: Optional['SQSServiceResource'] = None) -> None: + def __init__(self, *, + sqs_resource: Optional['SQSServiceResource'] = None) -> None: """ :param sqs_resource: the boto sqs service resource. Defaults to creating from env vars """ diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index 60d22cb..1989d88 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -1,4 +1,5 @@ import logging +from typing import Optional, Dict, TYPE_CHECKING, Any from messageflux.iodevices.base import ( OutputDevice, @@ -10,6 +11,9 @@ from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase from messageflux.utils import get_random_id +if TYPE_CHECKING: + from mypy_boto3_sqs.service_resource import SQSServiceResource + class SQSOutputDevice(OutputDevice["SQSOutputDeviceManager"]): """ @@ -31,21 +35,17 @@ def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str): self._logger = logging.getLogger(__name__) def _send_message(self, message_bundle: MessageBundle): + additional_args: Dict[str, Any] = {} if self._is_fifo: - response = self._sqs_queue.send_message( - MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=generate_message_attributes( - message_bundle.message.headers - ), - MessageGroupId=get_random_id(), - ) - else: - response = self._sqs_queue.send_message( - MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=generate_message_attributes( - message_bundle.message.headers - ), - ) + additional_args = dict(MessageGroupId=get_random_id()) + + response = self._sqs_queue.send_message( + MessageBody=message_bundle.message.bytes.decode(), + MessageAttributes=generate_message_attributes( + message_bundle.message.headers + ), + **additional_args, + ) if "MessageId" not in response: raise OutputDeviceException("Couldn't send message to SQS") @@ -56,16 +56,29 @@ class SQSOutputDeviceManager(SQSManagerBase, OutputDeviceManager[SQSOutputDevice this manager is used to create SQS devices """ - def get_output_device(self, queue_name: str) -> SQSOutputDevice: + def __init__(self, sqs_resource: Optional['SQSServiceResource'] = None) -> None: + """ + :param sqs_resource: the boto sqs service resource. Defaults to creating from env vars + """ + super().__init__(sqs_resource=sqs_resource) + self._device_cache: Dict[str, SQSOutputDevice] = {} + + def get_output_device(self, name: str) -> SQSOutputDevice: """ Returns and outgoing device by name - :param queue_name: the name of the queue + :param name: the name of the queue :return: an output device for 'queue_name' """ try: - return SQSOutputDevice(self, queue_name) + device = self._device_cache.get(name, None) + if device is None: + device = SQSOutputDevice(self, name) + self._device_cache[name] = device + + return device + except Exception as e: - message = f"Couldn't create output device '{queue_name}'" + message = f"Couldn't create output device '{name}'" self._logger.exception(message) raise OutputDeviceException(message) from e diff --git a/tests/devices/sqs_test.py b/tests/devices/sqs_test.py index 193e7d8..88e9dfe 100644 --- a/tests/devices/sqs_test.py +++ b/tests/devices/sqs_test.py @@ -1,4 +1,5 @@ import uuid +from threading import Event import boto3 from moto import mock_sqs @@ -11,7 +12,7 @@ @mock_sqs def test_generic_sanity(): sqs_resource = boto3.resource('sqs', region_name='us-west-2') - input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource) + input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource, max_messages_per_request=4) output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource) queue_name = str(uuid.uuid4()) with input_manager, output_manager: @@ -27,7 +28,7 @@ def test_generic_sanity(): @mock_sqs def test_generic_rollback(): sqs_resource = boto3.resource('sqs', region_name='us-west-2') - input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource) + input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource, max_messages_per_request=4) output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource) queue_name = str(uuid.uuid4()) with input_manager, output_manager: @@ -38,3 +39,21 @@ def test_generic_rollback(): device_name=queue_name) finally: q.delete() + + +@mock_sqs +def test_empty_headers(): + sqs_resource = boto3.resource('sqs', region_name='us-west-2') + input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource) + output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource) + queue_name = str(uuid.uuid4()) + test_message = str(uuid.uuid4()) + with input_manager, output_manager: + q = output_manager.create_queue(queue_name) + q.send_message(MessageBody=test_message) + try: + id = input_manager.get_input_device(queue_name) + rr = id.read_message(cancellation_token=Event()) + assert rr.message.bytes.decode() == test_message + finally: + q.delete()