Skip to content

Commit

Permalink
Merge pull request #89 from Avivsalem/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
Avivsalem authored Aug 29, 2023
2 parents 47e106e + 7b9322e commit 815bb6e
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 31 deletions.
12 changes: 8 additions & 4 deletions messageflux/iodevices/rabbitmq/rabbitmq_output_device.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import decimal
import logging
import ssl
from typing import BinaryIO, Dict, Any, Union, Optional, List, TYPE_CHECKING

import time
from typing import BinaryIO, Dict, Any, Union, Optional, List, TYPE_CHECKING

from messageflux.iodevices.base import OutputDevice, OutputDeviceException, OutputDeviceManager
from messageflux.iodevices.base.common import MessageBundle
Expand Down Expand Up @@ -65,8 +65,9 @@ class RabbitMQOutputDeviceManager(RabbitMQDeviceManagerBase, OutputDeviceManager
"""
_PUBLISH_CONFIRM_HEADER = "__RABBITMQ_PUBLISH_CONFIRM__"

_outgoing_channel: Union[ThreadLocalMember[Optional['BlockingChannel']],
Optional['BlockingChannel']] = ThreadLocalMember(init_value=None)
_outgoing_channel: Union[
ThreadLocalMember[Optional['BlockingChannel']],
Optional['BlockingChannel']] = ThreadLocalMember(init_value=None)

def __init__(self,
hosts: Union[List[str], str],
Expand Down Expand Up @@ -270,6 +271,9 @@ def _inner_publish(self,
headers = headers.copy()
headers[self._PUBLISH_CONFIRM_HEADER] = self.publish_confirm

for header in headers: # this solves a weird bug in pika, which doesn't handle floats...
if isinstance(headers[header], float):
headers[header] = decimal.Decimal(headers[header])
data.seek(0)

str_expiration: Optional[str] = None
Expand Down
57 changes: 42 additions & 15 deletions messageflux/iodevices/sqs/message_attributes.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,56 @@
import json

import numbers
from typing import Any, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef
from mypy_boto3_sqs.type_defs import MessageAttributeValueQueueTypeDef, MessageAttributeValueTypeDef


def get_aws_data_type(value: Any) -> str:
if isinstance(value, (list, set, frozenset, tuple)):
return "String.Array"
elif isinstance(value, bool):
if isinstance(value, str):
return "String"
elif isinstance(value, (int, float)):
elif isinstance(value, numbers.Number):
return "Number"
elif isinstance(value, bytes):
return "Binary"
elif isinstance(value, bool):
return "String.bool"
elif isinstance(value, (list, set, frozenset, tuple)):
return "String.list"
else:
return "String"
return "String.other"


def generate_message_attributes(headers: Dict[str, Any]) -> Dict[str, 'MessageAttributeValueQueueTypeDef']:
results: Dict[str, 'MessageAttributeValueQueueTypeDef'] = {}
for key, value in headers.items():
data_type = get_aws_data_type(value)
results[key] = {"DataType": data_type}

if data_type == "String":
results[key]["StringValue"] = value

elif data_type == "Binary":
results[key]["BinaryValue"] = value

else:
results[key]["StringValue"] = json.dumps(value)

return results


def decode_message_attributes(message_attributes: Dict[str, 'MessageAttributeValueTypeDef']) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for key, value in message_attributes.items():
decoded_val: Any
if value["DataType"] == "String":
decoded_val = value["StringValue"]

elif value["DataType"] == "Binary":
decoded_val = value["BinaryValue"]
else:
decoded_val = json.loads(value["StringValue"])

result[key] = decoded_val

def generate_message_attributes(attributes: Dict[str, Any]) -> Dict[str, 'MessageAttributeValueQueueTypeDef']:
return {
key: {
"DataType": get_aws_data_type(value),
"StringValue": value if isinstance(value, str) else json.dumps(value) # to avoid double encoding
}
for key, value in attributes.items()
}
return result
6 changes: 2 additions & 4 deletions messageflux/iodevices/sqs/sqs_input_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
InputDeviceManager,
)
from messageflux.iodevices.base.input_transaction import NULLTransaction
from messageflux.iodevices.sqs.message_attributes import decode_message_attributes
from messageflux.iodevices.sqs.sqs_manager_base import SQSManagerBase

if TYPE_CHECKING:
Expand Down Expand Up @@ -126,10 +127,7 @@ def _read_message(

return ReadResult(
message=Message(
headers={
key: value["BinaryValue"] if value['DataType'] == "Binary" else value['StringValue']
for key, value in message_attributes.items()
},
headers=decode_message_attributes(message_attributes),
data=BytesIO(sqs_message.body.encode()),
),
transaction=transaction
Expand Down
3 changes: 2 additions & 1 deletion messageflux/iodevices/sqs/sqs_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str):
"""
super(SQSOutputDevice, self).__init__(device_manager, queue_name)
self._sqs_queue = self.manager.get_queue(queue_name)
self._message_group_id = get_random_id()

# https://awscli.amazonaws.com/v2/documentation/api/latest/reference/sqs/get-queue-attributes.html#get-queue-attributes
self._is_fifo = queue_name.endswith(".fifo")
Expand All @@ -37,7 +38,7 @@ def __init__(self, device_manager: "SQSOutputDeviceManager", queue_name: str):
def _send_message(self, message_bundle: MessageBundle):
additional_args: Dict[str, Any] = {}
if self._is_fifo:
additional_args = dict(MessageGroupId=get_random_id())
additional_args = dict(MessageGroupId=self._message_group_id)

response = self._sqs_queue.send_message(
MessageBody=message_bundle.message.bytes.decode(),
Expand Down
23 changes: 17 additions & 6 deletions tests/devices/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import time
import uuid
from threading import Event
from typing import Optional

import time
from typing import Optional, Dict, Any

from messageflux.iodevices.base import InputDeviceManager, OutputDeviceManager, Message

Expand All @@ -16,13 +15,25 @@ def _assert_messages_equal(org_message: Message, new_message: Message):
def sanity_test(input_device_manager: InputDeviceManager,
output_device_manager: OutputDeviceManager,
device_name: Optional[str] = None,
sleep_between_sends=0.01):
sleep_between_sends: float = 0.01,
extra_headers: Optional[Dict[str, Any]] = None):
"""
Common test for all devices.
"""
if extra_headers is None:
extra_headers = {}
device_name = device_name or str(uuid.uuid4())
test_message_1 = Message(str(uuid.uuid4()).encode(), headers={'test': 'test1'})
test_message_2 = Message(str(uuid.uuid4()).encode(), headers={'test': 'test2'})
test_message_1 = Message(str(uuid.uuid4()).encode(), headers={'test_str': 'test1',
'test_int': 1,
'test_float': 2.5,
'test_bool': True,
**extra_headers})

test_message_2 = Message(str(uuid.uuid4()).encode(), headers={'test_str': 'test2',
'test_int': 2,
'test_float': 3.5,
'test_bool': False,
**extra_headers})

output_device_manager.connect()
try:
Expand Down
3 changes: 2 additions & 1 deletion tests/devices/sqs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_generic_sanity():
try:
sanity_test(input_device_manager=input_manager,
output_device_manager=output_manager,
device_name=queue_name)
device_name=queue_name,
extra_headers={'test_bytes': b'bytes'})
finally:
q.delete()

Expand Down

0 comments on commit 815bb6e

Please sign in to comment.