Skip to content

Add MQTT Sink #659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion conda/post-link.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ $PREFIX/bin/pip install \
'confluent-kafka[avro,json,protobuf,schemaregistry]>=2.8.2,<2.10' \
'influxdb>=5.3,<6' \
'jsonpath_ng>=1.7.0,<2' \
'types-psycopg2>=2.9,<3'
'types-psycopg2>=2.9,<3' \
'paho-mqtt>=2.1.0,<3'
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ all = [
"pymongo>=4.11,<5",
"pandas>=1.0.0,<3.0",
"elasticsearch>=8.17,<9",
"influxdb>=5.3,<6"
"influxdb>=5.3,<6",
"paho-mqtt>=2.1.0,<3"
]

avro = ["fastavro>=1.8,<2.0"]
Expand All @@ -62,6 +63,7 @@ neo4j = ["neo4j>=5.27.0,<6"]
mongodb = ["pymongo>=4.11,<5"]
pandas = ["pandas>=1.0.0,<3.0"]
elasticsearch = ["elasticsearch>=8.17,<9"]
mqtt = ["paho-mqtt>=2.1.0,<3"]

# AWS dependencies are separated by service to support
# different requirements in the future.
Expand Down
254 changes: 254 additions & 0 deletions quixstreams/sinks/community/mqtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import json
import logging
import time
from datetime import datetime
from typing import Any, Callable, Literal, Optional, Union, get_args

from quixstreams.models.types import HeadersTuples
from quixstreams.sinks import (
BaseSink,
ClientConnectFailureCallback,
ClientConnectSuccessCallback,
)

try:
import paho.mqtt.client as paho
except ImportError as exc:
raise ImportError(
'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it"
) from exc


logger = logging.getLogger(__name__)

VERSION_MAP = {
"3.1": paho.MQTTv31,
"3.1.1": paho.MQTTv311,
"5": paho.MQTTv5,
}
MQTT_SUCCESS = paho.MQTT_ERR_SUCCESS
ProtocolVersion = Literal["3.1", "3.1.1", "5"]
MqttPropertiesHandler = Union[paho.Properties, Callable[[Any], paho.Properties]]
RetainHandler = Union[bool, Callable[[Any], bool]]


class MQTTSink(BaseSink):
"""
A sink that publishes messages to an MQTT broker.
"""

def __init__(
self,
client_id: str,
server: str,
port: int,
topic_root: str,
username: str = None,
password: str = None,
version: ProtocolVersion = "3.1.1",
tls_enabled: bool = True,
key_serializer: Callable[[Any], str] = bytes.decode,
value_serializer: Callable[[Any], str] = json.dumps,
qos: Literal[0, 1] = 1,
mqtt_flush_timeout_seconds: int = 10,
retain: Union[bool, Callable[[Any], bool]] = False,
properties: Optional[MqttPropertiesHandler] = None,
on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
on_client_connect_failure: Optional[ClientConnectFailureCallback] = None,
):
"""
Initialize the MQTTSink.

:param client_id: MQTT client identifier.
:param server: MQTT broker server address.
:param port: MQTT broker server port.
:param topic_root: Root topic to publish messages to.
:param username: Username for MQTT broker authentication. Default = None
:param password: Password for MQTT broker authentication. Default = None
:param version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
:param tls_enabled: Whether to use TLS encryption. Default = True
:param key_serializer: How to serialize the MQTT message key for producing.
:param value_serializer: How to serialize the MQTT message value for producing.
:param qos: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
:param mqtt_flush_timeout_seconds: how long to wait for publish acknowledgment
of MQTT messages before failing. Default = 10.
:param retain: Retain last message for new subscribers. Default = False.
Also accepts a callable that uses the current message value as input.
:param properties: An optional Properties instance for messages. Default = None.
Also accepts a callable that uses the current message value as input.
:param on_client_connect_success: An optional callback made after successful
client authentication, primarily for additional logging.
:param on_client_connect_failure: An optional callback made after failed
client authentication (which should raise an Exception).
Callback should accept the raised Exception as an argument.
Callback must resolve (or propagate/re-raise) the Exception.
"""
super().__init__(
on_client_connect_success=on_client_connect_success,
on_client_connect_failure=on_client_connect_failure,
)
if qos == 2:
raise ValueError(f"MQTT QoS level {2} is currently not supported.")
if not (protocol := VERSION_MAP.get(version)):
raise ValueError(
f"Invalid MQTT version {version}; valid: {get_args(ProtocolVersion)}"
)
if properties and protocol != "5":
raise ValueError(
"MQTT Properties can only be used with MQTT protocol version 5"
)

self._version = version
self._server = server
self._port = port
self._topic_root = topic_root
self._key_serializer = key_serializer
self._value_serializer = value_serializer
self._qos = qos
self._flush_timeout = mqtt_flush_timeout_seconds
self._pending_acks: set[int] = set()
self._retain = _get_retain_callable(retain)
self._properties = _get_properties_callable(properties)

self._client = paho.Client(
callback_api_version=paho.CallbackAPIVersion.VERSION2,
client_id=client_id,
userdata=None,
protocol=protocol,
)

if username:
self._client.username_pw_set(username, password)
if tls_enabled:
self._client.tls_set(tls_version=paho.ssl.PROTOCOL_TLS)
self._client.reconnect_delay_set(5, 60)
self._client.on_connect = _mqtt_on_connect_cb
self._client.on_disconnect = _mqtt_on_disconnect_cb
self._client.on_publish = self._on_publish_cb
self._publish_count = 0

