Skip to content

Commit

Permalink
dependency: use mio to replace polling (#280)
Browse files Browse the repository at this point in the history
Poller.add() is unsafe in `polling` 3.x, but we want to keep this lib safe code only. Hence the change.
  • Loading branch information
keepsimple1 authored Dec 5, 2024
1 parent 99483b7 commit fcd31f3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
os: [ubuntu-20.04, windows-latest, macos-latest]
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@1.65.0
- uses: dtolnay/rust-toolchain@1.70.0
with:
components: rustfmt, clippy
- name: Run rustfmt and fail if any warnings
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "mdns-sd"
version = "0.12.0"
authors = ["keepsimple <[email protected]>"]
edition = "2018"
rust-version = "1.65.0"
rust-version = "1.70.0"
license = "Apache-2.0 OR MIT"
repository = "https://github.com/keepsimple1/mdns-sd"
documentation = "https://docs.rs/mdns-sd"
Expand All @@ -21,7 +21,7 @@ fastrand = "2.1"
flume = { version = "0.11", default-features = false } # channel between threads
if-addrs = { version = "0.13", features = ["link-local"] } # get local IP addresses
log = { version = "0.4", optional = true } # logging
polling = "2.1" # select/poll sockets
mio = { version = "1.0", features = ["os-poll", "net"] } # select/poll sockets
socket2 = { version = "0.5.5", features = ["all"] } # socket APIs

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Build](https://github.com/keepsimple1/mdns-sd/actions/workflows/build.yml/badge.svg)](https://github.com/keepsimple1/mdns-sd/actions)
[![Cargo](https://img.shields.io/crates/v/mdns-sd.svg)](https://crates.io/crates/mdns-sd)
[![docs.rs](https://img.shields.io/docsrs/mdns-sd)](https://docs.rs/mdns-sd/latest/mdns_sd/)
[![Rust version: 1.63+](https://img.shields.io/badge/rust%20version-1.63+-orange)](https://blog.rust-lang.org/2022/08/11/Rust-1.63.0.html)
[![Rust version: 1.70+](https://img.shields.io/badge/rust%20version-1.70+-orange)](https://blog.rust-lang.org/2022/08/11/Rust-1.70.0.html)

This is a small implementation of mDNS (Multicast DNS) based service discovery in safe Rust, with a small set of dependencies. Some highlights:

Expand Down
115 changes: 67 additions & 48 deletions src/service_daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ use crate::{
};
use flume::{bounded, Sender, TrySendError};
use if_addrs::{IfAddr, Interface};
use polling::Poller;
use socket2::{SockAddr, Socket};
use mio::{net::UdpSocket as MioUdpSocket, Poll};
use socket2::Socket;
use std::{
cmp::{self, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
fmt,
io::Read,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket},
str, thread,
time::Duration,
Expand Down Expand Up @@ -179,14 +178,15 @@ impl ServiceDaemon {
.set_nonblocking(true)
.map_err(|e| e_fmt!("failed to set nonblocking for signal socket: {}", e))?;

let poller = Poller::new().map_err(|e| e_fmt!("Failed to create Poller: {}", e))?;
let poller = Poll::new().map_err(|e| e_fmt!("failed to create mio Poll: {e}"))?;

let (sender, receiver) = bounded(100);

// Spawn the daemon thread
let mio_sock = MioUdpSocket::from_std(signal_sock);
thread::Builder::new()
.name("mDNS_daemon".to_string())
.spawn(move || Self::daemon_thread(signal_sock, poller, receiver))
.spawn(move || Self::daemon_thread(mio_sock, poller, receiver))
.map_err(|e| e_fmt!("thread builder failed to spawn: {}", e))?;

Ok(Self {
Expand Down Expand Up @@ -417,7 +417,7 @@ impl ServiceDaemon {
self.send_cmd(Command::Verify(instance_fullname, timeout))
}

fn daemon_thread(signal_sock: UdpSocket, poller: Poller, receiver: Receiver<Command>) {
fn daemon_thread(signal_sock: MioUdpSocket, poller: Poll, receiver: Receiver<Command>) {
let zc = Zeroconf::new(signal_sock, poller);

if let Some(cmd) = Self::run(zc, receiver) {
Expand All @@ -436,35 +436,40 @@ impl ServiceDaemon {
}
}

fn handle_poller_events(zc: &mut Zeroconf, events: &[polling::Event]) {
fn handle_poller_events(zc: &mut Zeroconf, events: &mio::Events) {
for ev in events.iter() {
trace!("event received with key {}", ev.key);
if ev.key == SIGNAL_SOCK_EVENT_KEY {
trace!("event received with key {:?}", ev.token());
if ev.token().0 == SIGNAL_SOCK_EVENT_KEY {
// Drain signals as we will drain commands as well.
zc.signal_sock_drain();

if let Err(e) = zc
.poller
.modify(&zc.signal_sock, polling::Event::readable(ev.key))
{
if let Err(e) = zc.poller.registry().reregister(
&mut zc.signal_sock,
ev.token(),
mio::Interest::READABLE,
) {
debug!("failed to modify poller for signal socket: {}", e);
}
continue; // Next event.
}

// Read until no more packets available.
let intf = match zc.poll_ids.get(&ev.key) {
let intf = match zc.poll_ids.get(&ev.token().0) {
Some(interface) => interface.clone(),
None => {
debug!("Ip for event key {} not found", ev.key);
debug!("Ip for event key {} not found", ev.token().0);
break;
}
};
while zc.handle_read(&intf) {}

// we continue to monitor this socket.
if let Some(sock) = zc.intf_socks.get(&intf) {
if let Err(e) = zc.poller.modify(sock, polling::Event::readable(ev.key)) {
if let Some(sock) = zc.intf_socks.get_mut(&intf) {
if let Err(e) =
zc.poller
.registry()
.reregister(sock, ev.token(), mio::Interest::READABLE)
{
debug!("modify poller for interface {:?}: {}", &intf, e);
break;
}
Expand All @@ -482,20 +487,26 @@ impl ServiceDaemon {
/// 5. process retransmissions if any.
fn run(mut zc: Zeroconf, receiver: Receiver<Command>) -> Option<Command> {
// Add the daemon's signal socket to the poller.
if let Err(e) = zc.poller.add(
&zc.signal_sock,
polling::Event::readable(SIGNAL_SOCK_EVENT_KEY),
if let Err(e) = zc.poller.registry().register(
&mut zc.signal_sock,
mio::Token(SIGNAL_SOCK_EVENT_KEY),
mio::Interest::READABLE,
) {
debug!("failed to add signal socket to the poller: {}", e);
return None;
}

// Add mDNS sockets to the poller.
for (intf, sock) in zc.intf_socks.iter() {
for (intf, sock) in zc.intf_socks.iter_mut() {
let key =
Zeroconf::add_poll_impl(&mut zc.poll_ids, &mut zc.poll_id_count, intf.clone());
if let Err(e) = zc.poller.add(sock, polling::Event::readable(key)) {
debug!("add socket of {:?} to poller: {}", intf, e);

if let Err(e) =
zc.poller
.registry()
.register(sock, mio::Token(key), mio::Interest::READABLE)
{
debug!("add socket of {:?} to poller: {e}", intf);
return None;
}
}
Expand All @@ -507,7 +518,7 @@ impl ServiceDaemon {

// Start the run loop.

let mut events = Vec::new();
let mut events = mio::Events::with_capacity(1024);
loop {
let now = current_time_millis();

Expand All @@ -520,7 +531,7 @@ impl ServiceDaemon {

// Process incoming packets, command events and optional timeout.
events.clear();
match zc.poller.wait(&mut events, timeout) {
match zc.poller.poll(&mut events, timeout) {
Ok(_) => Self::handle_poller_events(&mut zc, &events),
Err(e) => debug!("failed to select from sockets: {}", e),
}
Expand Down Expand Up @@ -626,7 +637,7 @@ impl ServiceDaemon {
}

/// Creates a new UDP socket that uses `intf` to send and recv multicast.
fn new_socket_bind(intf: &Interface) -> Result<Socket> {
fn new_socket_bind(intf: &Interface) -> Result<MioUdpSocket> {
// Use the same socket for receiving and sending multicast packets.
// Such socket has to bind to INADDR_ANY or IN6ADDR_ANY.
let intf_ip = &intf.ip();
Expand All @@ -650,7 +661,7 @@ fn new_socket_bind(intf: &Interface) -> Result<Socket> {
sock.send_to(&packet, &multicast_addr)
.map_err(|e| e_fmt!("send multicast packet on addr {}: {}", ip, e))?;
}
Ok(sock)
Ok(MioUdpSocket::from_std(UdpSocket::from(sock)))
}
IpAddr::V6(ip) => {
let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), MDNS_PORT, 0, 0);
Expand All @@ -668,7 +679,7 @@ fn new_socket_bind(intf: &Interface) -> Result<Socket> {
// be many IPv6 interfaces on a host and could cause such send error:
// "No buffer space available (os error 55)".

Ok(sock)
Ok(MioUdpSocket::from_std(UdpSocket::from(sock)))
}
}
}
Expand Down Expand Up @@ -834,7 +845,7 @@ struct IfSelection {
/// A struct holding the state. It was inspired by `zeroconf` package in Python.
struct Zeroconf {
/// Local interfaces with sockets to recv/send on these interfaces.
intf_socks: HashMap<Interface, Socket>,
intf_socks: HashMap<Interface, MioUdpSocket>,

/// Map poll id to Interface.
poll_ids: HashMap<usize, Interface>,
Expand Down Expand Up @@ -865,7 +876,7 @@ struct Zeroconf {
counters: Metrics,

/// Waits for incoming packets.
poller: Poller,
poller: Poll,

/// Channels to notify events.
monitors: Vec<Sender<DaemonEvent>>,
Expand All @@ -877,7 +888,7 @@ struct Zeroconf {
if_selections: Vec<IfSelection>,

/// Socket for signaling.
signal_sock: UdpSocket,
signal_sock: MioUdpSocket,

/// Timestamps marking where we need another iteration of the run loop,
/// to react to events like retransmissions, cache refreshes, interface IP address changes, etc.
Expand All @@ -896,7 +907,7 @@ struct Zeroconf {
}

impl Zeroconf {
fn new(signal_sock: UdpSocket, poller: Poller) -> Self {
fn new(signal_sock: MioUdpSocket, poller: Poll) -> Self {
// Get interfaces.
let my_ifaddrs = my_ip_interfaces();

Expand Down Expand Up @@ -1087,8 +1098,8 @@ impl Zeroconf {
}
} else {
// Remove the interface
if let Some(sock) = self.intf_socks.remove(&intf) {
if let Err(e) = self.poller.delete(&sock) {
if let Some(mut sock) = self.intf_socks.remove(&intf) {
if let Err(e) = self.poller.registry().deregister(&mut sock) {
debug!("process_if_selections: poller.delete {:?}: {}", &intf, e);
}
// Remove from poll_ids
Expand All @@ -1108,10 +1119,10 @@ impl Zeroconf {
// Remove unused sockets in the poller.
let deleted_addrs = self
.intf_socks
.iter()
.iter_mut()
.filter_map(|(intf, sock)| {
if !my_ifaddrs.contains(intf) {
if let Err(e) = poller.delete(sock) {
if let Err(e) = poller.registry().deregister(sock) {
debug!("check_ip_changes: poller.delete {:?}: {}", intf, e);
}
// Remove from poll_ids
Expand Down Expand Up @@ -1139,7 +1150,7 @@ impl Zeroconf {
fn add_new_interface(&mut self, intf: Interface) {
// Bind the new interface.
let new_ip = intf.ip();
let sock = match new_socket_bind(&intf) {
let mut sock = match new_socket_bind(&intf) {
Ok(s) => s,
Err(e) => {
debug!("bind a socket to {}: {}. Skipped.", &intf.ip(), e);
Expand All @@ -1149,7 +1160,11 @@ impl Zeroconf {

// Add the new interface into the poller.
let key = self.add_poll(intf.clone());
if let Err(e) = self.poller.add(&sock, polling::Event::readable(key)) {
if let Err(e) =
self.poller
.registry()
.register(&mut sock, mio::Token(key), mio::Interest::READABLE)
{
debug!("check_ip_changes: poller add ip {}: {}", new_ip, e);
return;
}
Expand Down Expand Up @@ -1417,7 +1432,12 @@ impl Zeroconf {
}
}

fn unregister_service(&self, info: &ServiceInfo, intf: &Interface, sock: &Socket) -> Vec<u8> {
fn unregister_service(
&self,
info: &ServiceInfo,
intf: &Interface,
sock: &MioUdpSocket,
) -> Vec<u8> {
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
out.add_answer_at_time(
DnsPointer::new(
Expand Down Expand Up @@ -1559,7 +1579,7 @@ impl Zeroconf {
// be truncated by the socket layer depending on the platform's libc.
// In any case, such large datagram will not be decoded properly and
// this function should return false but should not crash.
let sz = match sock.read(&mut buf) {
let sz = match sock.recv(&mut buf) {
Ok(sz) => sz,
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
Expand All @@ -1573,8 +1593,8 @@ impl Zeroconf {

// If sz is 0, it means sock reached End-of-File.
if sz == 0 {
debug!("socket {:?} was likely shutdown", sock);
if let Err(e) = self.poller.delete(&*sock) {
debug!("socket {:?} was likely shutdown", &sock);
if let Err(e) = self.poller.registry().deregister(sock) {
debug!("failed to remove sock {:?} from poller: {}", sock, &e);
}

Expand Down Expand Up @@ -3019,7 +3039,7 @@ fn my_ip_interfaces() -> Vec<Interface> {
}

/// Send an outgoing mDNS query or response, and returns the packet bytes.
fn send_dns_outgoing(out: &DnsOutgoing, intf: &Interface, sock: &Socket) -> Vec<Vec<u8>> {
fn send_dns_outgoing(out: &DnsOutgoing, intf: &Interface, sock: &MioUdpSocket) -> Vec<Vec<u8>> {
let qtype = if out.is_query() { "query" } else { "response" };
trace!(
"send outgoing {}: {} questions {} answers {} authorities {} additional",
Expand All @@ -3037,7 +3057,7 @@ fn send_dns_outgoing(out: &DnsOutgoing, intf: &Interface, sock: &Socket) -> Vec<
}

/// Sends a multicast packet, and returns the packet bytes.
fn multicast_on_intf(packet: &[u8], intf: &Interface, socket: &Socket) {
fn multicast_on_intf(packet: &[u8], intf: &Interface, socket: &MioUdpSocket) {
if packet.len() > MAX_MSG_ABSOLUTE {
debug!("Drop over-sized packet ({})", packet.len());
return;
Expand All @@ -3056,9 +3076,8 @@ fn multicast_on_intf(packet: &[u8], intf: &Interface, socket: &Socket) {
}

/// Sends out `packet` to `addr` on the socket in `intf_sock`.
fn send_packet(packet: &[u8], addr: SocketAddr, intf: &Interface, sock: &Socket) {
let sockaddr = SockAddr::from(addr);
match sock.send_to(packet, &sockaddr) {
fn send_packet(packet: &[u8], addr: SocketAddr, intf: &Interface, sock: &MioUdpSocket) {
match sock.send_to(packet, addr) {
Ok(sz) => trace!("sent out {} bytes on interface {:?}", sz, intf),
Err(e) => debug!("Failed to send to {} via {:?}: {}", addr, &intf, e),
}
Expand Down Expand Up @@ -3224,7 +3243,7 @@ fn announce_service_on_intf(
dns_registry: &mut DnsRegistry,
info: &ServiceInfo,
intf: &Interface,
sock: &Socket,
sock: &MioUdpSocket,
) -> bool {
if let Some(out) = prepare_announce(info, intf, dns_registry) {
send_dns_outgoing(&out, intf, sock);
Expand Down

0 comments on commit fcd31f3

Please sign in to comment.