From 5762e7cef4675f542f92aca5fe4ada13a467f302 Mon Sep 17 00:00:00 2001 From: Rafael Zilberman Date: Mon, 10 Apr 2023 08:16:20 +0300 Subject: [PATCH 01/15] Started working on SQS devices --- messageflux/iodevices/sqs/__init__.py | 0 messageflux/iodevices/sqs/sqs_input_device.py | 397 ++++++++++++++++++ .../iodevices/sqs/sqs_output_device.py | 81 ++++ requirements-sqs.txt | 2 + 4 files changed, 480 insertions(+) create mode 100644 messageflux/iodevices/sqs/__init__.py create mode 100644 messageflux/iodevices/sqs/sqs_input_device.py create mode 100644 messageflux/iodevices/sqs/sqs_output_device.py create mode 100644 requirements-sqs.txt diff --git a/messageflux/iodevices/sqs/__init__.py b/messageflux/iodevices/sqs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py new file mode 100644 index 0000000..56b3227 --- /dev/null +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -0,0 +1,397 @@ +import logging +import os +import socket +import ssl +from io import BytesIO +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, List, Any + +from messageflux.iodevices.base import InputDevice, InputTransaction, ReadResult, InputDeviceException, Message, \ + InputDeviceManager +from messageflux.iodevices.base.input_transaction import NULLTransaction +from messageflux.iodevices.rabbitmq.rabbitmq_device_manager_base import RabbitMQDeviceManagerBase +from messageflux.utils import ThreadLocalMember + +try: + from pika import spec +except ImportError as ex: + raise ImportError('Please Install the required extra: messageflux[rabbitmq]') from ex + +if TYPE_CHECKING: + from pika.adapters.blocking_connection import BlockingChannel + + +class RabbitMQInputTransaction(InputTransaction): + """ + represents a InputTransaction for RabbitMQ + """ + + def __init__(self, + device: 'RabbitMQInputDevice', + channel: 'BlockingChannel', + delivery_tag: int): + """ + + :param device: the device that returned this transaction + :param channel: the BlockingChannel that the item was read from + :param delivery_tag: the delivery tag for this item + """ + super(RabbitMQInputTransaction, self).__init__(device=device) + self._channel = channel + self._delivery_tag = delivery_tag + self._logger = logging.getLogger(__name__) + + @property + def channel(self) -> 'BlockingChannel': + """ + the channel that the item was read from + """ + return self._channel + + @property + def delivery_tag(self) -> int: + """ + the delivery tag for this item + """ + return self._delivery_tag + + def _commit(self): + try: + self._channel.basic_ack(self._delivery_tag) + except Exception: + self._logger.warning('commit failed', exc_info=True) + + def _rollback(self): + try: + self._channel.basic_nack(self._delivery_tag, requeue=True) + except Exception: + self._logger.warning('rollback failed', exc_info=True) + + +class RabbitMQInputDevice(InputDevice['RabbitMQInputDeviceManager']): + """ + represents an RabbitMQ input device + """ + + _channel: Union[ThreadLocalMember[Optional['BlockingChannel']], + Optional['BlockingChannel']] = ThreadLocalMember(init_value=None) + + @staticmethod + def _get_rabbit_headers(method_frame, header_frame): + return { + "exchange": method_frame.exchange, + "routing_key": method_frame.routing_key, + "content_type": header_frame.content_type, + "content_encoding": header_frame.content_encoding, + "priority": header_frame.priority, + "correlation_id": header_frame.correlation_id, + "reply_to": header_frame.reply_to, + "expiration": header_frame.expiration, + "message_id": header_frame.message_id, + "timestamp": header_frame.timestamp, + "type": header_frame.type, + "user_id": header_frame.user_id, + "app_id": header_frame.app_id + } + + def __init__(self, + device_manager: 'RabbitMQInputDeviceManager', + queue_name: str, + consumer_args: Optional[Dict[str, str]] = None, + prefetch_count: int = 1, + use_consumer: bool = True): + """ + constructs a new input RabbitMQ device + + :param device_manager: the RabbitMQ device Manager that holds this device + :param queue_name: the name for the queue + :param consumer_args: the arguments to create the consumer with + only relevent if "use_consumer" is True + :param int prefetch_count: the number of unacked messages that can be consumed + only relevent if "use_consumer" is True + :param bool use_consumer: True to use the 'consume' method, False to use 'basic_get' + """ + super().__init__(device_manager, queue_name) + self._device_manager = device_manager + self._queue_name = queue_name + self._logger = logging.getLogger(__name__) + if consumer_args is None: + consumer_args = {'hostname': socket.gethostname(), 'PID': str(os.getpid())} + + self._consumer_args = consumer_args + self._prefetch_count = max(1, prefetch_count) + self._use_consumer = use_consumer + self._last_consumer_auto_ack: Optional[bool] = None + + def _reconnect_device_manager(self): + """ + reconnects the RabbitMQ device manager + """ + try: + if self._channel is not None and self._channel.is_open: + assert self._channel is not None + if self._use_consumer: + self._channel.cancel() + self._channel.close() + self._channel = self._device_manager.connection.channel() + + assert self._channel is not None + self._channel.basic_qos(prefetch_count=self._prefetch_count) + except Exception as e: + raise InputDeviceException('Could not connect to rabbitmq.') from e + + def _get_channel(self) -> 'BlockingChannel': + """ + gets a channel + """ + if self._channel is None or not self._channel.is_open: + self._reconnect_device_manager() + + assert self._channel is not None + return self._channel + + def _get_data_from_queue(self, timeout: Optional[float], with_transaction: bool) -> Optional['ReadResult']: + """ + performs a single read from queue + + :param timeout: the timeout in seconds to block. negative number means no blocking + :param with_transaction: does this read is to be done with transaction? + :return: the stream and metadata, or None,None if no message in queue + """ + channel = self._get_channel() + get_timeout: Optional[float] = None + if timeout is not None: + get_timeout = max(0.01, timeout) + body: Optional[bytes] + header_frame: Optional[spec.BasicProperties] + method_frame: Optional[Union[spec.Basic.Deliver, spec.Basic.GetOk]] + + body, header_frame, method_frame = self._get_frames_from_queue(channel, + get_timeout, + with_transaction=with_transaction) + if method_frame is None: # no message in queue + return None + + assert body is not None + assert header_frame is not None + + return self._create_response_from_frames(body, + header_frame, + method_frame, + channel, + with_transaction) + + def _create_response_from_frames(self, + body: bytes, + header_frame: spec.BasicProperties, + method_frame: Union[spec.Basic.Deliver, spec.Basic.GetOk], + channel: 'BlockingChannel', + with_transaction: bool) -> ReadResult: + """ + creates the read result from the data returned from rabbitmq + + :param body: the body of the message + :param header_frame: the header frame of the message + :param method_frame: the method frame of the message + :param channel: the channel we read the message from + :param with_transaction: should we use a transaction + + :return: ReadResult object + """ + + delivery_tag = method_frame.delivery_tag + assert delivery_tag is not None + if with_transaction: + transaction: InputTransaction = RabbitMQInputTransaction(self, channel, delivery_tag) + else: + transaction = NULLTransaction(self) + headers: Dict[str, Any] = header_frame.headers or {} # type: ignore + # get the rabbitmq headers as for device headers + rabbit_headers = self._get_rabbit_headers(method_frame, header_frame) + + buf = BytesIO(body) + return ReadResult(message=Message(buf, headers), + device_headers=rabbit_headers, + transaction=transaction) + + def _get_frames_from_queue(self, + channel: 'BlockingChannel', + timeout: Optional[float], + with_transaction: bool) -> Tuple[Optional[bytes], + Optional[spec.BasicProperties], + Optional[Union[spec.Basic.Deliver, spec.Basic.GetOk]]]: + """ + gets the actual frame from queue. use_consumer effects this method + + :param BlockingChannel channel: the channel + :param float timeout: the timeout + :param bool with_transaction: do we operate within a transaction scope + """ + # this could be auto_ack = not has_transaction, but it's less clear... so it's verbose here... + if with_transaction: + auto_ack = False + else: + auto_ack = True + + if self._use_consumer: + if self._last_consumer_auto_ack is None: + self._last_consumer_auto_ack = auto_ack + + if self._last_consumer_auto_ack != auto_ack: + channel.cancel() + self._last_consumer_auto_ack = auto_ack + + method_frame: Optional[Union[spec.Basic.Deliver, spec.Basic.GetOk]] + header_frame: Optional[spec.BasicProperties] + body: Optional[bytes] + + method_frame, header_frame, body = next(channel.consume(queue=self._queue_name, + inactivity_timeout=timeout, + arguments=self._consumer_args, auto_ack=auto_ack)) + else: + method_frame, header_frame, body = channel.basic_get(queue=self._queue_name, # type: ignore + auto_ack=auto_ack) + + return body, header_frame, method_frame + + def _read_message(self, timeout: Optional[float] = None, with_transaction: bool = True) -> Optional['ReadResult']: + """ + reads a stream from InputDevice (tries getting a message. if it fails, reconnects and tries again once) + + :param timeout: the timeout in seconds to block. negative number means no blocking + :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available + """ + try: + from pika.exceptions import AMQPConnectionError, AMQPChannelError + except ImportError as exc: + raise ImportError('Please Install the required extra: messageflux[rabbitmq]') from exc + + try: + return self._get_data_from_queue(timeout=timeout, with_transaction=with_transaction) + except (AMQPConnectionError, AMQPChannelError): + self._reconnect_device_manager() + try: + return self._get_data_from_queue(timeout=timeout, with_transaction=with_transaction) + except Exception: + self._logger.exception(f"AMQError thrown. failed to get message. device name: {self._queue_name}") + raise + except Exception as e: + raise InputDeviceException('Error reading from device') from e + + def close(self): + """ + closes the connection to device + """ + try: + if self._channel is not None and self._channel.is_open: + if self._use_consumer: + self._channel.cancel() + self._channel.close() + except Exception: + self._logger.warning('Error Closing Device', exc_info=True) + + self._channel = None + + +class RabbitMQInputDeviceManager(RabbitMQDeviceManagerBase, InputDeviceManager[RabbitMQInputDevice]): + """ + rabbitmq input device manager + """ + + def __init__(self, + hosts: Union[List[str], str], + user: str, + password: str, + port: Optional[int] = None, + ssl_context: Optional[ssl.SSLContext] = None, + virtual_host: Optional[str] = None, + client_args: Optional[Dict[str, str]] = None, + heartbeat: int = 300, + connection_attempts: int = 5, + prefetch_count: int = 1, + use_consumer: bool = True, + blocked_connection_timeout: Optional[float] = None, + default_direct_exchange: Optional[str] = None + ): + """ + This manager used to create RabbitMQ devices (direct queues) + + :param hosts: the hostname or a list of hostnames of the manager + :param user: the username for the rabbitMQ manager + :param password: the password for the rabbitMQ manager + :param port: the port to connect the hosts to + :param ssl_context: the ssl context to use. None means don't use ssl at all + :param virtual_host: the virtual host to connect to + :param client_args: the arguments to create the client with + :param int heartbeat: heartbeat interval for the connection (between 0 and 65536 + :param int connection_attempts: Maximum number of retry attempts + (-1 means not to handle poison messages at all, 0 means reject all redelivered messages right away) + :param int prefetch_count: the number of unacked messages that can be consumed + :param bool use_consumer: True to use the 'consume' method, False to use 'basic_get' + :param blocked_connection_timeout: If not None, + the value is a non-negative timeout, in seconds, for the + connection to remain blocked (triggered by Connection.Blocked from + broker); if the timeout expires before connection becomes unblocked, + the connection will be torn down, triggering the adapter-specific + mechanism for informing client app about the closed connection: + passing `ConnectionBlockedTimeout` exception to on_close_callback + in asynchronous adapters or raising it in `BlockingConnection`. + + :param default_direct_exchange: optional direct exchange to bind all the queues to (None means no bind) + """ + super().__init__(hosts=hosts, + user=user, + password=password, + port=port, + ssl_context=ssl_context, + virtual_host=virtual_host, + client_args=client_args, + connection_type="Input", + heartbeat=heartbeat, + connection_attempts=connection_attempts, + blocked_connection_timeout=blocked_connection_timeout) + + self._prefetch_count = prefetch_count + self._use_consumer = use_consumer + self._default_direct_exchange = default_direct_exchange + + def _device_factory(self, device_name: str) -> RabbitMQInputDevice: + + return RabbitMQInputDevice(device_manager=self, + queue_name=device_name, + consumer_args=self._client_args, + prefetch_count=self._prefetch_count, + use_consumer=self._use_consumer) + + def get_input_device(self, device_name: str) -> RabbitMQInputDevice: + """ + Returns an incoming device by name + + :param device_name: the name of the device to read from + :return: an input device for 'device_name' + """ + try: + self.create_queue(queue_name=device_name, + passive=True, + direct_bind_to_exchange=self._default_direct_exchange) + + return self._device_factory(device_name) + + except Exception as e: + message = f"Couldn't create input device '{device_name}'" + self._logger.exception(message) + raise InputDeviceException(message) from e + + def connect(self): + """ + connects to the device manager + """ + try: + self._connect() + except Exception as e: + raise InputDeviceException('Could not connect to rabbitmq.') from e + + def disconnect(self): + """ + disconnects from the device manager + """ + self._disconnect() diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py new file mode 100644 index 0000000..05fdcff --- /dev/null +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -0,0 +1,81 @@ +import logging + +from messageflux.iodevices.base import OutputDevice, OutputDeviceException, OutputDeviceManager +from messageflux.iodevices.base.common import MessageBundle +from messageflux.metadata_headers import MetadataHeaders +from messageflux.utils import get_random_id + +try: + from mypy_boto3_sqs.service_resource import SQSServiceResource, Queue +except ImportError as ex: + raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + + +class SQSOutputDevice(OutputDevice['SQSOutputDeviceManager']): + """ + represents an SQS output devices + """ + + def __init__(self, device_manager: 'SQSOutputDeviceManager', queue_name: str): + """ + constructs a new output SQS device + + :param device_manager: the SQS device Manager that holds this device + :param queue_name: the name of the queue + """ + super(SQSOutputDevice, self).__init__(device_manager, queue_name) + self._queue_name = queue_name + self._logger = logging.getLogger(__name__) + + def _send_message(self, message_bundle: MessageBundle): + sqs_queue = self.manager.get_queue(self._queue_name) + + sqs_queue.send_message( + MessageBody=message_bundle.message.bytes.decode(), + MessageAttributes=message_bundle.message.headers, + MessageDeduplicationId=message_bundle.device_headers.get('message_id', + message_bundle.message.headers.get( + MetadataHeaders.ITEM_ID, get_random_id())), + ) + + +class SQSOutputDeviceManager(OutputDeviceManager[SQSOutputDevice]): + """ + this manager is used to create SQS devices + """ + + def __init__(self, sqs_resource: SQSServiceResource): + """ + This manager used to create SQS devices + + :param sqs_resource: the SQS resource from boto + """ + self._logger = logging.getLogger(__name__) + + self._sqs_resource = sqs_resource + self._queue_cache = {} + + def get_output_device(self, queue_name: str) -> SQSOutputDevice: + """ + Returns and outgoing device by name + + :param queue_name: the name of the queue + :return: an output device for 'queue_url' + """ + try: + return SQSOutputDevice(self, queue_name) + except Exception as e: + message = f"Couldn't create output device '{queue_name}'" + self._logger.exception(message) + raise OutputDeviceException(message) from e + + def get_queue(self, queue_name: str, auto_create=False) -> Queue: + """ + gets the bucket from cache + """ + queue = self._queue_cache.get(queue_name, None) + if queue is None: + queue = self._sqs_resource.get_queue_by_name(QueueName=queue_name) + self._queue_cache[queue_name] = queue + + return queue diff --git a/requirements-sqs.txt b/requirements-sqs.txt new file mode 100644 index 0000000..90c08c0 --- /dev/null +++ b/requirements-sqs.txt @@ -0,0 +1,2 @@ +boto3>=1.25,<2 +boto3-stubs[sqs]>=1.25,<2 \ No newline at end of file From 41b1eb69a7b34e224787eb8ec7d1c574b49d3b6d Mon Sep 17 00:00:00 2001 From: avivs Date: Mon, 10 Apr 2023 09:16:51 +0300 Subject: [PATCH 02/15] fixed mypy problems --- messageflux/iodevices/sqs/sqs_output_device.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index 05fdcff..f70ac1a 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -1,4 +1,5 @@ import logging +from typing import Dict from messageflux.iodevices.base import OutputDevice, OutputDeviceException, OutputDeviceManager from messageflux.iodevices.base.common import MessageBundle @@ -53,7 +54,7 @@ def __init__(self, sqs_resource: SQSServiceResource): self._logger = logging.getLogger(__name__) self._sqs_resource = sqs_resource - self._queue_cache = {} + self._queue_cache: Dict[str, Queue] = {} def get_output_device(self, queue_name: str) -> SQSOutputDevice: """ From ac199a6db093b40ef67f5bb9ae30591a64deaf5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 21:17:32 +0300 Subject: [PATCH 03/15] fix output device manager --- .../iodevices/sqs/sqs_output_device.py | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index f70ac1a..dc692f2 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -7,7 +7,8 @@ from messageflux.utils import get_random_id try: - from mypy_boto3_sqs.service_resource import SQSServiceResource, Queue + from mypy_boto3_sqs.service_resource import Queue + import boto3 except ImportError as ex: raise ImportError('Please Install the required extra: messageflux[sqs]') from ex @@ -25,19 +26,30 @@ def __init__(self, device_manager: 'SQSOutputDeviceManager', queue_name: str): :param queue_name: the name of the queue """ super(SQSOutputDevice, self).__init__(device_manager, queue_name) - self._queue_name = queue_name + self._sqs_queue = self.manager.get_queue(queue_name) + + # https://awscli.amazonaws.com/v2/documentation/api/latest/reference/sqs/get-queue-attributes.html#get-queue-attributes + self._is_fifo = queue_name.endswith(".fifo") self._logger = logging.getLogger(__name__) def _send_message(self, message_bundle: MessageBundle): - sqs_queue = self.manager.get_queue(self._queue_name) + if self._is_fifo: + response = self._sqs_queue.send_message( + MessageBody=message_bundle.message.bytes.decode(), + MessageAttributes=message_bundle.message.headers, + MessageGroupId=get_random_id(), + ) + else: + response = self._sqs_queue.send_message( + MessageBody=message_bundle.message.bytes.decode(), + MessageAttributes=message_bundle.message.headers, + ) - sqs_queue.send_message( - MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=message_bundle.message.headers, - MessageDeduplicationId=message_bundle.device_headers.get('message_id', - message_bundle.message.headers.get( - MetadataHeaders.ITEM_ID, get_random_id())), - ) + if "MessageId" not in response: + raise OutputDeviceException("Couldn't send message to SQS") + + if "Failed" in response: + raise OutputDeviceException(f"Couldn't send message to SQS: {response['Failed']}") class SQSOutputDeviceManager(OutputDeviceManager[SQSOutputDevice]): @@ -45,15 +57,13 @@ class SQSOutputDeviceManager(OutputDeviceManager[SQSOutputDevice]): this manager is used to create SQS devices """ - def __init__(self, sqs_resource: SQSServiceResource): + def __init__(self): """ This manager used to create SQS devices - - :param sqs_resource: the SQS resource from boto """ self._logger = logging.getLogger(__name__) - - self._sqs_resource = sqs_resource + + self._sqs_resource = boto3.resource('sqs') self._queue_cache: Dict[str, Queue] = {} def get_output_device(self, queue_name: str) -> SQSOutputDevice: @@ -61,7 +71,7 @@ def get_output_device(self, queue_name: str) -> SQSOutputDevice: Returns and outgoing device by name :param queue_name: the name of the queue - :return: an output device for 'queue_url' + :return: an output device for 'queue_name' """ try: return SQSOutputDevice(self, queue_name) @@ -70,9 +80,9 @@ def get_output_device(self, queue_name: str) -> SQSOutputDevice: self._logger.exception(message) raise OutputDeviceException(message) from e - def get_queue(self, queue_name: str, auto_create=False) -> Queue: + def get_queue(self, queue_name: str) -> Queue: """ - gets the bucket from cache + gets the queue from cache """ queue = self._queue_cache.get(queue_name, None) if queue is None: From c751d3eca9ba43d0a733cc8d70c77f8363c11a47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 21:57:04 +0300 Subject: [PATCH 04/15] fix sqs input device --- messageflux/iodevices/sqs/sqs_input_device.py | 394 ++++-------------- 1 file changed, 85 insertions(+), 309 deletions(-) diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index 56b3227..e4a480b 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -1,52 +1,40 @@ import logging -import os -import socket -import ssl + from io import BytesIO -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, List, Any +from typing import Dict, Optional, Union, Any from messageflux.iodevices.base import InputDevice, InputTransaction, ReadResult, InputDeviceException, Message, \ InputDeviceManager from messageflux.iodevices.base.input_transaction import NULLTransaction from messageflux.iodevices.rabbitmq.rabbitmq_device_manager_base import RabbitMQDeviceManagerBase -from messageflux.utils import ThreadLocalMember +from messageflux.iodevices.sqs.sqs_base import SQSManagerBase try: - from pika import spec + from mypy_boto3_sqs.service_resource import Queue + from mypy_boto3_sqs.client import SQSClient + import boto3 except ImportError as ex: - raise ImportError('Please Install the required extra: messageflux[rabbitmq]') from ex - -if TYPE_CHECKING: - from pika.adapters.blocking_connection import BlockingChannel + raise ImportError('Please Install the required extra: messageflux[sqs]') from ex -class RabbitMQInputTransaction(InputTransaction): +class SQSInputTransaction(InputTransaction): """ - represents a InputTransaction for RabbitMQ + represents a InputTransaction for SQS """ + _device: 'SQSInputDevice' def __init__(self, - device: 'RabbitMQInputDevice', - channel: 'BlockingChannel', + device: 'SQSInputDevice', delivery_tag: int): """ :param device: the device that returned this transaction - :param channel: the BlockingChannel that the item was read from :param delivery_tag: the delivery tag for this item """ - super(RabbitMQInputTransaction, self).__init__(device=device) - self._channel = channel + super(SQSInputTransaction, self).__init__(device=device) self._delivery_tag = delivery_tag self._logger = logging.getLogger(__name__) - @property - def channel(self) -> 'BlockingChannel': - """ - the channel that the item was read from - """ - return self._channel - @property def delivery_tag(self) -> int: """ @@ -56,313 +44,112 @@ def delivery_tag(self) -> int: def _commit(self): try: - self._channel.basic_ack(self._delivery_tag) + self._device.delete_message(self._delivery_tag) except Exception: - self._logger.warning('commit failed', exc_info=True) + self._logger.exception('commit failed') def _rollback(self): try: - self._channel.basic_nack(self._delivery_tag, requeue=True) + self._device.change_visability_timeout(self._delivery_tag, 0) except Exception: self._logger.warning('rollback failed', exc_info=True) + - -class RabbitMQInputDevice(InputDevice['RabbitMQInputDeviceManager']): +class SQSInputDevice(InputDevice['SQSInputDeviceManager']): """ - represents an RabbitMQ input device + represents an SQS input device """ - _channel: Union[ThreadLocalMember[Optional['BlockingChannel']], - Optional['BlockingChannel']] = ThreadLocalMember(init_value=None) - - @staticmethod - def _get_rabbit_headers(method_frame, header_frame): - return { - "exchange": method_frame.exchange, - "routing_key": method_frame.routing_key, - "content_type": header_frame.content_type, - "content_encoding": header_frame.content_encoding, - "priority": header_frame.priority, - "correlation_id": header_frame.correlation_id, - "reply_to": header_frame.reply_to, - "expiration": header_frame.expiration, - "message_id": header_frame.message_id, - "timestamp": header_frame.timestamp, - "type": header_frame.type, - "user_id": header_frame.user_id, - "app_id": header_frame.app_id - } - def __init__(self, - device_manager: 'RabbitMQInputDeviceManager', + device_manager: 'SQSInputDeviceManager', queue_name: str, - consumer_args: Optional[Dict[str, str]] = None, - prefetch_count: int = 1, - use_consumer: bool = True): + included_message_attributes: Optional[Union[str, list]] = None): """ - constructs a new input RabbitMQ device + constructs a new input SQS device - :param device_manager: the RabbitMQ device Manager that holds this device + :param device_manager: the SQS device Manager that holds this device :param queue_name: the name for the queue - :param consumer_args: the arguments to create the consumer with - only relevent if "use_consumer" is True - :param int prefetch_count: the number of unacked messages that can be consumed - only relevent if "use_consumer" is True - :param bool use_consumer: True to use the 'consume' method, False to use 'basic_get' - """ - super().__init__(device_manager, queue_name) - self._device_manager = device_manager - self._queue_name = queue_name - self._logger = logging.getLogger(__name__) - if consumer_args is None: - consumer_args = {'hostname': socket.gethostname(), 'PID': str(os.getpid())} - - self._consumer_args = consumer_args - self._prefetch_count = max(1, prefetch_count) - self._use_consumer = use_consumer - self._last_consumer_auto_ack: Optional[bool] = None - - def _reconnect_device_manager(self): - """ - reconnects the RabbitMQ device manager - """ - try: - if self._channel is not None and self._channel.is_open: - assert self._channel is not None - if self._use_consumer: - self._channel.cancel() - self._channel.close() - self._channel = self._device_manager.connection.channel() - - assert self._channel is not None - self._channel.basic_qos(prefetch_count=self._prefetch_count) - except Exception as e: - raise InputDeviceException('Could not connect to rabbitmq.') from e - def _get_channel(self) -> 'BlockingChannel': - """ - gets a channel """ - if self._channel is None or not self._channel.is_open: - self._reconnect_device_manager() - - assert self._channel is not None - return self._channel + super().__init__(device_manager, queue_name) + self._queue_url = self.manager.client.get_queue_url(queue_name) + self._included_message_attributes = ( + included_message_attributes + if included_message_attributes is not None + else ["All"] + ) + self._max_messages_per_request = 1 - def _get_data_from_queue(self, timeout: Optional[float], with_transaction: bool) -> Optional['ReadResult']: + def _read_message(self, timeout: Optional[float] = None, with_transaction: bool = True) -> Optional['ReadResult']: """ - performs a single read from queue + reads a stream from InputDevice (tries getting a message. if it fails, reconnects and tries again once) :param timeout: the timeout in seconds to block. negative number means no blocking - :param with_transaction: does this read is to be done with transaction? - :return: the stream and metadata, or None,None if no message in queue - """ - channel = self._get_channel() - get_timeout: Optional[float] = None - if timeout is not None: - get_timeout = max(0.01, timeout) - body: Optional[bytes] - header_frame: Optional[spec.BasicProperties] - method_frame: Optional[Union[spec.Basic.Deliver, spec.Basic.GetOk]] - - body, header_frame, method_frame = self._get_frames_from_queue(channel, - get_timeout, - with_transaction=with_transaction) - if method_frame is None: # no message in queue - return None - - assert body is not None - assert header_frame is not None - - return self._create_response_from_frames(body, - header_frame, - method_frame, - channel, - with_transaction) - - def _create_response_from_frames(self, - body: bytes, - header_frame: spec.BasicProperties, - method_frame: Union[spec.Basic.Deliver, spec.Basic.GetOk], - channel: 'BlockingChannel', - with_transaction: bool) -> ReadResult: - """ - creates the read result from the data returned from rabbitmq - - :param body: the body of the message - :param header_frame: the header frame of the message - :param method_frame: the method frame of the message - :param channel: the channel we read the message from - :param with_transaction: should we use a transaction - - :return: ReadResult object - """ - - delivery_tag = method_frame.delivery_tag - assert delivery_tag is not None - if with_transaction: - transaction: InputTransaction = RabbitMQInputTransaction(self, channel, delivery_tag) - else: - transaction = NULLTransaction(self) - headers: Dict[str, Any] = header_frame.headers or {} # type: ignore - # get the rabbitmq headers as for device headers - rabbit_headers = self._get_rabbit_headers(method_frame, header_frame) - - buf = BytesIO(body) - return ReadResult(message=Message(buf, headers), - device_headers=rabbit_headers, - transaction=transaction) - - def _get_frames_from_queue(self, - channel: 'BlockingChannel', - timeout: Optional[float], - with_transaction: bool) -> Tuple[Optional[bytes], - Optional[spec.BasicProperties], - Optional[Union[spec.Basic.Deliver, spec.Basic.GetOk]]]: - """ - gets the actual frame from queue. use_consumer effects this method - - :param BlockingChannel channel: the channel - :param float timeout: the timeout - :param bool with_transaction: do we operate within a transaction scope + :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available """ - # this could be auto_ack = not has_transaction, but it's less clear... so it's verbose here... - if with_transaction: - auto_ack = False - else: - auto_ack = True - if self._use_consumer: - if self._last_consumer_auto_ack is None: - self._last_consumer_auto_ack = auto_ack + + response = self.manager.client.receive_message( + QueueUrl=self._queue_url, + MessageAttributeNames=["All"], + MaxNumberOfMessages=self._max_messages_per_request + ) - if self._last_consumer_auto_ack != auto_ack: - channel.cancel() - self._last_consumer_auto_ack = auto_ack + sqs_messages = response["Messages"] - method_frame: Optional[Union[spec.Basic.Deliver, spec.Basic.GetOk]] - header_frame: Optional[spec.BasicProperties] - body: Optional[bytes] - method_frame, header_frame, body = next(channel.consume(queue=self._queue_name, - inactivity_timeout=timeout, - arguments=self._consumer_args, auto_ack=auto_ack)) - else: - method_frame, header_frame, body = channel.basic_get(queue=self._queue_name, # type: ignore - auto_ack=auto_ack) + if sqs_messages: + assert len(sqs_messages) == 1, "SQSDevice should only return one message at a time" + sqs_message = sqs_messages[0] - return body, header_frame, method_frame + return ReadResult( + message=Message( + headers=sqs_message["MessageAttributes"], + data=BytesIO(sqs_message["Body"].encode()), + ), + transaction = SQSInputTransaction( + device=self, + delivery_tag=sqs_message["ReceiptHandle"], + ) if with_transaction else NULLTransaction(), + ) + - def _read_message(self, timeout: Optional[float] = None, with_transaction: bool = True) -> Optional['ReadResult']: + def delete_message(self, receipt_handle: str): """ - reads a stream from InputDevice (tries getting a message. if it fails, reconnects and tries again once) + deletes a message from the queue - :param timeout: the timeout in seconds to block. negative number means no blocking - :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available + :param receipt_handle: the receipt handle of the message """ - try: - from pika.exceptions import AMQPConnectionError, AMQPChannelError - except ImportError as exc: - raise ImportError('Please Install the required extra: messageflux[rabbitmq]') from exc - - try: - return self._get_data_from_queue(timeout=timeout, with_transaction=with_transaction) - except (AMQPConnectionError, AMQPChannelError): - self._reconnect_device_manager() - try: - return self._get_data_from_queue(timeout=timeout, with_transaction=with_transaction) - except Exception: - self._logger.exception(f"AMQError thrown. failed to get message. device name: {self._queue_name}") - raise - except Exception as e: - raise InputDeviceException('Error reading from device') from e + self.manager.client.delete_message( + QueueUrl=self._queue_url, + ReceiptHandle=receipt_handle + ) - def close(self): + def change_visability_timeout(self, receipt_handle: str, timeout: int): """ - closes the connection to device - """ - try: - if self._channel is not None and self._channel.is_open: - if self._use_consumer: - self._channel.cancel() - self._channel.close() - except Exception: - self._logger.warning('Error Closing Device', exc_info=True) - - self._channel = None + changes the visibility timeout of a message + :param receipt_handle: the receipt handle of the message + :param timeout: the new timeout in seconds + """ + self.manager.client.change_message_visibility( + QueueUrl=self._queue_url, + ReceiptHandle=receipt_handle, + VisibilityTimeout=timeout + ) -class RabbitMQInputDeviceManager(RabbitMQDeviceManagerBase, InputDeviceManager[RabbitMQInputDevice]): +class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): """ - rabbitmq input device manager + SQS input device manager """ - def __init__(self, - hosts: Union[List[str], str], - user: str, - password: str, - port: Optional[int] = None, - ssl_context: Optional[ssl.SSLContext] = None, - virtual_host: Optional[str] = None, - client_args: Optional[Dict[str, str]] = None, - heartbeat: int = 300, - connection_attempts: int = 5, - prefetch_count: int = 1, - use_consumer: bool = True, - blocked_connection_timeout: Optional[float] = None, - default_direct_exchange: Optional[str] = None - ): - """ - This manager used to create RabbitMQ devices (direct queues) - - :param hosts: the hostname or a list of hostnames of the manager - :param user: the username for the rabbitMQ manager - :param password: the password for the rabbitMQ manager - :param port: the port to connect the hosts to - :param ssl_context: the ssl context to use. None means don't use ssl at all - :param virtual_host: the virtual host to connect to - :param client_args: the arguments to create the client with - :param int heartbeat: heartbeat interval for the connection (between 0 and 65536 - :param int connection_attempts: Maximum number of retry attempts - (-1 means not to handle poison messages at all, 0 means reject all redelivered messages right away) - :param int prefetch_count: the number of unacked messages that can be consumed - :param bool use_consumer: True to use the 'consume' method, False to use 'basic_get' - :param blocked_connection_timeout: If not None, - the value is a non-negative timeout, in seconds, for the - connection to remain blocked (triggered by Connection.Blocked from - broker); if the timeout expires before connection becomes unblocked, - the connection will be torn down, triggering the adapter-specific - mechanism for informing client app about the closed connection: - passing `ConnectionBlockedTimeout` exception to on_close_callback - in asynchronous adapters or raising it in `BlockingConnection`. - - :param default_direct_exchange: optional direct exchange to bind all the queues to (None means no bind) - """ - super().__init__(hosts=hosts, - user=user, - password=password, - port=port, - ssl_context=ssl_context, - virtual_host=virtual_host, - client_args=client_args, - connection_type="Input", - heartbeat=heartbeat, - connection_attempts=connection_attempts, - blocked_connection_timeout=blocked_connection_timeout) - - self._prefetch_count = prefetch_count - self._use_consumer = use_consumer - self._default_direct_exchange = default_direct_exchange - - def _device_factory(self, device_name: str) -> RabbitMQInputDevice: - - return RabbitMQInputDevice(device_manager=self, - queue_name=device_name, - consumer_args=self._client_args, - prefetch_count=self._prefetch_count, - use_consumer=self._use_consumer) - - def get_input_device(self, device_name: str) -> RabbitMQInputDevice: + def __init__(self): + self._sqs_client = boto3.client('sqs') + self._queue_cache: Dict[str, Queue] = {} + self._logger = logging.getLogger(__name__) + + + def get_input_device(self, device_name: str) -> SQSInputDevice: """ Returns an incoming device by name @@ -370,28 +157,17 @@ def get_input_device(self, device_name: str) -> RabbitMQInputDevice: :return: an input device for 'device_name' """ try: - self.create_queue(queue_name=device_name, - passive=True, - direct_bind_to_exchange=self._default_direct_exchange) + return SQSInputDevice(self, device_name) - return self._device_factory(device_name) except Exception as e: message = f"Couldn't create input device '{device_name}'" self._logger.exception(message) raise InputDeviceException(message) from e - def connect(self): - """ - connects to the device manager - """ - try: - self._connect() - except Exception as e: - raise InputDeviceException('Could not connect to rabbitmq.') from e - - def disconnect(self): + @property + def client(self) -> SQSClient: """ - disconnects from the device manager + returns the sqs client """ - self._disconnect() + return self._sqs_client \ No newline at end of file From 11353b32fc05c73b1f77c9f9528d8dbe0295dbca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 22:29:46 +0300 Subject: [PATCH 05/15] Add message attributes --- .../iodevices/sqs/message_attributes.py | 36 ++++++ messageflux/iodevices/sqs/sqs_input_device.py | 109 +++++++++--------- messageflux/iodevices/sqs/sqs_manager_base.py | 27 +++++ .../iodevices/sqs/sqs_output_device.py | 8 +- 4 files changed, 119 insertions(+), 61 deletions(-) create mode 100644 messageflux/iodevices/sqs/message_attributes.py create mode 100644 messageflux/iodevices/sqs/sqs_manager_base.py diff --git a/messageflux/iodevices/sqs/message_attributes.py b/messageflux/iodevices/sqs/message_attributes.py new file mode 100644 index 0000000..54e0b2e --- /dev/null +++ b/messageflux/iodevices/sqs/message_attributes.py @@ -0,0 +1,36 @@ +import json + +from typing import Any, TypedDict + + +try: + from mypy_boto3_sqs.type_defs import MessageAttributeValueTypeDef +except ImportError as ex: + raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + + +def get_aws_data_type(value: Any) -> str: + if isinstance(value, (list, set, frozenset, tuple)): + return "String.Array" + elif isinstance(value, bool): + return "String" + elif isinstance(value, (int, float)): + return "Number" + elif isinstance(value, bytes): + return "Binary" + else: + return "String" + + +def geterate_message_attributes( + attributes: dict[str, Any] +) -> dict[str, MessageAttributeValueTypeDef]: + return { + key: { + "DataType": get_aws_data_type(value), + "StringValue": json.dumps(value) + if not isinstance(value, str) + else value, # to avoid double encoding + } + for key, value in attributes.items() + } \ No newline at end of file diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index e4a480b..bcd1b07 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -1,18 +1,20 @@ import logging +import threading from io import BytesIO -from typing import Dict, Optional, Union, Any +from typing import Dict, Optional, Union from messageflux.iodevices.base import InputDevice, InputTransaction, ReadResult, InputDeviceException, Message, \ InputDeviceManager from messageflux.iodevices.base.input_transaction import NULLTransaction -from messageflux.iodevices.rabbitmq.rabbitmq_device_manager_base import RabbitMQDeviceManagerBase -from messageflux.iodevices.sqs.sqs_base import SQSManagerBase +from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase +from messageflux.utils import get_random_id + try: + import boto3 from mypy_boto3_sqs.service_resource import Queue from mypy_boto3_sqs.client import SQSClient - import boto3 except ImportError as ex: raise ImportError('Please Install the required extra: messageflux[sqs]') from ex @@ -25,32 +27,25 @@ class SQSInputTransaction(InputTransaction): def __init__(self, device: 'SQSInputDevice', - delivery_tag: int): + receipt_handle: str): """ :param device: the device that returned this transaction - :param delivery_tag: the delivery tag for this item + :param receipt_handle: the receipt handle for this item """ super(SQSInputTransaction, self).__init__(device=device) - self._delivery_tag = delivery_tag + self._receipt_handle = receipt_handle self._logger = logging.getLogger(__name__) - @property - def delivery_tag(self) -> int: - """ - the delivery tag for this item - """ - return self._delivery_tag - def _commit(self): try: - self._device.delete_message(self._delivery_tag) + self._device.delete_message(self._receipt_handle) except Exception: self._logger.exception('commit failed') def _rollback(self): try: - self._device.change_visability_timeout(self._delivery_tag, 0) + self._device.change_visability_timeout(self._receipt_handle, 0) except Exception: self._logger.warning('rollback failed', exc_info=True) @@ -72,57 +67,63 @@ def __init__(self, """ super().__init__(device_manager, queue_name) - self._queue_url = self.manager.client.get_queue_url(queue_name) self._included_message_attributes = ( included_message_attributes if included_message_attributes is not None else ["All"] ) self._max_messages_per_request = 1 + self._queue = self.manager.get_queue(queue_name) - def _read_message(self, timeout: Optional[float] = None, with_transaction: bool = True) -> Optional['ReadResult']: + def _read_message(self, cancellation_token: threading.Event, + timeout: Optional[float] = None, + with_transaction: bool = True) -> Optional['ReadResult']: """ reads a stream from InputDevice (tries getting a message. if it fails, reconnects and tries again once) :param timeout: the timeout in seconds to block. negative number means no blocking :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available """ - - response = self.manager.client.receive_message( - QueueUrl=self._queue_url, + sqs_messages = self._queue.receive_messages( MessageAttributeNames=["All"], MaxNumberOfMessages=self._max_messages_per_request ) - sqs_messages = response["Messages"] - - - if sqs_messages: - assert len(sqs_messages) == 1, "SQSDevice should only return one message at a time" - sqs_message = sqs_messages[0] - - return ReadResult( - message=Message( - headers=sqs_message["MessageAttributes"], - data=BytesIO(sqs_message["Body"].encode()), - ), - transaction = SQSInputTransaction( - device=self, - delivery_tag=sqs_message["ReceiptHandle"], - ) if with_transaction else NULLTransaction(), - ) + if not sqs_messages: + return None - + assert len(sqs_messages) == 1, "SQSInputDevice should only return one message at a time" + + sqs_message = sqs_messages[0] + + return ReadResult( + message=Message( + headers={ + key: value["StringValue"] + for key, value in sqs_message.message_attributes.items() + }, + data=BytesIO(sqs_message.body.encode()), + ), + transaction = SQSInputTransaction( + device=self, + receipt_handle=sqs_message.receipt_handle, + ) if with_transaction else NULLTransaction(self), + ) + def delete_message(self, receipt_handle: str): """ deletes a message from the queue :param receipt_handle: the receipt handle of the message """ - self.manager.client.delete_message( - QueueUrl=self._queue_url, - ReceiptHandle=receipt_handle + self._queue.delete_messages( + Entries=[ + { + "Id": get_random_id(), + "ReceiptHandle": receipt_handle + } + ] ) def change_visability_timeout(self, receipt_handle: str, timeout: int): @@ -132,10 +133,14 @@ def change_visability_timeout(self, receipt_handle: str, timeout: int): :param receipt_handle: the receipt handle of the message :param timeout: the new timeout in seconds """ - self.manager.client.change_message_visibility( - QueueUrl=self._queue_url, - ReceiptHandle=receipt_handle, - VisibilityTimeout=timeout + self._queue.change_message_visibility_batch( + Entries=[ + { + "Id": get_random_id(), + "ReceiptHandle": receipt_handle, + "VisibilityTimeout": timeout + } + ] ) class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): @@ -144,9 +149,7 @@ class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): """ def __init__(self): - self._sqs_client = boto3.client('sqs') - self._queue_cache: Dict[str, Queue] = {} - self._logger = logging.getLogger(__name__) + super(SQSManagerBase, self).__init__() def get_input_device(self, device_name: str) -> SQSInputDevice: @@ -159,15 +162,7 @@ def get_input_device(self, device_name: str) -> SQSInputDevice: try: return SQSInputDevice(self, device_name) - except Exception as e: message = f"Couldn't create input device '{device_name}'" self._logger.exception(message) raise InputDeviceException(message) from e - - @property - def client(self) -> SQSClient: - """ - returns the sqs client - """ - return self._sqs_client \ No newline at end of file diff --git a/messageflux/iodevices/sqs/sqs_manager_base.py b/messageflux/iodevices/sqs/sqs_manager_base.py new file mode 100644 index 0000000..0713d86 --- /dev/null +++ b/messageflux/iodevices/sqs/sqs_manager_base.py @@ -0,0 +1,27 @@ +import logging + +from typing import Dict + +try: + import boto3 + from mypy_boto3_sqs.service_resource import Queue +except ImportError as ex: + raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + + +class SQSManagerBase: + def __init__(self) -> None: + self._sqs_resource = boto3.resource('sqs') + self._queue_cache: Dict[str, Queue] = {} + self._logger = logging.getLogger(__name__) + + def get_queue(self, queue_name: str) -> Queue: + """ + gets the queue from cache + """ + queue = self._queue_cache.get(queue_name, None) + if queue is None: + queue = self._sqs_resource.get_queue_by_name(QueueName=queue_name) + self._queue_cache[queue_name] = queue + + return queue diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index dc692f2..b389f82 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -3,12 +3,12 @@ from messageflux.iodevices.base import OutputDevice, OutputDeviceException, OutputDeviceManager from messageflux.iodevices.base.common import MessageBundle -from messageflux.metadata_headers import MetadataHeaders +from messageflux.iodevices.sqs.message_attributes import geterate_message_attributes from messageflux.utils import get_random_id try: - from mypy_boto3_sqs.service_resource import Queue import boto3 + from mypy_boto3_sqs.service_resource import Queue except ImportError as ex: raise ImportError('Please Install the required extra: messageflux[sqs]') from ex @@ -36,13 +36,13 @@ def _send_message(self, message_bundle: MessageBundle): if self._is_fifo: response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=message_bundle.message.headers, + MessageAttributes=geterate_message_attributes(message_bundle.message.headers), MessageGroupId=get_random_id(), ) else: response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=message_bundle.message.headers, + MessageAttributes=geterate_message_attributes(message_bundle.message.headers), ) if "MessageId" not in response: From 5c517e66e12b4f08819675f112d78d6815040472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 22:30:59 +0300 Subject: [PATCH 06/15] Fix PR --- messageflux/iodevices/sqs/sqs_output_device.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index b389f82..d0c2281 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -48,9 +48,6 @@ def _send_message(self, message_bundle: MessageBundle): if "MessageId" not in response: raise OutputDeviceException("Couldn't send message to SQS") - if "Failed" in response: - raise OutputDeviceException(f"Couldn't send message to SQS: {response['Failed']}") - class SQSOutputDeviceManager(OutputDeviceManager[SQSOutputDevice]): """ From 54d1f40f5820df7993473d1227db84e13a2af8d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 22:32:39 +0300 Subject: [PATCH 07/15] Fix mypy --- messageflux/iodevices/sqs/message_attributes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/messageflux/iodevices/sqs/message_attributes.py b/messageflux/iodevices/sqs/message_attributes.py index 54e0b2e..a03b537 100644 --- a/messageflux/iodevices/sqs/message_attributes.py +++ b/messageflux/iodevices/sqs/message_attributes.py @@ -1,10 +1,10 @@ import json -from typing import Any, TypedDict +from typing import Any, Dict try: - from mypy_boto3_sqs.type_defs import MessageAttributeValueTypeDef + from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef except ImportError as ex: raise ImportError('Please Install the required extra: messageflux[sqs]') from ex @@ -23,8 +23,8 @@ def get_aws_data_type(value: Any) -> str: def geterate_message_attributes( - attributes: dict[str, Any] -) -> dict[str, MessageAttributeValueTypeDef]: + attributes: Dict[str, Any] +) -> Dict[str, MessageAttributeValueQueueTypeDef]: return { key: { "DataType": get_aws_data_type(value), From 7a335c7d3f939875896a372ad73e2f6b6c653db8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 22:39:24 +0300 Subject: [PATCH 08/15] lint --- .../iodevices/sqs/message_attributes.py | 6 +- messageflux/iodevices/sqs/sqs_input_device.py | 75 +++++++++++-------- messageflux/iodevices/sqs/sqs_manager_base.py | 4 +- .../iodevices/sqs/sqs_output_device.py | 28 ++++--- 4 files changed, 65 insertions(+), 48 deletions(-) diff --git a/messageflux/iodevices/sqs/message_attributes.py b/messageflux/iodevices/sqs/message_attributes.py index a03b537..5079900 100644 --- a/messageflux/iodevices/sqs/message_attributes.py +++ b/messageflux/iodevices/sqs/message_attributes.py @@ -6,7 +6,7 @@ try: from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef except ImportError as ex: - raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + raise ImportError("Please Install the required extra: messageflux[sqs]") from ex def get_aws_data_type(value: Any) -> str: @@ -30,7 +30,7 @@ def geterate_message_attributes( "DataType": get_aws_data_type(value), "StringValue": json.dumps(value) if not isinstance(value, str) - else value, # to avoid double encoding + else value, # to avoid double encoding } for key, value in attributes.items() - } \ No newline at end of file + } diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index bcd1b07..893eb57 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -4,8 +4,14 @@ from io import BytesIO from typing import Dict, Optional, Union -from messageflux.iodevices.base import InputDevice, InputTransaction, ReadResult, InputDeviceException, Message, \ - InputDeviceManager +from messageflux.iodevices.base import ( + InputDevice, + InputTransaction, + ReadResult, + InputDeviceException, + Message, + InputDeviceManager, +) from messageflux.iodevices.base.input_transaction import NULLTransaction from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase from messageflux.utils import get_random_id @@ -16,18 +22,17 @@ from mypy_boto3_sqs.service_resource import Queue from mypy_boto3_sqs.client import SQSClient except ImportError as ex: - raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + raise ImportError("Please Install the required extra: messageflux[sqs]") from ex class SQSInputTransaction(InputTransaction): """ represents a InputTransaction for SQS """ - _device: 'SQSInputDevice' - def __init__(self, - device: 'SQSInputDevice', - receipt_handle: str): + _device: "SQSInputDevice" + + def __init__(self, device: "SQSInputDevice", receipt_handle: str): """ :param device: the device that returned this transaction @@ -41,24 +46,26 @@ def _commit(self): try: self._device.delete_message(self._receipt_handle) except Exception: - self._logger.exception('commit failed') + self._logger.exception("commit failed") def _rollback(self): try: self._device.change_visability_timeout(self._receipt_handle, 0) except Exception: - self._logger.warning('rollback failed', exc_info=True) - + self._logger.warning("rollback failed", exc_info=True) + -class SQSInputDevice(InputDevice['SQSInputDeviceManager']): +class SQSInputDevice(InputDevice["SQSInputDeviceManager"]): """ represents an SQS input device """ - def __init__(self, - device_manager: 'SQSInputDeviceManager', - queue_name: str, - included_message_attributes: Optional[Union[str, list]] = None): + def __init__( + self, + device_manager: "SQSInputDeviceManager", + queue_name: str, + included_message_attributes: Optional[Union[str, list]] = None, + ): """ constructs a new input SQS device @@ -75,25 +82,30 @@ def __init__(self, self._max_messages_per_request = 1 self._queue = self.manager.get_queue(queue_name) - def _read_message(self, cancellation_token: threading.Event, - timeout: Optional[float] = None, - with_transaction: bool = True) -> Optional['ReadResult']: + def _read_message( + self, + cancellation_token: threading.Event, + timeout: Optional[float] = None, + with_transaction: bool = True, + ) -> Optional["ReadResult"]: """ reads a stream from InputDevice (tries getting a message. if it fails, reconnects and tries again once) :param timeout: the timeout in seconds to block. negative number means no blocking :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available """ - + sqs_messages = self._queue.receive_messages( MessageAttributeNames=["All"], - MaxNumberOfMessages=self._max_messages_per_request + MaxNumberOfMessages=self._max_messages_per_request, ) if not sqs_messages: return None - - assert len(sqs_messages) == 1, "SQSInputDevice should only return one message at a time" + + assert ( + len(sqs_messages) == 1 + ), "SQSInputDevice should only return one message at a time" sqs_message = sqs_messages[0] @@ -105,12 +117,14 @@ def _read_message(self, cancellation_token: threading.Event, }, data=BytesIO(sqs_message.body.encode()), ), - transaction = SQSInputTransaction( + transaction=SQSInputTransaction( device=self, receipt_handle=sqs_message.receipt_handle, - ) if with_transaction else NULLTransaction(self), + ) + if with_transaction + else NULLTransaction(self), ) - + def delete_message(self, receipt_handle: str): """ deletes a message from the queue @@ -118,12 +132,7 @@ def delete_message(self, receipt_handle: str): :param receipt_handle: the receipt handle of the message """ self._queue.delete_messages( - Entries=[ - { - "Id": get_random_id(), - "ReceiptHandle": receipt_handle - } - ] + Entries=[{"Id": get_random_id(), "ReceiptHandle": receipt_handle}] ) def change_visability_timeout(self, receipt_handle: str, timeout: int): @@ -138,11 +147,12 @@ def change_visability_timeout(self, receipt_handle: str, timeout: int): { "Id": get_random_id(), "ReceiptHandle": receipt_handle, - "VisibilityTimeout": timeout + "VisibilityTimeout": timeout, } ] ) + class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): """ SQS input device manager @@ -151,7 +161,6 @@ class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): def __init__(self): super(SQSManagerBase, self).__init__() - def get_input_device(self, device_name: str) -> SQSInputDevice: """ Returns an incoming device by name diff --git a/messageflux/iodevices/sqs/sqs_manager_base.py b/messageflux/iodevices/sqs/sqs_manager_base.py index 0713d86..365e3cb 100644 --- a/messageflux/iodevices/sqs/sqs_manager_base.py +++ b/messageflux/iodevices/sqs/sqs_manager_base.py @@ -6,12 +6,12 @@ import boto3 from mypy_boto3_sqs.service_resource import Queue except ImportError as ex: - raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + raise ImportError("Please Install the required extra: messageflux[sqs]") from ex class SQSManagerBase: def __init__(self) -> None: - self._sqs_resource = boto3.resource('sqs') + self._sqs_resource = boto3.resource("sqs") self._queue_cache: Dict[str, Queue] = {} self._logger = logging.getLogger(__name__) diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index d0c2281..af9dfdd 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -1,7 +1,11 @@ import logging from typing import Dict -from messageflux.iodevices.base import OutputDevice, OutputDeviceException, OutputDeviceManager +from messageflux.iodevices.base import ( + OutputDevice, + OutputDeviceException, + OutputDeviceManager, +) from messageflux.iodevices.base.common import MessageBundle from messageflux.iodevices.sqs.message_attributes import geterate_message_attributes from messageflux.utils import get_random_id @@ -10,15 +14,15 @@ import boto3 from mypy_boto3_sqs.service_resource import Queue except ImportError as ex: - raise ImportError('Please Install the required extra: messageflux[sqs]') from ex + raise ImportError("Please Install the required extra: messageflux[sqs]") from ex -class SQSOutputDevice(OutputDevice['SQSOutputDeviceManager']): +class SQSOutputDevice(OutputDevice["SQSOutputDeviceManager"]): """ represents an SQS output devices """ - def __init__(self, device_manager: 'SQSOutputDeviceManager', queue_name: str): + def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str): """ constructs a new output SQS device @@ -27,22 +31,26 @@ def __init__(self, device_manager: 'SQSOutputDeviceManager', queue_name: str): """ super(SQSOutputDevice, self).__init__(device_manager, queue_name) self._sqs_queue = self.manager.get_queue(queue_name) - + # https://awscli.amazonaws.com/v2/documentation/api/latest/reference/sqs/get-queue-attributes.html#get-queue-attributes - self._is_fifo = queue_name.endswith(".fifo") + self._is_fifo = queue_name.endswith(".fifo") self._logger = logging.getLogger(__name__) def _send_message(self, message_bundle: MessageBundle): if self._is_fifo: response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=geterate_message_attributes(message_bundle.message.headers), + MessageAttributes=geterate_message_attributes( + message_bundle.message.headers + ), MessageGroupId=get_random_id(), ) else: response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=geterate_message_attributes(message_bundle.message.headers), + MessageAttributes=geterate_message_attributes( + message_bundle.message.headers + ), ) if "MessageId" not in response: @@ -59,8 +67,8 @@ def __init__(self): This manager used to create SQS devices """ self._logger = logging.getLogger(__name__) - - self._sqs_resource = boto3.resource('sqs') + + self._sqs_resource = boto3.resource("sqs") self._queue_cache: Dict[str, Queue] = {} def get_output_device(self, queue_name: str) -> SQSOutputDevice: From 02b1e2c52a6ed4a4b1eeb860b3cd891bbdd93104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CNoam?= <“noamisraeli97@gmail.com”> Date: Sun, 13 Aug 2023 22:40:38 +0300 Subject: [PATCH 09/15] Remove imports --- messageflux/iodevices/sqs/sqs_input_device.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index 893eb57..604898e 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -2,7 +2,7 @@ import threading from io import BytesIO -from typing import Dict, Optional, Union +from typing import Optional, Union from messageflux.iodevices.base import ( InputDevice, @@ -17,14 +17,6 @@ from messageflux.utils import get_random_id -try: - import boto3 - from mypy_boto3_sqs.service_resource import Queue - from mypy_boto3_sqs.client import SQSClient -except ImportError as ex: - raise ImportError("Please Install the required extra: messageflux[sqs]") from ex - - class SQSInputTransaction(InputTransaction): """ represents a InputTransaction for SQS From 0511668700ec707a53e004431efef6a585fb7ddd Mon Sep 17 00:00:00 2001 From: avivs Date: Mon, 14 Aug 2023 17:20:20 +0300 Subject: [PATCH 10/15] style fix and comments --- .../iodevices/sqs/message_attributes.py | 15 +-- messageflux/iodevices/sqs/sqs_input_device.py | 114 ++++++++---------- messageflux/iodevices/sqs/sqs_manager_base.py | 28 +++-- .../iodevices/sqs/sqs_output_device.py | 42 +++---- pyproject.toml | 2 + requirements-sqs.txt | 1 - requirements-sqs_mypy.txt | 2 + 7 files changed, 89 insertions(+), 115 deletions(-) create mode 100644 requirements-sqs_mypy.txt diff --git a/messageflux/iodevices/sqs/message_attributes.py b/messageflux/iodevices/sqs/message_attributes.py index 5079900..53ce274 100644 --- a/messageflux/iodevices/sqs/message_attributes.py +++ b/messageflux/iodevices/sqs/message_attributes.py @@ -1,12 +1,9 @@ import json -from typing import Any, Dict +from typing import Any, Dict, TYPE_CHECKING - -try: +if TYPE_CHECKING: from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef -except ImportError as ex: - raise ImportError("Please Install the required extra: messageflux[sqs]") from ex def get_aws_data_type(value: Any) -> str: @@ -22,15 +19,11 @@ def get_aws_data_type(value: Any) -> str: return "String" -def geterate_message_attributes( - attributes: Dict[str, Any] -) -> Dict[str, MessageAttributeValueQueueTypeDef]: +def generate_message_attributes(attributes: Dict[str, Any]) -> Dict[str, 'MessageAttributeValueQueueTypeDef']: return { key: { "DataType": get_aws_data_type(value), - "StringValue": json.dumps(value) - if not isinstance(value, str) - else value, # to avoid double encoding + "StringValue": value if isinstance(value, str) else json.dumps(value) # to avoid double encoding } for key, value in attributes.items() } diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index 604898e..988e79f 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -1,8 +1,7 @@ import logging import threading - from io import BytesIO -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING from messageflux.iodevices.base import ( InputDevice, @@ -14,7 +13,9 @@ ) from messageflux.iodevices.base.input_transaction import NULLTransaction from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase -from messageflux.utils import get_random_id + +if TYPE_CHECKING: + from mypy_boto3_sqs.service_resource import SQSServiceResource, Message as SQSMessage class SQSInputTransaction(InputTransaction): @@ -24,25 +25,25 @@ class SQSInputTransaction(InputTransaction): _device: "SQSInputDevice" - def __init__(self, device: "SQSInputDevice", receipt_handle: str): + def __init__(self, device: "SQSInputDevice", message: 'SQSMessage'): """ :param device: the device that returned this transaction - :param receipt_handle: the receipt handle for this item + :param message: the received message """ super(SQSInputTransaction, self).__init__(device=device) - self._receipt_handle = receipt_handle + self._message = message self._logger = logging.getLogger(__name__) def _commit(self): try: - self._device.delete_message(self._receipt_handle) + self._message.delete() except Exception: self._logger.exception("commit failed") def _rollback(self): try: - self._device.change_visability_timeout(self._receipt_handle, 0) + self._message.change_visibility(VisibilityTimeout=0) except Exception: self._logger.warning("rollback failed", exc_info=True) @@ -53,10 +54,10 @@ class SQSInputDevice(InputDevice["SQSInputDeviceManager"]): """ def __init__( - self, - device_manager: "SQSInputDeviceManager", - queue_name: str, - included_message_attributes: Optional[Union[str, list]] = None, + self, + device_manager: "SQSInputDeviceManager", + queue_name: str, + included_message_attributes: Optional[Union[str, list]] = None, # TODO: what's this? ): """ constructs a new input SQS device @@ -66,19 +67,19 @@ def __init__( """ super().__init__(device_manager, queue_name) - self._included_message_attributes = ( - included_message_attributes - if included_message_attributes is not None - else ["All"] - ) - self._max_messages_per_request = 1 + + if included_message_attributes is None: + included_message_attributes = ["All"] + + self._included_message_attributes = included_message_attributes + self._max_messages_per_request = 1 # TODO: get this in manager self._queue = self.manager.get_queue(queue_name) def _read_message( - self, - cancellation_token: threading.Event, - timeout: Optional[float] = None, - with_transaction: bool = True, + self, + cancellation_token: threading.Event, + timeout: Optional[float] = None, + with_transaction: bool = True, ) -> Optional["ReadResult"]: """ reads a stream from InputDevice (tries getting a message. if it fails, reconnects and tries again once) @@ -86,62 +87,42 @@ def _read_message( :param timeout: the timeout in seconds to block. negative number means no blocking :return: a tuple of stream and metadata from InputDevice, or (None, None) if no message is available """ - - sqs_messages = self._queue.receive_messages( - MessageAttributeNames=["All"], - MaxNumberOfMessages=self._max_messages_per_request, - ) + if timeout is None: + sqs_messages = self._queue.receive_messages( + MessageAttributeNames=self._included_message_attributes, + MaxNumberOfMessages=self._max_messages_per_request, + ) # TODO: what's the visibility timeout? should we extend it? + else: + sqs_messages = self._queue.receive_messages( + MessageAttributeNames=self._included_message_attributes, + MaxNumberOfMessages=self._max_messages_per_request, + WaitTimeSeconds=int(timeout) + ) # TODO: what's the visibility timeout? should we extend it? if not sqs_messages: return None - assert ( - len(sqs_messages) == 1 - ), "SQSInputDevice should only return one message at a time" + assert (len(sqs_messages) == 1), "SQSInputDevice should only return one message at a time" sqs_message = sqs_messages[0] + transaction: InputTransaction + if with_transaction: + transaction = SQSInputTransaction(device=self, + message=sqs_message) + else: + transaction = NULLTransaction(self) + sqs_message.delete() + return ReadResult( message=Message( headers={ - key: value["StringValue"] + key: value["BinaryValue"] if value['DataType'] == "Binary" else value['StringValue'] for key, value in sqs_message.message_attributes.items() }, data=BytesIO(sqs_message.body.encode()), ), - transaction=SQSInputTransaction( - device=self, - receipt_handle=sqs_message.receipt_handle, - ) - if with_transaction - else NULLTransaction(self), - ) - - def delete_message(self, receipt_handle: str): - """ - deletes a message from the queue - - :param receipt_handle: the receipt handle of the message - """ - self._queue.delete_messages( - Entries=[{"Id": get_random_id(), "ReceiptHandle": receipt_handle}] - ) - - def change_visability_timeout(self, receipt_handle: str, timeout: int): - """ - changes the visibility timeout of a message - - :param receipt_handle: the receipt handle of the message - :param timeout: the new timeout in seconds - """ - self._queue.change_message_visibility_batch( - Entries=[ - { - "Id": get_random_id(), - "ReceiptHandle": receipt_handle, - "VisibilityTimeout": timeout, - } - ] + transaction=transaction ) @@ -150,8 +131,8 @@ class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): SQS input device manager """ - def __init__(self): - super(SQSManagerBase, self).__init__() + def __init__(self, sqs_resource: 'SQSServiceResource'): + super().__init__(sqs_resource=sqs_resource) def get_input_device(self, device_name: str) -> SQSInputDevice: """ @@ -162,7 +143,6 @@ def get_input_device(self, device_name: str) -> SQSInputDevice: """ try: return SQSInputDevice(self, device_name) - except Exception as e: message = f"Couldn't create input device '{device_name}'" self._logger.exception(message) diff --git a/messageflux/iodevices/sqs/sqs_manager_base.py b/messageflux/iodevices/sqs/sqs_manager_base.py index 365e3cb..5282580 100644 --- a/messageflux/iodevices/sqs/sqs_manager_base.py +++ b/messageflux/iodevices/sqs/sqs_manager_base.py @@ -1,21 +1,29 @@ import logging -from typing import Dict +from typing import Dict, TYPE_CHECKING -try: - import boto3 - from mypy_boto3_sqs.service_resource import Queue -except ImportError as ex: - raise ImportError("Please Install the required extra: messageflux[sqs]") from ex +# try: +# import boto3 +# except ImportError as ex: +# raise ImportError("Please Install the required extra: messageflux[sqs]") from ex + +if TYPE_CHECKING: + from mypy_boto3_sqs.service_resource import Queue, SQSServiceResource class SQSManagerBase: - def __init__(self) -> None: - self._sqs_resource = boto3.resource("sqs") - self._queue_cache: Dict[str, Queue] = {} + """ + base class for sqs device managers + """ + def __init__(self, sqs_resource: 'SQSServiceResource') -> None: + """ + :param sqs_resource: the boto sqs service resource + """ + self._sqs_resource = sqs_resource + self._queue_cache: Dict[str, 'Queue'] = {} self._logger = logging.getLogger(__name__) - def get_queue(self, queue_name: str) -> Queue: + def get_queue(self, queue_name: str) -> 'Queue': """ gets the queue from cache """ diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index af9dfdd..95ebd0d 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -1,5 +1,5 @@ import logging -from typing import Dict +from typing import TYPE_CHECKING from messageflux.iodevices.base import ( OutputDevice, @@ -7,14 +7,17 @@ OutputDeviceManager, ) from messageflux.iodevices.base.common import MessageBundle -from messageflux.iodevices.sqs.message_attributes import geterate_message_attributes +from messageflux.iodevices.sqs.message_attributes import generate_message_attributes +from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase from messageflux.utils import get_random_id -try: - import boto3 - from mypy_boto3_sqs.service_resource import Queue -except ImportError as ex: - raise ImportError("Please Install the required extra: messageflux[sqs]") from ex +# try: +# import boto3 +# except ImportError as ex: +# raise ImportError("Please Install the required extra: messageflux[sqs]") from ex +# +if TYPE_CHECKING: + from mypy_boto3_sqs.service_resource import SQSServiceResource class SQSOutputDevice(OutputDevice["SQSOutputDeviceManager"]): @@ -40,7 +43,7 @@ def _send_message(self, message_bundle: MessageBundle): if self._is_fifo: response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=geterate_message_attributes( + MessageAttributes=generate_message_attributes( message_bundle.message.headers ), MessageGroupId=get_random_id(), @@ -48,7 +51,7 @@ def _send_message(self, message_bundle: MessageBundle): else: response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), - MessageAttributes=geterate_message_attributes( + MessageAttributes=generate_message_attributes( message_bundle.message.headers ), ) @@ -57,20 +60,18 @@ def _send_message(self, message_bundle: MessageBundle): raise OutputDeviceException("Couldn't send message to SQS") -class SQSOutputDeviceManager(OutputDeviceManager[SQSOutputDevice]): +class SQSOutputDeviceManager(SQSManagerBase, OutputDeviceManager[SQSOutputDevice]): """ this manager is used to create SQS devices """ - def __init__(self): + def __init__(self, sqs_resource: 'SQSServiceResource'): """ - This manager used to create SQS devices + :param sqs_resource: the boto sqs service resource """ + super().__init__(sqs_resource=sqs_resource) self._logger = logging.getLogger(__name__) - self._sqs_resource = boto3.resource("sqs") - self._queue_cache: Dict[str, Queue] = {} - def get_output_device(self, queue_name: str) -> SQSOutputDevice: """ Returns and outgoing device by name @@ -84,14 +85,3 @@ def get_output_device(self, queue_name: str) -> SQSOutputDevice: message = f"Couldn't create output device '{queue_name}'" self._logger.exception(message) raise OutputDeviceException(message) from e - - def get_queue(self, queue_name: str) -> Queue: - """ - gets the queue from cache - """ - queue = self._queue_cache.get(queue_name, None) - if queue is None: - queue = self._sqs_resource.get_queue_by_name(QueueName=queue_name) - self._queue_cache[queue_name] = queue - - return queue diff --git a/pyproject.toml b/pyproject.toml index a497660..8fb2e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,8 @@ version = { file = "VERSION" } dev = { file = "requirements-dev.txt" } objectstorage = { file = "requirements-objectstorage.txt" } objectstorage_mypy = { file = "requirements-objectstorage_mypy.txt" } +sqs = { file = "requirements-sqs.txt" } +sqs_mypy = { file = "requirements-sqs_mypy.txt" } rabbitmq = { file = "requirements-rabbitmq.txt" } rabbitmq_mypy = { file = "requirements-rabbitmq_mypy.txt" } all = { file = "requirements-all.txt" } diff --git a/requirements-sqs.txt b/requirements-sqs.txt index 90c08c0..d16fa6b 100644 --- a/requirements-sqs.txt +++ b/requirements-sqs.txt @@ -1,2 +1 @@ boto3>=1.25,<2 -boto3-stubs[sqs]>=1.25,<2 \ No newline at end of file diff --git a/requirements-sqs_mypy.txt b/requirements-sqs_mypy.txt new file mode 100644 index 0000000..90c08c0 --- /dev/null +++ b/requirements-sqs_mypy.txt @@ -0,0 +1,2 @@ +boto3>=1.25,<2 +boto3-stubs[sqs]>=1.25,<2 \ No newline at end of file From b2f0fb6794875c55d75469f86e933115ecafe177 Mon Sep 17 00:00:00 2001 From: avivs Date: Tue, 15 Aug 2023 17:57:25 +0300 Subject: [PATCH 11/15] CR --- messageflux/iodevices/sqs/sqs_input_device.py | 5 +---- messageflux/iodevices/sqs/sqs_manager_base.py | 18 +++++++++++------- messageflux/iodevices/sqs/sqs_output_device.py | 16 ---------------- 3 files changed, 12 insertions(+), 27 deletions(-) diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index 988e79f..6043332 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -15,7 +15,7 @@ from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase if TYPE_CHECKING: - from mypy_boto3_sqs.service_resource import SQSServiceResource, Message as SQSMessage + from mypy_boto3_sqs.service_resource import Message as SQSMessage class SQSInputTransaction(InputTransaction): @@ -131,9 +131,6 @@ class SQSInputDeviceManager(SQSManagerBase, InputDeviceManager[SQSInputDevice]): SQS input device manager """ - def __init__(self, sqs_resource: 'SQSServiceResource'): - super().__init__(sqs_resource=sqs_resource) - def get_input_device(self, device_name: str) -> SQSInputDevice: """ Returns an incoming device by name diff --git a/messageflux/iodevices/sqs/sqs_manager_base.py b/messageflux/iodevices/sqs/sqs_manager_base.py index 5282580..acffff6 100644 --- a/messageflux/iodevices/sqs/sqs_manager_base.py +++ b/messageflux/iodevices/sqs/sqs_manager_base.py @@ -1,11 +1,11 @@ import logging -from typing import Dict, TYPE_CHECKING +from typing import Dict, TYPE_CHECKING, Optional -# try: -# import boto3 -# except ImportError as ex: -# raise ImportError("Please Install the required extra: messageflux[sqs]") from ex +try: + import boto3 +except ImportError as ex: + raise ImportError("Please Install the required extra: messageflux[sqs]") from ex if TYPE_CHECKING: from mypy_boto3_sqs.service_resource import Queue, SQSServiceResource @@ -15,10 +15,14 @@ class SQSManagerBase: """ base class for sqs device managers """ - def __init__(self, sqs_resource: 'SQSServiceResource') -> None: + + def __init__(self, sqs_resource: Optional['SQSServiceResource'] = None) -> None: """ - :param sqs_resource: the boto sqs service resource + :param sqs_resource: the boto sqs service resource. Defaults to creating from env vars """ + if sqs_resource is None: + sqs_resource = boto3.resource('sqs') + self._sqs_resource = sqs_resource self._queue_cache: Dict[str, 'Queue'] = {} self._logger = logging.getLogger(__name__) diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index 95ebd0d..60d22cb 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -1,5 +1,4 @@ import logging -from typing import TYPE_CHECKING from messageflux.iodevices.base import ( OutputDevice, @@ -11,14 +10,6 @@ from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase from messageflux.utils import get_random_id -# try: -# import boto3 -# except ImportError as ex: -# raise ImportError("Please Install the required extra: messageflux[sqs]") from ex -# -if TYPE_CHECKING: - from mypy_boto3_sqs.service_resource import SQSServiceResource - class SQSOutputDevice(OutputDevice["SQSOutputDeviceManager"]): """ @@ -65,13 +56,6 @@ class SQSOutputDeviceManager(SQSManagerBase, OutputDeviceManager[SQSOutputDevice this manager is used to create SQS devices """ - def __init__(self, sqs_resource: 'SQSServiceResource'): - """ - :param sqs_resource: the boto sqs service resource - """ - super().__init__(sqs_resource=sqs_resource) - self._logger = logging.getLogger(__name__) - def get_output_device(self, queue_name: str) -> SQSOutputDevice: """ Returns and outgoing device by name From 1fe83bc9054d2b802b63381033c6ddc36a50b863 Mon Sep 17 00:00:00 2001 From: avivs Date: Tue, 15 Aug 2023 18:06:40 +0300 Subject: [PATCH 12/15] CR --- .../iodevices/objectstorage/s3_message_store.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/messageflux/iodevices/objectstorage/s3_message_store.py b/messageflux/iodevices/objectstorage/s3_message_store.py index 9678681..832a01f 100644 --- a/messageflux/iodevices/objectstorage/s3_message_store.py +++ b/messageflux/iodevices/objectstorage/s3_message_store.py @@ -5,6 +5,8 @@ from hashlib import md5 from typing import Optional, BinaryIO, Dict, Any, Tuple, TYPE_CHECKING +import boto3 + from messageflux.iodevices.base.common import MessageBundle, Message from messageflux.iodevices.message_store_device_wrapper.message_store_base import MessageStoreException, \ MessageStoreBase @@ -67,8 +69,8 @@ class _S3MessageStoreBase(MessageStoreBase, metaclass=ABCMeta): _ORIGINAL_HEADERS_KEY = "originalheaders" def __init__(self, - s3_resource: 'S3ServiceResource', magic: bytes, + s3_resource: Optional['S3ServiceResource'] = None, auto_create_bucket: bool = False, bucket_name_formatter: Optional[BucketNameFormatterBase] = None): """ @@ -83,6 +85,9 @@ def __init__(self, self.bucket_name_formatter = bucket_name_formatter or BucketNameFormatterBase() self._magic = magic + if s3_resource is None: + s3_resource = boto3.resource('s3') + self._s3_resource = s3_resource self._auto_create_bucket = auto_create_bucket self._bucket_cache: Dict[str, S3Bucket] = {} @@ -199,7 +204,7 @@ class S3MessageStore(_S3MessageStoreBase): """ def __init__(self, - s3_resource: 'S3ServiceResource', + s3_resource: Optional['S3ServiceResource'] = None, magic: bytes = b"__S3_MSGSTORE__", auto_create_bucket: bool = False, bucket_name_formatter: Optional[BucketNameFormatterBase] = None, @@ -207,7 +212,7 @@ def __init__(self, """ An S3 based message store - :param s3_resource: the s3 resource from boto + :param s3_resource: the s3 resource from boto (or None, to create it from env vars) :param auto_create_bucket: Whether or not a bucket will be created when a message is being put in a nonexistent one. :param bucket_name_formatter: a formatter to use to manipulate the bucket name. @@ -238,7 +243,7 @@ class S3UploadMessageStore(_S3MessageStoreBase): """ def __init__(self, - s3_resource: 'S3ServiceResource', + s3_resource: Optional['S3ServiceResource'] = None, magic: bytes = b"__S3_UPLOAD_MSGSTORE__", auto_create_bucket: bool = False, bucket_name_formatter: Optional[BucketNameFormatterBase] = None, @@ -246,7 +251,7 @@ def __init__(self, """ An S3 based message store - :param s3_resource: the s3 resource from boto + :param s3_resource: the s3 resource from boto (or None, to create it from env vars) :param auto_create_bucket: Whether or not a bucket will be created when a message is being put in a nonexistent one. :param bucket_name_formatter: a formatter to use to manipulate the bucket name. From 6b9fc7f91a4b69fa048ba18260d6cb8e7ee88923 Mon Sep 17 00:00:00 2001 From: avivs Date: Wed, 23 Aug 2023 19:43:30 +0300 Subject: [PATCH 13/15] added SQS tests moved and fixed objectstorage tests --- messageflux/iodevices/sqs/sqs_manager_base.py | 10 ++++ tests/devices/objectstorage/__init__.py | 0 .../objectstorage/s3_message_store_test.py | 50 ------------------- .../headers_test.py => objectstorage_test.py} | 48 +++++++++++++++++- tests/devices/sqs_test.py | 40 +++++++++++++++ 5 files changed, 96 insertions(+), 52 deletions(-) delete mode 100644 tests/devices/objectstorage/__init__.py delete mode 100644 tests/devices/objectstorage/s3_message_store_test.py rename tests/devices/{objectstorage/headers_test.py => objectstorage_test.py} (53%) create mode 100644 tests/devices/sqs_test.py diff --git a/messageflux/iodevices/sqs/sqs_manager_base.py b/messageflux/iodevices/sqs/sqs_manager_base.py index acffff6..4c34c42 100644 --- a/messageflux/iodevices/sqs/sqs_manager_base.py +++ b/messageflux/iodevices/sqs/sqs_manager_base.py @@ -37,3 +37,13 @@ def get_queue(self, queue_name: str) -> 'Queue': self._queue_cache[queue_name] = queue return queue + + def create_queue(self, queue_name: str, **kwargs) -> 'Queue': + """ + creates a queue + + :param queue_name: the queue name to create + :return: the newly created queue + """ + + return self._sqs_resource.create_queue(QueueName=queue_name, **kwargs) diff --git a/tests/devices/objectstorage/__init__.py b/tests/devices/objectstorage/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/devices/objectstorage/s3_message_store_test.py b/tests/devices/objectstorage/s3_message_store_test.py deleted file mode 100644 index 1a2dce2..0000000 --- a/tests/devices/objectstorage/s3_message_store_test.py +++ /dev/null @@ -1,50 +0,0 @@ -from io import BytesIO - -import boto3 -import pytest -from moto import mock_s3 - -from messageflux.iodevices.base.common import MessageBundle, Message -from messageflux.iodevices.objectstorage import S3MessageStore, S3UploadMessageStore, S3NoSuchItem - - -@mock_s3 -def test_s3_message_store(): - data = BytesIO(b"some data") - headers = {"header key": "header value"} - s3_resource = boto3.Session().resource('s3') - with S3MessageStore(s3_resource, - auto_create_bucket=True) as s3_message_store: - key = s3_message_store.put_message("device-name", MessageBundle(Message(data, headers))) - - res_bundle = s3_message_store.read_message(key) - - assert res_bundle.message.bytes == b"some data" - assert "header key" in res_bundle.message.headers - assert res_bundle.message.headers["header key"] == "header value" - - s3_message_store.delete_message(key) - - with pytest.raises(S3NoSuchItem): - res_bundle = s3_message_store.read_message(key) - - -@mock_s3 -def test_s3_upload_message_store(): - data = BytesIO(b"some data") - headers = {"header key": "header value"} - s3_resource = boto3.Session().resource('s3') - with S3UploadMessageStore(s3_resource, - auto_create_bucket=True) as s3_message_store: - key = s3_message_store.put_message("device-name", MessageBundle(Message(data, headers))) - - res_bundle = s3_message_store.read_message(key) - - assert res_bundle.message.bytes == b"some data" - assert "header key" in res_bundle.message.headers - assert res_bundle.message.headers["header key"] == "header value" - - s3_message_store.delete_message(key) - - with pytest.raises(S3NoSuchItem): - res_bundle = s3_message_store.read_message(key) diff --git a/tests/devices/objectstorage/headers_test.py b/tests/devices/objectstorage_test.py similarity index 53% rename from tests/devices/objectstorage/headers_test.py rename to tests/devices/objectstorage_test.py index b632ba9..4cb8ba1 100644 --- a/tests/devices/objectstorage/headers_test.py +++ b/tests/devices/objectstorage_test.py @@ -1,16 +1,60 @@ from io import BytesIO import boto3 +import pytest from moto import mock_s3 from messageflux.iodevices.base.common import MessageBundle, Message -from messageflux.iodevices.objectstorage import S3MessageStore, BucketNameFormatterBase +from messageflux.iodevices.objectstorage import BucketNameFormatterBase +from messageflux.iodevices.objectstorage import S3MessageStore, S3UploadMessageStore, S3NoSuchItem + + +@mock_s3 +def test_s3_message_store(): + data = BytesIO(b"some data") + headers = {"header key": "header value"} + s3_resource = boto3.resource('s3') + with S3MessageStore(s3_resource, + auto_create_bucket=True) as s3_message_store: + key = s3_message_store.put_message("device-name", MessageBundle(Message(data, headers))) + + res_bundle = s3_message_store.read_message(key) + + assert res_bundle.message.bytes == b"some data" + assert "header key" in res_bundle.message.headers + assert res_bundle.message.headers["header key"] == "header value" + + s3_message_store.delete_message(key) + + with pytest.raises(S3NoSuchItem): + res_bundle = s3_message_store.read_message(key) + + +@mock_s3 +def test_s3_upload_message_store(): + data = BytesIO(b"some data") + headers = {"header key": "header value"} + s3_resource = boto3.Session().resource('s3') + with S3UploadMessageStore(s3_resource, + auto_create_bucket=True) as s3_message_store: + key = s3_message_store.put_message("device-name", MessageBundle(Message(data, headers))) + + res_bundle = s3_message_store.read_message(key) + + assert res_bundle.message.bytes == b"some data" + assert "header key" in res_bundle.message.headers + assert res_bundle.message.headers["header key"] == "header value" + + s3_message_store.delete_message(key) + + with pytest.raises(S3NoSuchItem): + res_bundle = s3_message_store.read_message(key) @mock_s3 def test_headers(): from messageflux.iodevices.objectstorage.s3api.s3bucket import S3Bucket - s3_resource = boto3.Session().resource('s3') + s3_resource = boto3.resource('s3') S3Bucket.create_bucket(s3_resource, BucketNameFormatterBase().format_name('SomeFlow', None)) diff --git a/tests/devices/sqs_test.py b/tests/devices/sqs_test.py new file mode 100644 index 0000000..3480b0b --- /dev/null +++ b/tests/devices/sqs_test.py @@ -0,0 +1,40 @@ +import uuid + +import boto3 +from moto import mock_sqs + +from messageflux.iodevices.sqs.sqs_input_device import SQSInputDeviceManager +from messageflux.iodevices.sqs.sqs_output_device import SQSOutputDeviceManager +from tests.devices.common import sanity_test, rollback_test + + +@mock_sqs +def test_generic_sanity(): + sqs_resource = boto3.resource('sqs', region_name='us-west-2') + input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource) + output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource) + queue_name = str(uuid.uuid4()) + with input_manager, output_manager: + q = output_manager.create_queue(queue_name) + try: + sanity_test(input_device_manager=input_manager, + output_device_manager=output_manager, + device_name=queue_name) + finally: + q.delete() + + +@mock_sqs +def test_generic_rollback(): + sqs_resource = boto3.resource('sqs', region_name='us-west-2') + input_manager = SQSInputDeviceManager(sqs_resource=sqs_resource) + output_manager = SQSOutputDeviceManager(sqs_resource=sqs_resource) + queue_name = str(uuid.uuid4()) + with input_manager, output_manager: + q = output_manager.create_queue(queue_name) + try: + rollback_test(input_device_manager=input_manager, + output_device_manager=output_manager, + device_name=queue_name) + finally: + q.delete() From f85bdfcc5652101363c362dbaf4ec515ae353cd5 Mon Sep 17 00:00:00 2001 From: avivs Date: Wed, 23 Aug 2023 19:49:01 +0300 Subject: [PATCH 14/15] fixed __init__.py --- messageflux/iodevices/sqs/__init__.py | 2 ++ tests/devices/sqs_test.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/messageflux/iodevices/sqs/__init__.py b/messageflux/iodevices/sqs/__init__.py index e69de29..d553a79 100644 --- a/messageflux/iodevices/sqs/__init__.py +++ b/messageflux/iodevices/sqs/__init__.py @@ -0,0 +1,2 @@ +from .sqs_input_device import SQSInputDeviceManager +from .sqs_output_device import SQSOutputDeviceManager \ No newline at end of file diff --git a/tests/devices/sqs_test.py b/tests/devices/sqs_test.py index 3480b0b..193e7d8 100644 --- a/tests/devices/sqs_test.py +++ b/tests/devices/sqs_test.py @@ -3,8 +3,8 @@ import boto3 from moto import mock_sqs -from messageflux.iodevices.sqs.sqs_input_device import SQSInputDeviceManager -from messageflux.iodevices.sqs.sqs_output_device import SQSOutputDeviceManager +from messageflux.iodevices.sqs import SQSInputDeviceManager +from messageflux.iodevices.sqs import SQSOutputDeviceManager from tests.devices.common import sanity_test, rollback_test From ea06d0a177e8aa10af97011f32eaec2f009e1fda Mon Sep 17 00:00:00 2001 From: avivs Date: Wed, 23 Aug 2023 19:53:18 +0300 Subject: [PATCH 15/15] fixed __init__.py --- messageflux/iodevices/sqs/__init__.py | 2 +- tests/devices/common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/messageflux/iodevices/sqs/__init__.py b/messageflux/iodevices/sqs/__init__.py index d553a79..753e469 100644 --- a/messageflux/iodevices/sqs/__init__.py +++ b/messageflux/iodevices/sqs/__init__.py @@ -1,2 +1,2 @@ from .sqs_input_device import SQSInputDeviceManager -from .sqs_output_device import SQSOutputDeviceManager \ No newline at end of file +from .sqs_output_device import SQSOutputDeviceManager diff --git a/tests/devices/common.py b/tests/devices/common.py index 4c70edd..114257d 100644 --- a/tests/devices/common.py +++ b/tests/devices/common.py @@ -81,7 +81,7 @@ def rollback_test(input_device_manager: InputDeviceManager, _assert_messages_equal(org_message=test_message_2, new_message=read_result2.message) read_result1.rollback() read_result2.rollback() - + time.sleep(sleep_between_sends) read_result = input_device.read_message(cancellation_token=cancellation_token) assert read_result is not None _assert_messages_equal(org_message=test_message_1, new_message=read_result.message)