def setup(self):
self._client.connect(self._server, self._port)
self._client.loop_start()

def _publish_to_mqtt(
self,
data: Any,
topic_suffix: Any,
):
properties = self._properties
info = self._client.publish(
f"{self._topic_root}/{self._key_serializer(topic_suffix)}",
payload=self._value_serializer(data),
qos=self._qos,
properties=properties(data) if properties else None,
retain=self._retain(data),
)
if self._qos:
if info.rc != MQTT_SUCCESS:
raise MqttPublishEnqueueFailed(
f"Failed adding message to MQTT publishing queue; "
f"error code {info.rc}: {paho.error_string(info.rc)}"
)
self._pending_acks.add(info.mid)
else:
self._publish_count += 1

def _on_publish_cb(
self,
client: paho.Client,
userdata: Any,
mid: int,
rc: paho.ReasonCode,
p: paho.Properties,
):
"""
This is only triggered upon successful publish when self._qos > 0.
"""
self._publish_count += 1
self._pending_acks.remove(mid)

def add(
self,
topic: str,
partition: int,
offset: int,
key: bytes,
value: bytes,
timestamp: datetime,
headers: HeadersTuples,
):
try:
self._publish_to_mqtt(value, key)
except Exception as e:
self._cleanup()
raise e

def flush(self):
if self._pending_acks:
start_time = time.monotonic()
timeout = start_time + self._flush_timeout
while self._pending_acks and start_time < timeout:
logger.debug(f"Pending acks remaining: {len(self._pending_acks)}")
time.sleep(1)
if self._pending_acks:
self._cleanup()
raise MqttPublishAckTimeout(
f"Mqtt acknowledgement timeout of {self._flush_timeout}s reached."
)
logger.info(f"{self._publish_count} MQTT messages published.")
self._publish_count = 0

def on_paused(self):
pass

def _cleanup(self):
self._client.loop_stop()
self._client.disconnect()


class MqttPublishEnqueueFailed(Exception):
pass


class MqttPublishAckTimeout(Exception):
pass


def _mqtt_on_connect_cb(
client: paho.Client,
userdata: any,
connect_flags: paho.ConnectFlags,
reason_code: paho.ReasonCode,
properties: paho.Properties,
):
if reason_code != 0:
raise ConnectionError(
f"Failed to connect to MQTT broker; ERROR: ({reason_code.value}).{reason_code.getName()}"
)


def _mqtt_on_disconnect_cb(
client: paho.Client,
userdata: any,
disconnect_flags: paho.DisconnectFlags,
reason_code: paho.ReasonCode,
properties: paho.Properties,
):
logger.info(
f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!"
)


def _get_properties_callable(
properties: Optional[MqttPropertiesHandler],
) -> Optional[Callable[[Any], paho.Properties]]:
if isinstance(properties, paho.Properties):
return lambda data: properties(data)
return properties


def _get_retain_callable(retain: RetainHandler) -> Callable[[Any], bool]:
if isinstance(retain, bool):
return lambda data: retain
return retain
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ redis[hiredis]>=5.2.0,<6
pandas>=1.0.0,<3.0
psycopg2-binary>=2.9,<3
types-psycopg2>=2.9,<3
paho-mqtt>=2.1.0,<3
106 changes: 106 additions & 0 deletions tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from datetime import datetime
from typing import Optional
from unittest.mock import patch

import pytest

from quixstreams.sinks.community.mqtt import MQTTSink


@pytest.fixture()
def mqtt_sink_factory():
def factory(
client_id: str = "test_client",
server: str = "localhost",
port: int = 1883,
username: Optional[str] = None,
password: Optional[str] = None,
topic_root: str = "test/topic",
version: str = "3.1.1",
tls_enabled: bool = True,
qos: int = 1,
) -> MQTTSink:
with patch("paho.mqtt.client.Client") as MockClient:
mock_mqtt_client = MockClient.return_value
sink = MQTTSink(
client_id=client_id,
server=server,
port=port,
topic_root=topic_root,
username=username,
password=password,
version=version,
tls_enabled=tls_enabled,
qos=qos,
)
sink.mqtt_client = mock_mqtt_client
return sink, mock_mqtt_client

return factory


class TestMQTTSink:
def test_mqtt_connect(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory()
sink.setup()
mock_mqtt_client.connect.assert_called_once_with("localhost", 1883)

def test_mqtt_tls_enabled(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True)
mock_mqtt_client.tls_set.assert_called_once()

def test_mqtt_tls_disabled(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False)
mock_mqtt_client.tls_set.assert_not_called()

def test_mqtt_publish(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory()
data = "test_data"
key = b"test_key"
timestamp = datetime.now()
headers = []

class MockInfo:
def __init__(self):
self.rc = 0
self.mid = 123

mock_mqtt_client.publish.return_value = MockInfo()
sink.add(
topic="test-topic",
partition=0,
offset=1,
key=key,
value=data,
timestamp=timestamp,
headers=headers,
)

mock_mqtt_client.publish.assert_called_once_with(
"test/topic/test_key",
payload='"test_data"',
qos=1,
retain=False,
properties=None,
)

def test_mqtt_authentication(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory(username="user", password="pass")
mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass")

def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory()
mock_mqtt_client.publish.side_effect = ConnectionError("publish error")
with pytest.raises(ConnectionError):
sink.add(
topic="test-topic",
partition=0,
offset=1,
key=b"key",
value="data",
timestamp=12345,
headers=(),
)

mock_mqtt_client.loop_stop.assert_called_once()
mock_mqtt_client.disconnect.assert_called_once()