Skip to content

Fix done message parse #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 32 additions & 85 deletions src/done.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// SPDX-License-Identifier: MIT

use std::mem::size_of;

use byteorder::{ByteOrder, NativeEndian};
use netlink_packet_utils::DecodeError;

Expand All @@ -11,99 +9,52 @@ const CODE: Field = 0..4;
const EXTENDED_ACK: Rest = 4..;
const DONE_HEADER_LEN: usize = EXTENDED_ACK.start;

#[derive(Debug, PartialEq, Eq, Clone)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct DoneBuffer<T> {
buffer: T,
pub struct DoneMessage {
pub payload: Vec<u8>,
}

impl<T: AsRef<[u8]>> DoneBuffer<T> {
pub fn new(buffer: T) -> DoneBuffer<T> {
DoneBuffer { buffer }
}

/// Consume the packet, returning the underlying buffer.
pub fn into_inner(self) -> T {
self.buffer
}

pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
let packet = Self::new(buffer);
packet.check_buffer_length()?;
Ok(packet)
}

fn check_buffer_length(&self) -> Result<(), DecodeError> {
let len = self.buffer.as_ref().len();
if len < DONE_HEADER_LEN {
Err(format!(
"invalid DoneBuffer: length is {len} but DoneBuffer are \
at least {DONE_HEADER_LEN} bytes"
)
.into())
impl DoneMessage {
pub fn code(&self) -> Option<i32> {
if self.payload.len() < DONE_HEADER_LEN {
None
} else {
Ok(())
Some(NativeEndian::read_i32(&self.payload[CODE]))
}
}

/// Return the error code
pub fn code(&self) -> i32 {
let data = self.buffer.as_ref();
NativeEndian::read_i32(&data[CODE])
}
}

impl<'a, T: AsRef<[u8]> + ?Sized> DoneBuffer<&'a T> {
/// Return a pointer to the extended ack attributes.
pub fn extended_ack(&self) -> &'a [u8] {
let data = self.buffer.as_ref();
&data[EXTENDED_ACK]
}
}

impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> DoneBuffer<&'a mut T> {
/// Return a mutable pointer to the extended ack attributes.
pub fn extended_ack_mut(&mut self) -> &mut [u8] {
let data = self.buffer.as_mut();
&mut data[EXTENDED_ACK]
pub fn extended_ack(&self) -> Option<&[u8]> {
if self.payload.len() < DONE_HEADER_LEN {
None
} else {
Some(&self.payload[EXTENDED_ACK])
}
}
}

impl<T: AsRef<[u8]> + AsMut<[u8]>> DoneBuffer<T> {
/// set the error code field
pub fn set_code(&mut self, value: i32) {
let data = self.buffer.as_mut();
NativeEndian::write_i32(&mut data[CODE], value)
pub fn new_with_code<T: AsRef<[u8]>>(code: i32, extend_ack: &T) -> Self {
let mut payload = vec![0; DONE_HEADER_LEN + extend_ack.as_ref().len()];
NativeEndian::write_i32(&mut payload, code);
payload[CODE.end..].copy_from_slice(extend_ack.as_ref());
Self { payload }
}
}

#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct DoneMessage {
pub code: i32,
pub extended_ack: Vec<u8>,
}

impl Emitable for DoneMessage {
fn buffer_len(&self) -> usize {
size_of::<i32>() + self.extended_ack.len()
self.payload.len()
}
fn emit(&self, buffer: &mut [u8]) {
let mut buffer = DoneBuffer::new(buffer);
buffer.set_code(self.code);
buffer
.extended_ack_mut()
.copy_from_slice(&self.extended_ack);
buffer.copy_from_slice(&self.payload);
}
}

impl<T: AsRef<[u8]>> Parseable<DoneBuffer<&T>> for DoneMessage {
impl<T: AsRef<[u8]>> Parseable<T> for DoneMessage {
type Error = DecodeError;

fn parse(buf: &DoneBuffer<&T>) -> Result<DoneMessage, Self::Error> {
fn parse(buf: &T) -> Result<DoneMessage, Self::Error> {
Ok(DoneMessage {
code: buf.code(),
extended_ack: buf.extended_ack().to_vec(),
payload: buf.as_ref().to_vec(),
})
}
}
Expand All @@ -114,22 +65,18 @@ mod tests {

#[test]
fn serialize_and_parse() {
let expected = DoneMessage {
code: 5,
extended_ack: vec![1, 2, 3],
};

let expected = DoneMessage::new_with_code(5, &[1, 2, 3]);
let len = expected.buffer_len();
assert_eq!(len, size_of::<i32>() + expected.extended_ack.len());
assert_eq!(
len,
size_of::<i32>() + expected.extended_ack().unwrap().len()
);

let mut buf = vec![0; len];
expected.emit(&mut buf);

let done_buf = DoneBuffer::new(&buf);
assert_eq!(done_buf.code(), expected.code);
assert_eq!(done_buf.extended_ack(), &expected.extended_ack);

let got = DoneMessage::parse(&done_buf).unwrap();
assert_eq!(got, expected);
let got = DoneMessage::parse(&buf);
assert!(got.is_ok());
assert_eq!(got.unwrap(), expected);
}
}
29 changes: 12 additions & 17 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use netlink_packet_utils::DecodeError;

use crate::{
payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
NetlinkSerializable, Parseable,
DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable,
Parseable, NLM_F_MULTIPART,
};

/// Represent a netlink message.
Expand Down Expand Up @@ -103,10 +103,11 @@ where
Error(msg)
}
NLMSG_NOOP => Noop,
NLMSG_DONE => {
let msg = DoneBuffer::new_checked(&bytes)
.and_then(|buf| DoneMessage::parse(&buf))?;
Done(msg)
// only parse message_type of NLMSG_DONE when flag has
// NLM_F_MULTIPART because some special netlink like
// connector use NLMSG_DONE for all the message
NLMSG_DONE if header.flags & NLM_F_MULTIPART == NLM_F_MULTIPART => {
Done(DoneMessage::parse(&bytes)?)
}
NLMSG_OVERRUN => Overrun(bytes.to_vec()),
message_type => match I::deserialize(&header, bytes) {
Expand Down Expand Up @@ -205,11 +206,9 @@ mod tests {

#[test]
fn test_done() {
let header = NetlinkHeader::default();
let done_msg = DoneMessage {
code: 0,
extended_ack: vec![6, 7, 8, 9],
};
let mut header = NetlinkHeader::default();
header.flags |= NLM_F_MULTIPART;
let done_msg = DoneMessage::new_with_code(0, &[6, 7, 8, 9]);
let mut want = NetlinkMessage::new(
header,
NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
Expand All @@ -221,16 +220,12 @@ mod tests {
len,
header.buffer_len()
+ size_of::<i32>()
+ done_msg.extended_ack.len()
+ done_msg.extended_ack().unwrap().len()
);

let mut buf = vec![1; len];
want.emit(&mut buf);

let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
assert_eq!(done_buf.code(), done_msg.code);
assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);

let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
assert_eq!(got, want);
}
Expand Down
Loading