diff --git a/messageflux/iodevices/rabbitmq/rabbitmq_output_device.py b/messageflux/iodevices/rabbitmq/rabbitmq_output_device.py index 2b88bd6..da28dcb 100644 --- a/messageflux/iodevices/rabbitmq/rabbitmq_output_device.py +++ b/messageflux/iodevices/rabbitmq/rabbitmq_output_device.py @@ -1,8 +1,8 @@ +import decimal import logging import ssl -from typing import BinaryIO, Dict, Any, Union, Optional, List, TYPE_CHECKING - import time +from typing import BinaryIO, Dict, Any, Union, Optional, List, TYPE_CHECKING from messageflux.iodevices.base import OutputDevice, OutputDeviceException, OutputDeviceManager from messageflux.iodevices.base.common import MessageBundle @@ -65,8 +65,9 @@ class RabbitMQOutputDeviceManager(RabbitMQDeviceManagerBase, OutputDeviceManager """ _PUBLISH_CONFIRM_HEADER = "__RABBITMQ_PUBLISH_CONFIRM__" - _outgoing_channel: Union[ThreadLocalMember[Optional['BlockingChannel']], - Optional['BlockingChannel']] = ThreadLocalMember(init_value=None) + _outgoing_channel: Union[ + ThreadLocalMember[Optional['BlockingChannel']], + Optional['BlockingChannel']] = ThreadLocalMember(init_value=None) def __init__(self, hosts: Union[List[str], str], @@ -270,6 +271,9 @@ def _inner_publish(self, headers = headers.copy() headers[self._PUBLISH_CONFIRM_HEADER] = self.publish_confirm + for header in headers: # this solves a weird bug in pika, which doesn't handle floats... + if isinstance(headers[header], float): + headers[header] = decimal.Decimal(headers[header]) data.seek(0) str_expiration: Optional[str] = None diff --git a/messageflux/iodevices/sqs/message_attributes.py b/messageflux/iodevices/sqs/message_attributes.py index 53ce274..73c72cc 100644 --- a/messageflux/iodevices/sqs/message_attributes.py +++ b/messageflux/iodevices/sqs/message_attributes.py @@ -1,29 +1,56 @@ import json - +import numbers from typing import Any, Dict, TYPE_CHECKING if TYPE_CHECKING: - from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef + from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef, MessageAttributeValueTypeDef def get_aws_data_type(value: Any) -> str: - if isinstance(value, (list, set, frozenset, tuple)): - return "String.Array" - elif isinstance(value, bool): + if isinstance(value, str): return "String" - elif isinstance(value, (int, float)): + elif isinstance(value, numbers.Number): return "Number" elif isinstance(value, bytes): return "Binary" + elif isinstance(value, bool): + return "String.bool" + elif isinstance(value, (list, set, frozenset, tuple)): + return "String.list" else: - return "String" + return "String.other" + + +def generate_message_attributes(headers: Dict[str, Any]) -> Dict[str, 'MessageAttributeValueQueueTypeDef']: + results: Dict[str, 'MessageAttributeValueQueueTypeDef'] = {} + for key, value in headers.items(): + data_type = get_aws_data_type(value) + results[key] = {"DataType": data_type} + + if data_type == "String": + results[key]["StringValue"] = value + + elif data_type == "Binary": + results[key]["BinaryValue"] = value + + else: + results[key]["StringValue"] = json.dumps(value) + + return results + + +def decode_message_attributes(message_attributes: Dict[str, 'MessageAttributeValueTypeDef']) -> Dict[str, Any]: + result: Dict[str, Any] = {} + for key, value in message_attributes.items(): + decoded_val: Any + if value["DataType"] == "String": + decoded_val = value["StringValue"] + + elif value["DataType"] == "Binary": + decoded_val = value["BinaryValue"] + else: + decoded_val = json.loads(value["StringValue"]) + result[key] = decoded_val -def generate_message_attributes(attributes: Dict[str, Any]) -> Dict[str, 'MessageAttributeValueQueueTypeDef']: - return { - key: { - "DataType": get_aws_data_type(value), - "StringValue": value if isinstance(value, str) else json.dumps(value) # to avoid double encoding - } - for key, value in attributes.items() - } + return result diff --git a/messageflux/iodevices/sqs/sqs_input_device.py b/messageflux/iodevices/sqs/sqs_input_device.py index 3ea7429..d9e180a 100644 --- a/messageflux/iodevices/sqs/sqs_input_device.py +++ b/messageflux/iodevices/sqs/sqs_input_device.py @@ -12,6 +12,7 @@ InputDeviceManager, ) from messageflux.iodevices.base.input_transaction import NULLTransaction +from messageflux.iodevices.sqs.message_attributes import decode_message_attributes from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase if TYPE_CHECKING: @@ -126,10 +127,7 @@ def _read_message( return ReadResult( message=Message( - headers={ - key: value["BinaryValue"] if value['DataType'] == "Binary" else value['StringValue'] - for key, value in message_attributes.items() - }, + headers=decode_message_attributes(message_attributes), data=BytesIO(sqs_message.body.encode()), ), transaction=transaction diff --git a/messageflux/iodevices/sqs/sqs_output_device.py b/messageflux/iodevices/sqs/sqs_output_device.py index 1989d88..4a74d66 100644 --- a/messageflux/iodevices/sqs/sqs_output_device.py +++ b/messageflux/iodevices/sqs/sqs_output_device.py @@ -29,6 +29,7 @@ def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str): """ super(SQSOutputDevice, self).__init__(device_manager, queue_name) self._sqs_queue = self.manager.get_queue(queue_name) + self._message_group_id = get_random_id() # https://awscli.amazonaws.com/v2/documentation/api/latest/reference/sqs/get-queue-attributes.html#get-queue-attributes self._is_fifo = queue_name.endswith(".fifo") @@ -37,7 +38,7 @@ def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str): def _send_message(self, message_bundle: MessageBundle): additional_args: Dict[str, Any] = {} if self._is_fifo: - additional_args = dict(MessageGroupId=get_random_id()) + additional_args = dict(MessageGroupId=self._message_group_id) response = self._sqs_queue.send_message( MessageBody=message_bundle.message.bytes.decode(), diff --git a/tests/devices/common.py b/tests/devices/common.py index 114257d..25ebae0 100644 --- a/tests/devices/common.py +++ b/tests/devices/common.py @@ -1,8 +1,7 @@ +import time import uuid from threading import Event -from typing import Optional - -import time +from typing import Optional, Dict, Any from messageflux.iodevices.base import InputDeviceManager, OutputDeviceManager, Message @@ -16,13 +15,25 @@ def _assert_messages_equal(org_message: Message, new_message: Message): def sanity_test(input_device_manager: InputDeviceManager, output_device_manager: OutputDeviceManager, device_name: Optional[str] = None, - sleep_between_sends=0.01): + sleep_between_sends: float = 0.01, + extra_headers: Optional[Dict[str, Any]] = None): """ Common test for all devices. """ + if extra_headers is None: + extra_headers = {} device_name = device_name or str(uuid.uuid4()) - test_message_1 = Message(str(uuid.uuid4()).encode(), headers={'test': 'test1'}) - test_message_2 = Message(str(uuid.uuid4()).encode(), headers={'test': 'test2'}) + test_message_1 = Message(str(uuid.uuid4()).encode(), headers={'test_str': 'test1', + 'test_int': 1, + 'test_float': 2.5, + 'test_bool': True, + **extra_headers}) + + test_message_2 = Message(str(uuid.uuid4()).encode(), headers={'test_str': 'test2', + 'test_int': 2, + 'test_float': 3.5, + 'test_bool': False, + **extra_headers}) output_device_manager.connect() try: diff --git a/tests/devices/sqs_test.py b/tests/devices/sqs_test.py index 88e9dfe..4f9185d 100644 --- a/tests/devices/sqs_test.py +++ b/tests/devices/sqs_test.py @@ -20,7 +20,8 @@ def test_generic_sanity(): try: sanity_test(input_device_manager=input_manager, output_device_manager=output_manager, - device_name=queue_name) + device_name=queue_name, + extra_headers={'test_bytes': b'bytes'}) finally: q.delete()