diff --git a/Cargo.toml b/Cargo.toml index 0aa783c..e94dbee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ description = "netlink packet types" [dependencies] anyhow = "1.0.31" byteorder = "1.3.2" -netlink-packet-utils = "0.6.0" [dev-dependencies] netlink-packet-route = "0.13.0" diff --git a/src/buffer.rs b/src/buffer.rs index 3801583..b20a415 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,9 +1,8 @@ // SPDX-License-Identifier: MIT use byteorder::{ByteOrder, NativeEndian}; -use netlink_packet_utils::DecodeError; -use crate::{Field, Rest}; +use crate::{DecodeError, ErrorContext, Field, Rest}; const LENGTH: Field = 0..4; const MESSAGE_TYPE: Field = 4..6; @@ -158,7 +157,7 @@ impl> NetlinkBuffer { /// ``` pub fn new_checked(buffer: T) -> Result, DecodeError> { let packet = Self::new(buffer); - packet.check_buffer_length()?; + packet.check_buffer_length().context("invalid netlink buffer length")?; Ok(packet) } @@ -331,7 +330,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> NetlinkBuffer<&'a T> { } } -impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NetlinkBuffer<&'a mut T> { +impl + AsMut<[u8]> + ?Sized> NetlinkBuffer<&mut T> { /// Return a mutable pointer to the payload. /// /// # Panic diff --git a/src/done.rs b/src/done.rs index 1cba4d6..428dec9 100644 --- a/src/done.rs +++ b/src/done.rs @@ -3,9 +3,8 @@ use std::mem::size_of; use byteorder::{ByteOrder, NativeEndian}; -use netlink_packet_utils::DecodeError; -use crate::{Emitable, Field, Parseable, Rest}; +use crate::{DecodeError, Emitable, ErrorContext, Field, Parseable, Rest}; const CODE: Field = 0..4; const EXTENDED_ACK: Rest = 4..; @@ -29,7 +28,9 @@ impl> DoneBuffer { pub fn new_checked(buffer: T) -> Result { let packet = Self::new(buffer); - packet.check_buffer_length()?; + packet + .check_buffer_length() + .context("invalid DoneBuffer length")?; Ok(packet) } @@ -61,7 +62,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> DoneBuffer<&'a T> { } } -impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> DoneBuffer<&'a mut T> { +impl + AsMut<[u8]> + ?Sized> DoneBuffer<&mut T> { /// Return a mutable pointer to the extended ack attributes. pub fn extended_ack_mut(&mut self) -> &mut [u8] { let data = self.buffer.as_mut(); diff --git a/src/error.rs b/src/error.rs index df510c1..4f2f3b6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,7 +3,6 @@ use std::{fmt, io, mem::size_of, num::NonZeroI32}; use byteorder::{ByteOrder, NativeEndian}; -use netlink_packet_utils::DecodeError; use crate::{Emitable, Field, Parseable, Rest}; @@ -11,6 +10,103 @@ const CODE: Field = 0..4; const PAYLOAD: Rest = 4..; const ERROR_HEADER_LEN: usize = PAYLOAD.start; +pub trait ErrorContext { + fn context(self, msg: &str) -> Self; +} + +#[derive(Debug)] +pub struct DecodeError { + msg: String, +} + +impl ErrorContext for DecodeError { + fn context(self, msg: &str) -> Self { + Self { + msg: format!("{} caused by {}", msg, self.msg), + } + } +} + +impl ErrorContext for Result +where + T: Clone, +{ + fn context(self, msg: &str) -> Result { + match self { + Ok(t) => Ok(t), + Err(e) => Err(e.context(msg)), + } + } +} + +impl From<&str> for DecodeError { + fn from(msg: &str) -> Self { + Self { + msg: msg.to_string(), + } + } +} + +impl From for DecodeError { + fn from(msg: String) -> Self { + Self { msg } + } +} + +impl From for DecodeError { + fn from(err: std::string::FromUtf8Error) -> Self { + Self { + msg: format!("Invalid UTF-8 sequence: {}", err), + } + } +} + +impl std::fmt::Display for DecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.msg) + } +} + +impl std::error::Error for DecodeError {} + +impl DecodeError { + pub fn invalid_mac_address(received: usize) -> Self { + Self{ + msg: format!("Invalid MAC address. Expected 6 bytes, received {received} bytes"), + } + } + + pub fn invalid_ip_address(received: usize) -> Self { + Self{ + msg: format!("Invalid IP address. Expected 4 or 16 bytes, received {received} bytes"), + } + } + + pub fn invalid_number(expected: usize, received: usize) -> Self { + Self{ + msg: format!("Invalid number. Expected {expected} bytes, received {received} bytes"), + } + } + + pub fn nla_buffer_too_small(buffer_len: usize, nla_len: usize) -> Self { + Self{ + msg: format!("buffer has length {buffer_len}, but an NLA header is {nla_len} bytes"), + } + } + + pub fn nla_length_mismatch(buffer_len: usize, nla_len: usize) -> Self { + Self{ + msg: format!("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes"), + } + } + + pub fn nla_invalid_length(buffer_len: usize, nla_len: usize) -> Self { + Self{ + msg: format!("NLA has invalid length: {nla_len} (should be at least {buffer_len} bytes)"), + } + } +} + #[derive(Debug, PartialEq, Eq, Clone)] #[non_exhaustive] pub struct ErrorBuffer { @@ -29,18 +125,21 @@ impl> ErrorBuffer { pub fn new_checked(buffer: T) -> Result { let packet = Self::new(buffer); - packet.check_buffer_length()?; + packet + .check_buffer_length() + .context("invalid ErrorBuffer length")?; Ok(packet) } fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.buffer.as_ref().len(); if len < ERROR_HEADER_LEN { - Err(format!( - "invalid ErrorBuffer: length is {len} but ErrorBuffer are \ + Err(DecodeError { + msg: format!( + "invalid ErrorBuffer: length is {len} but ErrorBuffer are \ at least {ERROR_HEADER_LEN} bytes" - ) - .into()) + ), + }) } else { Ok(()) } @@ -65,7 +164,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> ErrorBuffer<&'a T> { } } -impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> ErrorBuffer<&'a mut T> { +impl + AsMut<[u8]> + ?Sized> ErrorBuffer<&mut T> { /// Return a mutable pointer to the payload. pub fn payload_mut(&mut self) -> &mut [u8] { let data = self.buffer.as_mut(); @@ -199,8 +298,7 @@ mod tests { #[test] fn parse_nack() { // SAFETY: value is non-zero. - const ERROR_CODE: NonZeroI32 = - unsafe { NonZeroI32::new_unchecked(-1234) }; + const ERROR_CODE: NonZeroI32 = NonZeroI32::new(-1234).unwrap(); let mut bytes = vec![0, 0, 0, 0]; NativeEndian::write_i32(&mut bytes, ERROR_CODE.get()); let msg = ErrorBuffer::new_checked(&bytes) diff --git a/src/header.rs b/src/header.rs index dd37753..67db7e3 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -use netlink_packet_utils::DecodeError; - -use crate::{buffer::NETLINK_HEADER_LEN, Emitable, NetlinkBuffer, Parseable}; +use crate::{ + buffer::NETLINK_HEADER_LEN, DecodeError, Emitable, NetlinkBuffer, Parseable, +}; /// A Netlink header representation. A netlink header has the following /// structure: @@ -111,7 +111,7 @@ mod tests { port_number: 0, }; assert_eq!(repr.buffer_len(), 16); - let mut buf = vec![0; 16]; + let mut buf = [0; 16]; repr.emit(&mut buf[..]); assert_eq!(&buf[..], &IP_LINK_SHOW_PKT[..16]); } diff --git a/src/lib.rs b/src/lib.rs index f7f1508..910ae27 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,5 +268,11 @@ pub use self::message::*; pub mod constants; pub use self::constants::*; -pub(crate) use self::utils::traits::*; -pub(crate) use netlink_packet_utils as utils; +pub mod nla; +pub use self::nla::*; + +pub mod parsers; +pub use self::parsers::*; + +#[macro_use] +mod macros; diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 0000000..b779194 --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT + +#[macro_export(local_inner_macros)] +macro_rules! getter { + ($buffer: ident, $name:ident, slice, $offset:expr) => { + impl<'a, T: AsRef<[u8]> + ?Sized> $buffer<&'a T> { + pub fn $name(&self) -> &'a [u8] { + &self.buffer.as_ref()[$offset] + } + } + }; + ($buffer: ident, $name:ident, $ty:tt, $offset:expr) => { + impl<'a, T: AsRef<[u8]>> $buffer { + getter!($name, $ty, $offset); + } + }; + ($name:ident, u8, $offset:expr) => { + pub fn $name(&self) -> u8 { + self.buffer.as_ref()[$offset] + } + }; + ($name:ident, u16, $offset:expr) => { + pub fn $name(&self) -> u16 { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::read_u16(&self.buffer.as_ref()[$offset]) + } + }; + ($name:ident, u32, $offset:expr) => { + pub fn $name(&self) -> u32 { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::read_u32(&self.buffer.as_ref()[$offset]) + } + }; + ($name:ident, u64, $offset:expr) => { + pub fn $name(&self) -> u64 { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::read_u64(&self.buffer.as_ref()[$offset]) + } + }; + ($name:ident, i8, $offset:expr) => { + pub fn $name(&self) -> i8 { + self.buffer.as_ref()[$offset] + } + }; + ($name:ident, i16, $offset:expr) => { + pub fn $name(&self) -> i16 { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::read_i16(&self.buffer.as_ref()[$offset]) + } + }; + ($name:ident, i32, $offset:expr) => { + pub fn $name(&self) -> i32 { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::read_i32(&self.buffer.as_ref()[$offset]) + } + }; + ($name:ident, i64, $offset:expr) => { + pub fn $name(&self) -> i64 { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::read_i64(&self.buffer.as_ref()[$offset]) + } + }; +} + +#[macro_export(local_inner_macros)] +macro_rules! setter { + ($buffer: ident, $name:ident, slice, $offset:expr) => { + impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> $buffer<&'a mut T> { + $crate::pastey::item! { + pub fn [<$name _mut>](&mut self) -> &mut [u8] { + &mut self.buffer.as_mut()[$offset] + } + } + } + }; + ($buffer: ident, $name:ident, $ty:tt, $offset:expr) => { + impl<'a, T: AsRef<[u8]> + AsMut<[u8]>> $buffer { + setter!($name, $ty, $offset); + } + }; + ($name:ident, u8, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: u8) { + self.buffer.as_mut()[$offset] = value; + } + } + }; + ($name:ident, u16, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: u16) { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::write_u16(&mut self.buffer.as_mut()[$offset], value) + } + } + }; + ($name:ident, u32, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: u32) { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::write_u32(&mut self.buffer.as_mut()[$offset], value) + } + } + }; + ($name:ident, u64, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: u64) { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::write_u64(&mut self.buffer.as_mut()[$offset], value) + } + } + }; + ($name:ident, i8, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: i8) { + self.buffer.as_mut()[$offset] = value; + } + } + }; + ($name:ident, i16, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: i16) { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::write_i16(&mut self.buffer.as_mut()[$offset], value) + } + } + }; + ($name:ident, i32, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: i32) { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::write_i32(&mut self.buffer.as_mut()[$offset], value) + } + } + }; + ($name:ident, i64, $offset:expr) => { + $crate::pastey::item! { + pub fn [](&mut self, value: i64) { + use $crate::byteorder::{ByteOrder, NativeEndian}; + NativeEndian::write_i64(&mut self.buffer.as_mut()[$offset], value) + } + } + }; +} + +#[macro_export(local_inner_macros)] +macro_rules! buffer { + ($name:ident($buffer_len:expr) { $($field:ident : ($ty:tt, $offset:expr)),* $(,)? }) => { + buffer!($name { $($field: ($ty, $offset),)* }); + buffer_check_length!($name($buffer_len)); + }; + + ($name:ident { $($field:ident : ($ty:tt, $offset:expr)),* $(,)? }) => { + buffer_common!($name); + fields!($name { + $($field: ($ty, $offset),)* + }); + }; + + ($name:ident, $buffer_len:expr) => { + buffer_common!($name); + buffer_check_length!($name($buffer_len)); + }; + + ($name:ident) => { + buffer_common!($name); + }; +} + +#[macro_export(local_inner_macros)] +macro_rules! fields { + ($buffer:ident { $($name:ident : ($ty:tt, $offset:expr)),* $(,)? }) => { + $( + getter!($buffer, $name, $ty, $offset); + )* + + $( + setter!($buffer, $name, $ty, $offset); + )* + } +} + +#[macro_export] +macro_rules! buffer_check_length { + ($name:ident($buffer_len:expr)) => { + impl> $name { + pub fn new_checked(buffer: T) -> Result { + let packet = Self::new(buffer); + packet.check_buffer_length()?; + Ok(packet) + } + + fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.buffer.as_ref().len(); + if len < $buffer_len { + Err(DecodeError::InvalidBuffer { + name: stringify!($name), + received: len, + minimum_length: $buffer_len, + }) + } else { + Ok(()) + } + } + } + }; +} + +#[macro_export] +macro_rules! buffer_common { + ($name:ident) => { + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub struct $name { + buffer: T, + } + + impl> $name { + pub fn new(buffer: T) -> Self { + Self { buffer } + } + + pub fn into_inner(self) -> T { + self.buffer + } + } + + impl<'a, T: AsRef<[u8]> + ?Sized> $name<&'a T> { + pub fn inner(&self) -> &'a [u8] { + &self.buffer.as_ref()[..] + } + } + + impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> $name<&'a mut T> { + pub fn inner_mut(&mut self) -> &mut [u8] { + &mut self.buffer.as_mut()[..] + } + } + }; +} diff --git a/src/message.rs b/src/message.rs index b3b0541..73236e2 100644 --- a/src/message.rs +++ b/src/message.rs @@ -2,13 +2,11 @@ use std::fmt::Debug; -use netlink_packet_utils::DecodeError; - use crate::{ payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN}, - DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, - NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload, - NetlinkSerializable, Parseable, + DecodeError, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorContext, + ErrorMessage, NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, + NetlinkPayload, NetlinkSerializable, Parseable, }; /// Represent a netlink message. @@ -39,7 +37,8 @@ where { /// Parse the given buffer as a netlink message pub fn deserialize(buffer: &[u8]) -> Result { - let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?; + let netlink_buffer = NetlinkBuffer::new_checked(&buffer) + .context("failed deserializing NetlinkMessage")?; >>::parse(&netlink_buffer) } } @@ -93,27 +92,31 @@ where use self::NetlinkPayload::*; let header = - >>::parse(buf)?; + >>::parse(buf) + .context("failed parsing NetlinkHeader")?; let bytes = buf.payload(); let payload = match header.message_type { NLMSG_ERROR => { let msg = ErrorBuffer::new_checked(&bytes) - .and_then(|buf| ErrorMessage::parse(&buf))?; + .and_then(|buf| ErrorMessage::parse(&buf)) + .context("failed parsing NLMSG_ERROR")?; Error(msg) } NLMSG_NOOP => Noop, NLMSG_DONE => { let msg = DoneBuffer::new_checked(&bytes) - .and_then(|buf| DoneMessage::parse(&buf))?; + .and_then(|buf| DoneMessage::parse(&buf)) + .context("failed parsing NLMSG_DONE")?; Done(msg) } NLMSG_OVERRUN => Overrun(bytes.to_vec()), message_type => match I::deserialize(&header, bytes) { Err(e) => { - return Err(DecodeError::Other( - format!("Failed to parse message with type {message_type}: {e}").into()), - ); + return Err(format!( + "Failed to parse message with type {message_type}: {e}" + ) + .into()) } Ok(inner_msg) => InnerMessage(inner_msg), }, @@ -238,8 +241,7 @@ mod tests { #[test] fn test_error() { // SAFETY: value is non-zero. - const ERROR_CODE: NonZeroI32 = - unsafe { NonZeroI32::new_unchecked(-8765) }; + const ERROR_CODE: NonZeroI32 = NonZeroI32::new(-8765).unwrap(); let header = NetlinkHeader::default(); let error_msg = ErrorMessage { diff --git a/src/nla.rs b/src/nla.rs new file mode 100644 index 0000000..a438c89 --- /dev/null +++ b/src/nla.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: MIT + +use crate::{ + traits::{Emitable, Parseable}, + DecodeError, +}; +use byteorder::{ByteOrder, NativeEndian}; +use core::ops::Range; + +/// Represent a multi-bytes field with a fixed size in a packet +type Field = Range; + +/// Identify the bits that represent the "nested" flag of a netlink attribute. +pub const NLA_F_NESTED: u16 = 0x8000; +/// Identify the bits that represent the "byte order" flag of a netlink +/// attribute. +pub const NLA_F_NET_BYTEORDER: u16 = 0x4000; +/// Identify the bits that represent the type of a netlink attribute. +pub const NLA_TYPE_MASK: u16 = !(NLA_F_NET_BYTEORDER | NLA_F_NESTED); +/// NlA(RTA) align size +pub const NLA_ALIGNTO: usize = 4; +/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type) +pub const NLA_HEADER_SIZE: usize = 4; + +#[macro_export] +macro_rules! nla_align { + ($len: expr) => { + ($len + NLA_ALIGNTO - 1) & !(NLA_ALIGNTO - 1) + }; +} + +const LENGTH: Field = 0..2; +const TYPE: Field = 2..4; +#[allow(non_snake_case)] +fn VALUE(length: usize) -> Field { + TYPE.end..TYPE.end + length +} + +// with Copy, NlaBuffer<&'buffer T> can be copied, which turns out to be pretty +// conveninent. And since it's boils down to copying a reference it's pretty +// cheap +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct NlaBuffer> { + buffer: T, +} + +impl> NlaBuffer { + pub fn new(buffer: T) -> NlaBuffer { + NlaBuffer { buffer } + } + + pub fn new_checked(buffer: T) -> Result, DecodeError> { + let buffer = Self::new(buffer); + buffer.check_buffer_length()?; + Ok(buffer) + } + + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.buffer.as_ref().len(); + if len < TYPE.end { + Err(DecodeError::nla_buffer_too_small(len, TYPE.end)) + } else if len < self.length() as usize { + Err(DecodeError::nla_length_mismatch( + len, + self.length() as usize, + )) + } else if (self.length() as usize) < TYPE.end { + Err(DecodeError::nla_invalid_length(len, self.length() as usize)) + } else { + Ok(()) + } + } + + /// Consume the buffer, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return a reference to the underlying buffer + pub fn inner(&mut self) -> &T { + &self.buffer + } + + /// Return a mutable reference to the underlying buffer + pub fn inner_mut(&mut self) -> &mut T { + &mut self.buffer + } + + /// Return the `type` field + pub fn kind(&self) -> u16 { + let data = self.buffer.as_ref(); + NativeEndian::read_u16(&data[TYPE]) & NLA_TYPE_MASK + } + + pub fn nested_flag(&self) -> bool { + let data = self.buffer.as_ref(); + (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NESTED) != 0 + } + + pub fn network_byte_order_flag(&self) -> bool { + let data = self.buffer.as_ref(); + (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NET_BYTEORDER) != 0 + } + + /// Return the `length` field. The `length` field corresponds to the length + /// of the nla header (type and length fields, and the value field). + /// However, it does not account for the potential padding that follows + /// the value field. + pub fn length(&self) -> u16 { + let data = self.buffer.as_ref(); + NativeEndian::read_u16(&data[LENGTH]) + } + + /// Return the length of the `value` field + /// + /// # Panic + /// + /// This panics if the length field value is less than the attribut header + /// size. + pub fn value_length(&self) -> usize { + self.length() as usize - TYPE.end + } +} + +impl + AsMut<[u8]>> NlaBuffer { + /// Set the `type` field + pub fn set_kind(&mut self, kind: u16) { + let data = self.buffer.as_mut(); + NativeEndian::write_u16(&mut data[TYPE], kind & NLA_TYPE_MASK) + } + + pub fn set_nested_flag(&mut self) { + let kind = self.kind(); + let data = self.buffer.as_mut(); + NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NESTED) + } + + pub fn set_network_byte_order_flag(&mut self) { + let kind = self.kind(); + let data = self.buffer.as_mut(); + NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NET_BYTEORDER) + } + + /// Set the `length` field + pub fn set_length(&mut self, length: u16) { + let data = self.buffer.as_mut(); + NativeEndian::write_u16(&mut data[LENGTH], length) + } +} + +impl + ?Sized> NlaBuffer<&T> { + /// Return the `value` field + pub fn value(&self) -> &[u8] { + &self.buffer.as_ref()[VALUE(self.value_length())] + } +} + +impl + AsMut<[u8]> + ?Sized> NlaBuffer<&mut T> { + /// Return the `value` field + pub fn value_mut(&mut self) -> &mut [u8] { + let length = VALUE(self.value_length()); + &mut self.buffer.as_mut()[length] + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DefaultNla { + kind: u16, + value: Vec, +} + +impl DefaultNla { + pub fn new(kind: u16, value: Vec) -> Self { + Self { kind, value } + } +} + +impl Nla for DefaultNla { + fn value_len(&self) -> usize { + self.value.len() + } + fn kind(&self) -> u16 { + self.kind + } + fn emit_value(&self, buffer: &mut [u8]) { + buffer.copy_from_slice(self.value.as_slice()); + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for DefaultNla +{ + type Error = DecodeError; + + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + let mut kind = buf.kind(); + + if buf.network_byte_order_flag() { + kind |= NLA_F_NET_BYTEORDER; + } + + if buf.nested_flag() { + kind |= NLA_F_NESTED; + } + + Ok(DefaultNla { + kind, + value: buf.value().to_vec(), + }) + } +} + +pub trait Nla { + fn value_len(&self) -> usize; + fn kind(&self) -> u16; + fn emit_value(&self, buffer: &mut [u8]); + + #[inline] + fn is_nested(&self) -> bool { + (self.kind() & NLA_F_NESTED) != 0 + } + + #[inline] + fn is_network_byteorder(&self) -> bool { + (self.kind() & NLA_F_NET_BYTEORDER) != 0 + } +} + +impl Emitable for T { + fn buffer_len(&self) -> usize { + nla_align!(self.value_len()) + NLA_HEADER_SIZE + } + fn emit(&self, buffer: &mut [u8]) { + let mut buffer = NlaBuffer::new(buffer); + buffer.set_kind(self.kind()); + + if self.is_network_byteorder() { + buffer.set_network_byte_order_flag() + } + + if self.is_nested() { + buffer.set_nested_flag() + } + + // do not include the padding here, but do include the header + buffer.set_length(self.value_len() as u16 + NLA_HEADER_SIZE as u16); + + self.emit_value(buffer.value_mut()); + + let padding = nla_align!(self.value_len()) - self.value_len(); + for i in 0..padding { + buffer.inner_mut()[NLA_HEADER_SIZE + self.value_len() + i] = 0; + } + } +} + +// FIXME: whern specialization lands, why can actually have +// +// impl<'a, T: Nla, I: Iterator> Emitable for I { ...} +// +// The reason this does not work today is because it conflicts with +// +// impl Emitable for T { ... } +impl Emitable for &[T] { + fn buffer_len(&self) -> usize { + self.iter().fold(0, |acc, nla| { + assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0); + acc + nla.buffer_len() + }) + } + + fn emit(&self, buffer: &mut [u8]) { + let mut start = 0; + let mut end: usize; + for nla in self.iter() { + let attr_len = nla.buffer_len(); + assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0); + end = start + attr_len; + nla.emit(&mut buffer[start..end]); + start = end; + } + } +} + +/// An iterator that iteratates over nlas without decoding them. This is useful +/// when looking for specific nlas. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NlasIterator { + position: usize, + buffer: T, +} + +impl NlasIterator { + pub fn new(buffer: T) -> Self { + NlasIterator { + position: 0, + buffer, + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator + for NlasIterator<&'buffer T> +{ + type Item = Result, DecodeError>; + + fn next(&mut self) -> Option { + if self.position >= self.buffer.as_ref().len() { + return None; + } + + match NlaBuffer::new_checked(&self.buffer.as_ref()[self.position..]) { + Ok(nla_buffer) => { + self.position += nla_align!(nla_buffer.length() as usize); + Some(Ok(nla_buffer)) + } + Err(e) => { + // Make sure next time we call `next()`, we return None. We + // don't try to continue iterating after we + // failed to return a buffer. + self.position = self.buffer.as_ref().len(); + Some(Err(e)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn network_byteorder() { + // The IPSET_ATTR_TIMEOUT attribute should have the network byte order + // flag set. IPSET_ATTR_TIMEOUT(3600) + static TEST_ATTRIBUTE: &[u8] = + &[0x08, 0x00, 0x06, 0x40, 0x00, 0x00, 0x0e, 0x10]; + let buffer = NlaBuffer::new(TEST_ATTRIBUTE); + let buffer_is_net = buffer.network_byte_order_flag(); + let buffer_is_nest = buffer.nested_flag(); + + let nla = DefaultNla::parse(&buffer).unwrap(); + let mut emitted_buffer = vec![0; nla.buffer_len()]; + + nla.emit(&mut emitted_buffer); + + let attr_is_net = nla.is_network_byteorder(); + let attr_is_nest = nla.is_nested(); + + let emit = NlaBuffer::new(emitted_buffer); + let emit_is_net = emit.network_byte_order_flag(); + let emit_is_nest = emit.nested_flag(); + + assert_eq!( + [buffer_is_net, buffer_is_nest], + [attr_is_net, attr_is_nest] + ); + assert_eq!([attr_is_net, attr_is_nest], [emit_is_net, emit_is_nest]); + } + + fn get_len() -> usize { + // usize::MAX + 18446744073709551615 + } + + #[test] + fn test_align() { + assert_eq!(nla_align!(13), 16); + assert_eq!(nla_align!(16), 16); + assert_eq!(nla_align!(0), 0); + assert_eq!(nla_align!(1), 4); + assert_eq!(nla_align!(get_len() - 4), usize::MAX - 3); + } + #[test] + #[should_panic] + fn test_align_overflow() { + assert_eq!(nla_align!(get_len() - 3), usize::MAX); + } +} diff --git a/src/parsers.rs b/src/parsers.rs new file mode 100644 index 0000000..8e09a5e --- /dev/null +++ b/src/parsers.rs @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT + +use std::{ + mem::size_of, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, +}; + +use byteorder::{BigEndian, ByteOrder, NativeEndian}; + +use crate::DecodeError; + +pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> { + if payload.len() != 6 { + return Err(DecodeError::invalid_mac_address(payload.len())); + } + let mut address: [u8; 6] = [0; 6]; + for (i, byte) in payload.iter().enumerate() { + address[i] = *byte; + } + Ok(address) +} + +pub fn parse_ipv6(payload: &[u8]) -> Result<[u8; 16], DecodeError> { + if payload.len() != 16 { + return Err(DecodeError::invalid_ip_address(payload.len())); + } + let mut address: [u8; 16] = [0; 16]; + for (i, byte) in payload.iter().enumerate() { + address[i] = *byte; + } + Ok(address) +} + +pub fn parse_ip(payload: &[u8]) -> Result { + match payload.len() { + 4 => Ok( + Ipv4Addr::new(payload[0], payload[1], payload[2], payload[3]) + .into(), + ), + 16 => Ok(Ipv6Addr::from([ + payload[0], + payload[1], + payload[2], + payload[3], + payload[4], + payload[5], + payload[6], + payload[7], + payload[8], + payload[9], + payload[10], + payload[11], + payload[12], + payload[13], + payload[14], + payload[15], + ]) + .into()), + other => Err(DecodeError::invalid_ip_address(other)), + } +} + +pub fn parse_string(payload: &[u8]) -> Result { + if payload.is_empty() { + return Ok(String::new()); + } + // iproute2 is a bit inconsistent with null-terminated strings. + let slice = if payload[payload.len() - 1] == 0 { + &payload[..payload.len() - 1] + } else { + &payload[..payload.len()] + }; + let s = String::from_utf8(slice.to_vec())?; + Ok(s) +} + +pub fn parse_u8(payload: &[u8]) -> Result { + if payload.len() != 1 { + return Err(DecodeError::invalid_number(1, payload.len())); + } + Ok(payload[0]) +} + +pub fn parse_i8(payload: &[u8]) -> Result { + if payload.len() != 1 { + return Err(DecodeError::invalid_number(1, payload.len())); + } + Ok(payload[0] as i8) +} + +pub fn parse_u32(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(DecodeError::invalid_number( + size_of::(), + payload.len(), + )); + } + Ok(NativeEndian::read_u32(payload)) +} + +pub fn parse_u64(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(DecodeError::invalid_number( + size_of::(), + payload.len(), + )); + } + Ok(NativeEndian::read_u64(payload)) +} +pub fn parse_u128(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(DecodeError::invalid_number( + size_of::(), + payload.len(), + )); + } + Ok(NativeEndian::read_u128(payload)) +} + +pub fn parse_u16(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(DecodeError::invalid_number( + size_of::(), + payload.len(), + )); + } + Ok(NativeEndian::read_u16(payload)) +} + +pub fn parse_i32(payload: &[u8]) -> Result { + if payload.len() != 4 { + return Err(DecodeError::invalid_number(4, payload.len())); + } + Ok(NativeEndian::read_i32(payload)) +} + +pub fn parse_i64(payload: &[u8]) -> Result { + if payload.len() != 8 { + return Err(format!("invalid i64: {payload:?}").into()); + } + Ok(NativeEndian::read_i64(payload)) +} + +pub fn parse_u16_be(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(DecodeError::invalid_number( + size_of::(), + payload.len(), + )); + } + Ok(BigEndian::read_u16(payload)) +} + +pub fn parse_u32_be(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(DecodeError::invalid_number( + size_of::(), + payload.len(), + )); + } + Ok(BigEndian::read_u32(payload)) +} diff --git a/src/traits.rs b/src/traits.rs index dc55331..4bd09cb 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -41,3 +41,46 @@ pub trait NetlinkSerializable { /// This method panics if the buffer is not big enough. fn serialize(&self, buffer: &mut [u8]); } + +/// A type that implements `Emitable` can be serialized. +pub trait Emitable { + /// Return the length of the serialized data. + fn buffer_len(&self) -> usize; + + /// Serialize this types and write the serialized data into the given + /// buffer. + /// + /// # Panic + /// + /// This method panic if the buffer is not big enough. You **must** make + /// sure the buffer is big enough before calling this method. You can + /// use [`buffer_len()`](trait.Emitable.html#method.buffer_len) to check + /// how big the storage needs to be. + fn emit(&self, buffer: &mut [u8]); +} + +/// A `Parseable` type can be used to deserialize data from the type `T` for +/// which it is implemented. +pub trait Parseable +where + Self: Sized, + T: ?Sized, +{ + type Error; + + /// Deserialize the current type. + fn parse(buf: &T) -> Result; +} + +/// A `Parseable` type can be used to deserialize data from the type `T` for +/// which it is implemented. +pub trait ParseableParametrized +where + Self: Sized, + T: ?Sized, +{ + type Error; + + /// Deserialize the current type. + fn parse_with_param(buf: &T, params: P) -> Result; +}