Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions mqtt/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,41 +86,96 @@ impl TryFrom<u8> for ConnectReasonCode {
}
}

#[repr(u8)]
enum ConnectionFlags {
CleanStart = 0x02,
Password = 0x40,
Username = 0x80,
}

#[derive(Debug)]
pub struct LoginCredentials<'a> {
username: &'a str,
password: &'a str,
}

impl<'a> LoginCredentials<'a> {
pub fn new(username: &'a str, password: &'a str) -> Self {
Self { username, password }
}
}

pub fn send_connect<E, Writer: Write<E>>(
mut writer: Writer,
client_id: &Option<ClientId>,
login_credentials: &Option<LoginCredentials>,
rx_max: u16,
) -> Result<(), HlError<E>> {
const KEEP_ALIVE: u16 = 15 * 60;

let client_id_len: u8 = client_id.map(|id| id.len()).unwrap_or(0);
let mut flags: u8 = ConnectionFlags::CleanStart as u8;

let client_id = client_id.as_ref();
let client_id_len = client_id.map_or(0, ClientId::len);

let mut suffix_len: u8 = client_id_len;

if let Some(login_credentials) = login_credentials {
flags += ConnectionFlags::Username as u8;
flags += ConnectionFlags::Password as u8;

suffix_len += payload_len(login_credentials.username);
suffix_len += payload_len(login_credentials.password);
}

#[rustfmt::skip]
writer.write_all(&[
// control packet type
(CtrlPkt::CONNECT as u8) << 4,
// remaining length
18 + client_id_len,
18 + suffix_len,
// protocol name length
0, 4,
// protocol name
b'M', b'Q', b'T', b'T',
// protocol version
5,
// flags, clean start is set
0b00000010,
flags,
// keepalive
(KEEP_ALIVE >> 8) as u8, KEEP_ALIVE as u8,
// properties length
5,
// recieve maximum property
// receive maximum property
(Properties::MaxPktSize as u8), 0, 0, (rx_max >> 8) as u8, rx_max as u8,
// client ID length
0, client_id_len,
])?;
if let Some(client_id) = client_id {
writer.write_all(client_id.as_bytes())?;
}

if let Some(login_credentials) = login_credentials {
let LoginCredentials { username, password } = login_credentials;
writer.write_all(str_len_msb_lsb(username).as_slice())?;
writer.write_all(username.as_bytes())?;
writer.write_all(str_len_msb_lsb(password).as_slice())?;
writer.write_all(password.as_bytes())?;
}

writer.send()?;
Ok(())
}

fn payload_len(s: &str) -> u8 {
// str len + 2 bytes for str len prefix
(s.len() + 2) as u8
}

fn str_len_msb_lsb(s: &str) -> [u8; 2] {
let len: u16 = s.len() as u16;
let msb: u8 = (len >> 8) as u8;
let lsb: u8 = len as u8;

[msb, lsb]
}
13 changes: 11 additions & 2 deletions mqtt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ pub mod tls;
pub use w5500_tls;

pub use client_id::ClientId;
use connect::send_connect;
pub use connect::ConnectReasonCode;
use connect::{send_connect, LoginCredentials};
use hl::{
io::{Read, Seek, Write},
ll::{net::SocketAddrV4, Registers, Sn, SocketInterrupt, SocketInterruptMask},
Expand Down Expand Up @@ -325,6 +325,8 @@ pub struct Client<'a> {
state_timeout: StateTimeout,
/// Packet ID for subscribing
pkt_id: u16,
/// Login credentials
credentials: Option<LoginCredentials<'a>>,
}

impl<'a> Client<'a> {
Expand Down Expand Up @@ -364,6 +366,7 @@ impl<'a> Client<'a> {
},
client_id: None,
pkt_id: 1,
credentials: None,
}
}

Expand Down Expand Up @@ -399,6 +402,11 @@ impl<'a> Client<'a> {
self.client_id = Some(client_id)
}

/// Set the MQTT login credentials.
pub fn set_credentials(&mut self, username: &'a str, password: &'a str) {
self.credentials = Some(LoginCredentials::new(username, password));
}

