diff --git a/src/done.rs b/src/done.rs index 1cba4d6..c8b5e06 100644 --- a/src/done.rs +++ b/src/done.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MIT -use std::mem::size_of; - use byteorder::{ByteOrder, NativeEndian}; use netlink_packet_utils::DecodeError; @@ -11,99 +9,52 @@ const CODE: Field = 0..4; const EXTENDED_ACK: Rest = 4..; const DONE_HEADER_LEN: usize = EXTENDED_ACK.start; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] #[non_exhaustive] -pub struct DoneBuffer { - buffer: T, +pub struct DoneMessage { + pub payload: Vec, } -impl> DoneBuffer { - pub fn new(buffer: T) -> DoneBuffer { - DoneBuffer { buffer } - } - - /// Consume the packet, returning the underlying buffer. - pub fn into_inner(self) -> T { - self.buffer - } - - pub fn new_checked(buffer: T) -> Result { - let packet = Self::new(buffer); - packet.check_buffer_length()?; - Ok(packet) - } - - fn check_buffer_length(&self) -> Result<(), DecodeError> { - let len = self.buffer.as_ref().len(); - if len < DONE_HEADER_LEN { - Err(format!( - "invalid DoneBuffer: length is {len} but DoneBuffer are \ - at least {DONE_HEADER_LEN} bytes" - ) - .into()) +impl DoneMessage { + pub fn code(&self) -> Option { + if self.payload.len() < DONE_HEADER_LEN { + None } else { - Ok(()) + Some(NativeEndian::read_i32(&self.payload[CODE])) } } - /// Return the error code - pub fn code(&self) -> i32 { - let data = self.buffer.as_ref(); - NativeEndian::read_i32(&data[CODE]) - } -} - -impl<'a, T: AsRef<[u8]> + ?Sized> DoneBuffer<&'a T> { - /// Return a pointer to the extended ack attributes. - pub fn extended_ack(&self) -> &'a [u8] { - let data = self.buffer.as_ref(); - &data[EXTENDED_ACK] - } -} - -impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> DoneBuffer<&'a mut T> { - /// Return a mutable pointer to the extended ack attributes. - pub fn extended_ack_mut(&mut self) -> &mut [u8] { - let data = self.buffer.as_mut(); - &mut data[EXTENDED_ACK] + pub fn extended_ack(&self) -> Option<&[u8]> { + if self.payload.len() < DONE_HEADER_LEN { + None + } else { + Some(&self.payload[EXTENDED_ACK]) + } } -} -impl + AsMut<[u8]>> DoneBuffer { - /// set the error code field - pub fn set_code(&mut self, value: i32) { - let data = self.buffer.as_mut(); - NativeEndian::write_i32(&mut data[CODE], value) + pub fn new_with_code>(code: i32, extend_ack: &T) -> Self { + let mut payload = vec![0; DONE_HEADER_LEN + extend_ack.as_ref().len()]; + NativeEndian::write_i32(&mut payload, code); + payload[CODE.end..].copy_from_slice(extend_ack.as_ref()); + Self { payload } } } -#[derive(Debug, Default, Clone, PartialEq, Eq)] -#[non_exhaustive] -pub struct DoneMessage { - pub code: i32, - pub extended_ack: Vec, -} - impl Emitable for DoneMessage { fn buffer_len(&self) -> usize { - size_of::() + self.extended_ack.len() + self.payload.len() } fn emit(&self, buffer: &mut [u8]) { - let mut buffer = DoneBuffer::new(buffer); - buffer.set_code(self.code); - buffer - .extended_ack_mut() - .copy_from_slice(&self.extended_ack); + buffer.copy_from_slice(&self.payload); } } -impl> Parseable> for DoneMessage { +impl> Parseable for DoneMessage { type Error = DecodeError; - fn parse(buf: &DoneBuffer<&T>) -> Result { + fn parse(buf: &T) -> Result { Ok(DoneMessage { - code: buf.code(), - extended_ack: buf.extended_ack().to_vec(), + payload: buf.as_ref().to_vec(), }) } } @@ -114,22 +65,18 @@ mod tests { #[test] fn serialize_and_parse() { - let expected = DoneMessage { - code: 5, - extended_ack: vec![1, 2, 3], - }; - + let expected = DoneMessage::new_with_code(5, &[1, 2, 3]); let len = expected.buffer_len(); - assert_eq!(len, size_of::() + expected.extended_ack.len()); + assert_eq!( + len, + size_of::() + expected.extended_ack().unwrap().len() + ); let mut buf = vec![0; len]; expected.emit(&mut buf); - let done_buf = DoneBuffer::new(&buf); - assert_eq!(done_buf.code(), expected.code); - assert_eq!(done_buf.extended_ack(), &expected.extended_ack); - - let got = DoneMessage::parse(&done_buf).unwrap(); - assert_eq!(got, expected); + let got = DoneMessage::parse(&buf); + assert!(got.is_ok()); + assert_eq!(got.unwrap(), expected); } } diff --git a/src/message.rs b/src/message.rs index b3b0541..f6cbe03 100644 --- a/src/message.rs +++ b/src/message.rs @@ -6,9 +6,9 @@ use netlink_packet_utils::DecodeError; use crate::{ payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN}, - DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, - NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload, - NetlinkSerializable, Parseable, + DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer, + NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, + Parseable, NLM_F_MULTIPART, }; /// Represent a netlink message. @@ -103,10 +103,11 @@ where Error(msg) } NLMSG_NOOP => Noop, - NLMSG_DONE => { - let msg = DoneBuffer::new_checked(&bytes) - .and_then(|buf| DoneMessage::parse(&buf))?; - Done(msg) + // only parse message_type of NLMSG_DONE when flag has + // NLM_F_MULTIPART because some special netlink like + // connector use NLMSG_DONE for all the message + NLMSG_DONE if header.flags & NLM_F_MULTIPART == NLM_F_MULTIPART => { + Done(DoneMessage::parse(&bytes)?) } NLMSG_OVERRUN => Overrun(bytes.to_vec()), message_type => match I::deserialize(&header, bytes) { @@ -205,11 +206,9 @@ mod tests { #[test] fn test_done() { - let header = NetlinkHeader::default(); - let done_msg = DoneMessage { - code: 0, - extended_ack: vec![6, 7, 8, 9], - }; + let mut header = NetlinkHeader::default(); + header.flags |= NLM_F_MULTIPART; + let done_msg = DoneMessage::new_with_code(0, &[6, 7, 8, 9]); let mut want = NetlinkMessage::new( header, NetlinkPayload::::Done(done_msg.clone()), @@ -221,16 +220,12 @@ mod tests { len, header.buffer_len() + size_of::() - + done_msg.extended_ack.len() + + done_msg.extended_ack().unwrap().len() ); let mut buf = vec![1; len]; want.emit(&mut buf); - let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]); - assert_eq!(done_buf.code(), done_msg.code); - assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack); - let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap(); assert_eq!(got, want); }