Skip to content
12 changes: 10 additions & 2 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from awscrt.http import HttpProxyOptions, HttpRequest
from awscrt.io import ClientBootstrap, ClientTlsContext, SocketOptions
from dataclasses import dataclass
from awscrt.mqtt5 import Client as Mqtt5Client
from awscrt.mqtt5 import Client as Mqtt5Client, _get_awsiot_metrics_str


class QoS(IntEnum):
Expand Down Expand Up @@ -330,6 +330,8 @@ class Connection(NativeResource):

proxy_options (Optional[awscrt.http.HttpProxyOptions]):
Optional proxy options for all connections.

enable_aws_metrics (bool): If true, append AWS IoT metrics to the username. (Default to true)
"""

def __init__(self,
Expand All @@ -355,7 +357,8 @@ def __init__(self,
proxy_options=None,
on_connection_success=None,
on_connection_failure=None,
on_connection_closed=None
on_connection_closed=None,
enable_aws_metrics=True
):

assert isinstance(client, Client) or isinstance(client, Mqtt5Client)
Expand Down Expand Up @@ -404,6 +407,11 @@ def __init__(self,
self.ping_timeout_ms = ping_timeout_ms
self.protocol_operation_timeout_ms = protocol_operation_timeout_ms
self.will = will

if enable_aws_metrics:
username = username if username else ""
username += _get_awsiot_metrics_str(username)

self.username = username
self.password = password
self.socket_options = socket_options if socket_options else SocketOptions()
Expand Down
41 changes: 41 additions & 0 deletions awscrt/mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,40 @@
from collections.abc import Sequence
from inspect import signature

# Global variable to cache metrics string
_metrics_str = None


def _get_awsiot_metrics_str(current_username=""):
global _metrics_str

username_has_query = False
if current_username.find("?") != -1:
username_has_query = True
# The SDK query is already set, skip adding it again
if username_has_query and current_username.find("SDK=") != -1:
return ""

if _metrics_str is None:
try:
import importlib.metadata
try:
version = importlib.metadata.version("awscrt")
_metrics_str = "SDK=CRTPython&Version={}&Platform={}".format(
version, _awscrt.get_platform_build_os_string())
except importlib.metadata.PackageNotFoundError:
_metrics_str = "SDK=CRTPython&Version=dev&Platform={}".format(_awscrt.get_platform_build_os_string())
except BaseException:
_metrics_str = ""

if not _metrics_str == "":
if username_has_query:
return "&" + _metrics_str
else:
return "?" + _metrics_str
else:
return ""


class QoS(IntEnum):
"""MQTT message delivery quality of service.
Expand Down Expand Up @@ -1338,6 +1372,7 @@ class ClientOptions:
on_lifecycle_event_connection_success_fn (Callable[[LifecycleConnectSuccessData],]): Callback for Lifecycle Event Connection Success.
on_lifecycle_event_connection_failure_fn (Callable[[LifecycleConnectFailureData],]): Callback for Lifecycle Event Connection Failure.
on_lifecycle_event_disconnection_fn (Callable[[LifecycleDisconnectData],]): Callback for Lifecycle Event Disconnection.
enable_aws_metrics (bool): Whether to append AWS IoT metrics to the username field during CONNECT. Default: True
"""
host_name: str
port: int = None
Expand All @@ -1364,6 +1399,7 @@ class ClientOptions:
on_lifecycle_event_connection_success_fn: Callable[[LifecycleConnectSuccessData], None] = None
on_lifecycle_event_connection_failure_fn: Callable[[LifecycleConnectFailureData], None] = None
on_lifecycle_event_disconnection_fn: Callable[[LifecycleDisconnectData], None] = None
enable_aws_metrics: bool = True


def _check_callback(callback):
Expand Down Expand Up @@ -1753,6 +1789,11 @@ def __init__(self, client_options: ClientOptions):
is_will_none = False
will = connect_options.will

username = connect_options.username
if client_options.enable_aws_metrics:
username = username if username else ""
username += _get_awsiot_metrics_str(username)
connect_options.username = username
websocket_is_none = client_options.websocket_handshake_transform is None
self.tls_ctx = client_options.tls_ctx
self._binding = _awscrt.mqtt5_client_new(self,
Expand Down
8 changes: 8 additions & 0 deletions source/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ PyObject *aws_py_get_cpu_count_for_group(PyObject *self, PyObject *args) {
return PyLong_FromSize_t(count);
}

PyObject *aws_py_get_platform_build_os_string(PyObject *self, PyObject *args) {
(void)self;
(void)args;

struct aws_byte_cursor os_string = aws_get_platform_build_os_string();
return PyUnicode_FromAwsByteCursor(&os_string);
}

PyObject *aws_py_thread_join_all_managed(PyObject *self, PyObject *args) {
(void)self;

Expand Down
1 change: 1 addition & 0 deletions source/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

PyObject *aws_py_get_cpu_group_count(PyObject *self, PyObject *args);
PyObject *aws_py_get_cpu_count_for_group(PyObject *self, PyObject *args);
PyObject *aws_py_get_platform_build_os_string(PyObject *self, PyObject *args);

PyObject *aws_py_thread_join_all_managed(PyObject *self, PyObject *args);

Expand Down
1 change: 1 addition & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ static PyMethodDef s_module_methods[] = {
AWS_PY_METHOD_DEF(get_corresponding_builtin_exception, METH_VARARGS),
AWS_PY_METHOD_DEF(get_cpu_group_count, METH_VARARGS),
AWS_PY_METHOD_DEF(get_cpu_count_for_group, METH_VARARGS),
AWS_PY_METHOD_DEF(get_platform_build_os_string, METH_VARARGS),
AWS_PY_METHOD_DEF(native_memory_usage, METH_NOARGS),
AWS_PY_METHOD_DEF(native_memory_dump, METH_NOARGS),
AWS_PY_METHOD_DEF(thread_join_all_managed, METH_VARARGS),
Expand Down
6 changes: 4 additions & 2 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def _test_mqtt311_direct_connect_basic_auth(self):
host_name=input_host_name,
port=input_port,
username=input_username,
password=input_password)
password=input_password,
enable_aws_metrics=False) # Disable AWS metrics for basic auth on non-AWS broker
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

Expand Down Expand Up @@ -760,7 +761,8 @@ def sign_function(transform_args, **kwargs):
username=input_username,
password=input_password,
use_websockets=True,
websocket_handshake_transform=sign_function)
websocket_handshake_transform=sign_function,
enable_aws_metrics=False) # Disable AWS metrics for basic auth on non-AWS broker
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

Expand Down
9 changes: 6 additions & 3 deletions test/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def _test_direct_connect_basic_auth(self):
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port,
connect_options=connect_options
connect_options=connect_options,
enable_aws_metrics=False # Disable AWS metrics for basic auth on non-AWS broker
)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)
Expand Down Expand Up @@ -416,7 +417,8 @@ def _test_websocket_connect_basic_auth(self):
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port,
connect_options=connect_options
connect_options=connect_options,
enable_aws_metrics=False # Disable AWS metrics for basic auth on non-AWS broker
)
callbacks = Mqtt5TestCallbacks()
client_options.websocket_handshake_transform = callbacks.ws_handshake_transform
Expand Down Expand Up @@ -615,7 +617,8 @@ def test_connect_with_incorrect_basic_authentication_credentials(self):
client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=input_port,
connect_options=connect_options
connect_options=connect_options,
enable_aws_metrics=False # Disable AWS metrics for basic auth on non-AWS broker
)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)
Expand Down
Loading