fn next_pkt_id(&mut self) -> u16 {
self.pkt_id = self.pkt_id.checked_add(1).unwrap_or(1);
self.pkt_id
Expand Down Expand Up @@ -687,7 +695,8 @@ impl<'a> Client<'a> {
.size_in_bytes() as u16;

let writer: TcpWriter<W5500> = w5500.tcp_writer(self.sn)?;
send_connect(writer, &self.client_id, rx_max).map_err(Error::map_w5500)?;
send_connect(writer, &self.client_id, &self.credentials, rx_max)
.map_err(Error::map_w5500)?;
Ok(self
.state_timeout
.set_state_with_timeout(State::WaitConAck, monotonic_secs))
Expand Down
18 changes: 13 additions & 5 deletions mqtt/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
//! [`w5500-tls`]: https://github.com/newAM/w5500-rs/blob/main/tls/README.md

use crate::{
connect::send_connect,
connect::{send_connect, LoginCredentials},
hl::{
ll::{net::SocketAddrV4, Registers, Sn},
Error as HlError, Hostname,
Expand Down Expand Up @@ -88,16 +88,18 @@ fn map_tls_writer_err<E>(e: w5500_tls::Error) -> Error<E> {
///
/// The methods are nearly identical to [`crate::Client`], see [`crate::Client`]
/// for additional documentation and examples.
pub struct Client<'id, 'hn, 'psk, 'b, const N: usize> {
pub struct Client<'id, 'hn, 'psk, 'b, 'cred, const N: usize> {
tls: TlsClient<'hn, 'psk, 'b, N>,
client_id: Option<ClientId<'id>>,
/// State and Timeout tracker
state_timeout: StateTimeout,
/// Packet ID for subscribing
pkt_id: u16,
/// Login credentials
credentials: Option<LoginCredentials<'cred>>,
}

impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> {
impl<'id, 'hn, 'psk, 'b, 'cred, const N: usize> Client<'id, 'hn, 'psk, 'b, 'cred, N> {
/// Create a new MQTT client.
///
/// # Arguments
Expand Down Expand Up @@ -151,8 +153,8 @@ impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> {
timeout: None,
},
client_id: None,

pkt_id: 1,
credentials: None,
}
}

Expand All @@ -161,6 +163,11 @@ impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> {
self.client_id = Some(client_id)
}

/// Set the MQTT login credentials.
pub fn set_credentials(&mut self, username: &'cred str, password: &'cred str) {
self.credentials = Some(LoginCredentials::new(username, password));
}

fn next_pkt_id(&mut self) -> u16 {
self.pkt_id = self.pkt_id.checked_add(1).unwrap_or(1);
self.pkt_id
Expand Down Expand Up @@ -280,7 +287,8 @@ impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> {

let rx_max: u16 = (N as u16) - TLS_OVERHEAD;
let writer: TlsWriter<W5500> = self.tls.writer(w5500).map_err(map_tls_writer_err)?;
send_connect(writer, &self.client_id, rx_max).map_err(Error::map_w5500)?;
send_connect(writer, &self.client_id, &self.credentials, rx_max)
.map_err(Error::map_w5500)?;
Ok(self
.state_timeout
.set_state_with_timeout(State::WaitConAck, monotonic_secs))
Expand Down
41 changes: 41 additions & 0 deletions mqtt/tests/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,47 @@ fn connect_no_client_id() {
}));
}

#[test]
fn connect_with_login() {
const PORT: u16 = 12345;
let mut client: Client =
Client::new(Sn0, SRC_PORT, SocketAddrV4::new(Ipv4Addr::LOCALHOST, PORT));
client.set_credentials("mqtt-user", "password");

let mut fixture = Fixture::with_client(client, PORT);
assert!(matches!(
fixture.client_process().unwrap(),
Event::CallAfter(10)
));
fixture.server.accept();
assert!(matches!(
fixture.client_process().unwrap(),
Event::CallAfter(10)
));
fixture.server_expect(Packet::Connect(Connect {
protocol: V5,
keep_alive: 900,
client_id: "".to_string(),
clean_session: true,
last_will: None,
login: Some(mqttbytes::v5::Login {
username: "mqtt-user".to_string(),
password: "password".to_string(),
}),
properties: Some(ConnectProperties {
session_expiry_interval: None,
receive_maximum: None,
max_packet_size: Some(2048),
topic_alias_max: None,
request_response_info: None,
request_problem_info: None,
user_properties: vec![],
authentication_method: None,
authentication_data: None,
}),
}));
}

#[test]
fn connect_fail() {
const PORT: u16 = 12344;
Expand Down