From 6bbc3b65e6b19212c4f7fc4f40c20daf6f452deb Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Thu, 19 Sep 2024 08:44:16 +0800 Subject: [PATCH] fix: Use After Free in PacketReader (#67) * refactor: use `Bytes` for buffer * fix: deal with remain bytes * refactor: remove unused clone * chore: remove another unused clone * depend: relax `bytes` version require * chore: change it back&update some comment * test: add test for various packet size * tests: async read --- mysql/Cargo.toml | 1 + mysql/src/packet_reader.rs | 428 ++++++++++++++++++++++++++----------- mysql/src/tests/packet.rs | 10 +- 3 files changed, 315 insertions(+), 124 deletions(-) diff --git a/mysql/Cargo.toml b/mysql/Cargo.toml index 1c5699f..cfed920 100644 --- a/mysql/Cargo.toml +++ b/mysql/Cargo.toml @@ -20,6 +20,7 @@ tls = ["tokio-rustls", "pin-project-lite"] [dependencies] async-trait = "0.1.52" byteorder = "1.4.3" +bytes = "1.7.0" chrono = "0.4.20" mysql_common = { version = "0.32.0", features = ["chrono"] } nom = "7.1.0" diff --git a/mysql/src/packet_reader.rs b/mysql/src/packet_reader.rs index b8e0b6f..815d2af 100644 --- a/mysql/src/packet_reader.rs +++ b/mysql/src/packet_reader.rs @@ -15,25 +15,68 @@ use std::io; use std::io::prelude::*; +use std::iter::Enumerate; +use std::marker::PhantomData; +use std::ops::RangeFrom; + +use bytes::BytesMut; +use nom::Needed; use tokio::io::AsyncRead; use tokio::io::AsyncReadExt; const PACKET_BUFFER_SIZE: usize = 4_096; const PACKET_LARGE_BUFFER_SIZE: usize = 1_048_576; +/// Calculate the new buffer size for the next read +fn calc_new_buf_size(last_buf_size: usize) -> usize { + if last_buf_size >= PACKET_BUFFER_SIZE * 2 { + // if packet is already too large, use larger buffer to avoid multiple allocation + PACKET_LARGE_BUFFER_SIZE + } else { + std::cmp::max(PACKET_BUFFER_SIZE, last_buf_size * 2) + } +} + +/// reuse old buffer if possible, otherwise create a new buffer +/// +/// return (the idx to start writing to the buffer, and the buffer itself) +/// +/// will copy the remain bytes from the old buffer to the new buffer if reusing +fn reuse_or_create_buf(old_buf: bytes::Bytes, last_buf_size: usize) -> (usize, BytesMut) { + let new_buf_size = calc_new_buf_size(last_buf_size); + match old_buf.try_into_mut() { + Ok(mut unique) => { + let len = unique.len(); + let resize_buf = if new_buf_size <= len { + // if new buffer is smaller than old buffer, just double the size + len * 2 + } else { + new_buf_size + }; + debug_assert!(len < resize_buf); + // resize will save old bytes unchanged and fill the rest with 0 + unique.resize(resize_buf, 0); + (len, unique) + } + Err(remain) => { + let mut buf = BytesMut::with_capacity(new_buf_size); + buf.resize(new_buf_size, 0); + // if old buffer still contain bytes unread, need to save those bytes too + buf[0..remain.len()].copy_from_slice(&remain); + (remain.len(), buf) + } + } +} + pub struct PacketReader { - bytes: Vec, - start: usize, - remaining: usize, + bytes: bytes::Bytes, pub r: R, } impl PacketReader { pub fn new(r: R) -> Self { PacketReader { - bytes: Vec::new(), - start: 0, - remaining: 0, + bytes: bytes::Bytes::new(), r, } } @@ -42,49 +85,37 @@ impl PacketReader { impl PacketReader { #[allow(dead_code)] pub fn next(&mut self) -> io::Result)>> { - self.start = self.bytes.len() - self.remaining; - loop { - if self.remaining != 0 { - let bytes = { - // NOTE: this is all sorts of unfortunate. what we really want to do is to give - // &self.bytes[self.start..] to `packet()`, and the lifetimes should all work - // out. however, without NLL, borrowck doesn't realize that self.bytes is no - // longer borrowed after the match, and so can be mutated. - let bytes = &self.bytes[self.start..]; - unsafe { ::std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) } - }; - - match packet(bytes) { + let last_buffer_size = self.bytes.len(); + if !self.bytes.is_empty() { + // coping `bytes::Bytes` are very cheap, just move the pointer and increase the ref count. + match packet(self.bytes.clone().into()) { Ok((rest, p)) => { - self.remaining = rest.len(); + // most time the `rest` is either empty or very small, so it's cheap to copy it later into next buffer + self.bytes = rest.into(); return Ok(Some(p)); } - Err(nom::Err::Incomplete(_)) | Err(nom::Err::Error(_)) => {} + Err(nom::Err::Incomplete(_)) | Err(nom::Err::Error(_)) => { + } Err(nom::Err::Failure(ctx)) => { let err = Err(io::Error::new( io::ErrorKind::InvalidData, format!("{:?}", ctx), )); - self.bytes.truncate(self.remaining); return err; } } } - // we need to read some more - self.bytes.drain(0..self.start); - self.start = 0; - let end = self.bytes.len(); - self.bytes.resize(std::cmp::max(4096, end * 2), 0); - let read = { - let buf = &mut self.bytes[end..]; - self.r.read(buf)? - }; - self.bytes.truncate(end + read); - self.remaining = self.bytes.len(); + // read more buffer + let (start, mut buf) = + reuse_or_create_buf(std::mem::take(&mut self.bytes), last_buffer_size); + let read_cnt = self.r.read(&mut buf[start..])?; + buf.truncate(start + read_cnt); + self.bytes = buf.freeze(); - if read == 0 { + // for a [TcpStream], returning zero indicates the connection was shut down correctly. + if read_cnt == 0 { if self.bytes.is_empty() { return Ok(None); } else { @@ -104,11 +135,10 @@ impl AsyncRead for PacketReader { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - if self.remaining != 0 { - buf.put_slice(&self.bytes[self.start..]); + // if our buffer have content, send those immediately + if !self.bytes.is_empty() { + buf.put_slice(&self.bytes); self.bytes.clear(); - self.start = 0; - self.remaining = 0; std::task::Poll::Ready(Ok(())) } else { std::pin::Pin::new(&mut self.r).poll_read(cx, buf) @@ -118,31 +148,16 @@ impl AsyncRead for PacketReader { impl PacketReader { pub async fn next_async(&mut self) -> io::Result)>> { - self.start = self.bytes.len() - self.remaining; - - let mut buffer_size = PACKET_BUFFER_SIZE; loop { - if self.remaining != 0 { - let bytes = { - // NOTE: this is all sorts of unfortunate. what we really want to do is to give - // &self.bytes[self.start..] to `packet()`, and the lifetimes should all work - // out. however, without NLL, borrowck doesn't realize that self.bytes is no - // longer borrowed after the match, and so can be mutated. - let bytes = &self.bytes[self.start..]; - unsafe { ::std::slice::from_raw_parts(bytes.as_ptr(), self.remaining) } - }; - match packet(bytes) { + let last_buffer_size = self.bytes.len(); + if !self.bytes.is_empty() { + match packet(self.bytes.clone().into()) { Ok((rest, p)) => { - self.remaining = rest.len(); - if self.remaining > 0 { - self.bytes = rest.to_vec(); - self.start = 0; - } + self.bytes = rest.into(); return Ok(Some(p)); } Err(nom::Err::Incomplete(_)) | Err(nom::Err::Error(_)) => {} Err(nom::Err::Failure(ctx)) => { - self.bytes.truncate(self.remaining); return Err(io::Error::new( io::ErrorKind::InvalidData, format!("{:?}", ctx), @@ -151,25 +166,16 @@ impl PacketReader { } } - // we need to read some more - self.bytes.drain(0..self.start); - self.start = 0; - let end = self.remaining; + // read more buffer + let (start, mut buf) = + reuse_or_create_buf(std::mem::take(&mut self.bytes), last_buffer_size); - if self.bytes.len() - end < buffer_size { - let new_len = std::cmp::max(buffer_size, end * 2); - self.bytes.resize(new_len, 0); - } - let read = { - let buf = &mut self.bytes[end..]; - self.r.read(buf).await? - }; - self.remaining = end + read; - // use a larger buffer size to reduce bytes resize times. - buffer_size = PACKET_LARGE_BUFFER_SIZE; + let read_cnt = self.r.read(&mut buf[start..]).await?; + buf.truncate(start + read_cnt); + + self.bytes = buf.freeze(); - if read == 0 { - self.bytes.truncate(self.remaining); + if read_cnt == 0 { if self.bytes.is_empty() { return Ok(None); } else { @@ -183,50 +189,134 @@ impl PacketReader { } } -pub fn fullpacket(i: &[u8]) -> nom::IResult<&[u8], (u8, &[u8])> { - let (i, _) = nom::bytes::complete::tag(&[0xff, 0xff, 0xff])(i)?; +pub fn fullpacket(i: NomBytes) -> nom::IResult { + let (i, _) = nom::bytes::complete::tag(&[0xff, 0xff, 0xff][..])(i)?; let (i, seq) = nom::bytes::complete::take(1u8)(i)?; let (i, bytes) = nom::bytes::complete::take(U24_MAX)(i)?; - Ok((i, (seq[0], bytes))) + Ok((i, (seq.as_ref()[0], bytes))) } -pub fn onepacket(i: &[u8]) -> nom::IResult<&[u8], (u8, &[u8])> { +pub fn onepacket(i: NomBytes) -> nom::IResult { let (i, length) = nom::number::complete::le_u24(i)?; let (i, seq) = nom::bytes::complete::take(1u8)(i)?; let (i, bytes) = nom::bytes::complete::take(length)(i)?; - Ok((i, (seq[0], bytes))) + Ok((i, (seq.as_ref()[0], bytes))) } -// Clone because of https://github.com/Geal/nom/issues/1008 -#[derive(Clone)] -pub struct Packet<'a>(&'a [u8], Vec); +/// Bytes wrapper for nom, allowing nom to parse bytes::Bytes +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NomBytes(bytes::Bytes); -impl<'a> Packet<'a> { - fn extend(&mut self, bytes: &'a [u8]) { - if self.0.is_empty() { - if self.1.is_empty() { - // first extend - self.0 = bytes; - } else { - // later extend - self.1.extend(bytes); - } - } else { - assert!(self.1.is_empty()); - let mut v = self.0.to_vec(); - v.extend(bytes); - self.1 = v; - self.0 = &[]; - } +impl NomBytes { + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl From<&[u8]> for NomBytes { + fn from(value: &[u8]) -> Self { + NomBytes(bytes::Bytes::copy_from_slice(value)) + } +} + +impl From for NomBytes { + fn from(value: bytes::Bytes) -> Self { + NomBytes(value) } } -impl<'a> AsRef<[u8]> for Packet<'a> { +impl From for bytes::Bytes { + fn from(value: NomBytes) -> Self { + value.0 + } +} + +impl AsRef<[u8]> for NomBytes { fn as_ref(&self) -> &[u8] { - if self.1.is_empty() { - self.0 + self.0.as_ref() + } +} + +impl nom::InputTake for NomBytes { + fn take(&self, count: usize) -> Self { + NomBytes(self.0.slice(0..count)) + } + + fn take_split(&self, count: usize) -> (Self, Self) { + let mut prefix = self.0.clone(); + let suffix = prefix.split_off(count); + (NomBytes(suffix), NomBytes(prefix)) + } +} + +impl nom::Compare<&[u8]> for NomBytes { + fn compare(&self, t: &[u8]) -> nom::CompareResult { + self.0.as_ref().compare(t) + } + + fn compare_no_case(&self, t: &[u8]) -> nom::CompareResult { + self.0.as_ref().compare_no_case(t) + } +} + +impl nom::InputLength for NomBytes { + fn input_len(&self) -> usize { + self.0.len() + } +} + +impl nom::InputIter for NomBytes { + type Item = u8; + type Iter = Enumerate; + type IterElem = bytes::buf::IntoIter; + + #[inline] + fn iter_indices(&self) -> Self::Iter { + self.iter_elements().enumerate() + } + #[inline] + fn iter_elements(&self) -> Self::IterElem { + self.0.clone().into_iter() + } + #[inline] + fn position

(&self, predicate: P) -> Option + where + P: Fn(Self::Item) -> bool, + { + self.0.iter().position(|b| predicate(*b)) + } + #[inline] + fn slice_index(&self, count: usize) -> Result { + if self.0.len() >= count { + Ok(count) } else { - &self.1 + Err(Needed::new(count - self.0.len())) + } + } +} + +impl nom::Slice> for NomBytes { + fn slice(&self, range: RangeFrom) -> Self { + NomBytes(self.0.slice(range)) + } +} + +// a simple wrapper around bytes::Bytes to make sure interface stays the same +#[derive(Clone)] +pub struct Packet<'a> { + bytes: bytes::Bytes, + _lifetime: PhantomData<&'a ()>, // NOTE: the lifetime can be removed since Bytes mangaes the lifetime by itself +} + +impl<'a> Packet<'a> { + fn from_bytes(bytes: bytes::Bytes) -> Self { + Packet { + bytes, + _lifetime: PhantomData, } } } @@ -237,39 +327,139 @@ use std::ops::Deref; impl<'a> Deref for Packet<'a> { type Target = [u8]; fn deref(&self) -> &Self::Target { - self.as_ref() + self.bytes.as_ref() } } -pub(crate) fn packet(i: &[u8]) -> nom::IResult<&[u8], (u8, Packet<'_>)> { +// note that for small packet, this function is zero-copy, but for packet >= 2^24 it currently copy stuff, this await further optimization +pub(crate) fn packet<'a>(i: NomBytes) -> nom::IResult)> { nom::combinator::map( nom::sequence::pair( nom::multi::fold_many0( fullpacket, || (0, None), - |(seq, pkt): (_, Option>), (nseq, p)| { + |(seq, pkt): (_, Option), (nseq, p)| { let pkt = if let Some(mut pkt) = pkt { assert_eq!(nseq, seq + 1); - pkt.extend(p); + pkt.extend_from_slice(p.as_ref()); Some(pkt) } else { - Some(Packet(p, Vec::new())) + // TODO: avoid copy + Some(BytesMut::from(p.0)) }; (nseq, pkt) }, ), - onepacket, + nom::combinator::opt(onepacket), ), - move |(full, last)| { - let seq = last.0; - let pkt = if let Some(mut pkt) = full.1 { - assert_eq!(last.0, full.0 + 1); - pkt.extend(last.1); - pkt - } else { - Packet(last.1, Vec::new()) - }; - (seq, pkt) + move |((full_seq, full_pkt), last)| match (full_pkt, last) { + (Some(mut full_pkt), Some((last_seq, last_pkt))) => { + assert_eq!(last_seq, full_seq + 1); + full_pkt.extend_from_slice(last_pkt.as_ref()); + let final_pkt = full_pkt.freeze(); + Ok((last_seq, Packet::from_bytes(final_pkt))) + } + (Some(full_pkt), None) => Ok((full_seq, Packet::from_bytes(full_pkt.freeze()))), + (None, Some((last_seq, last_pkt))) => Ok((last_seq, Packet::from_bytes(last_pkt.0))), + // TODO: might know length + (None, None) => Err(nom::Err::Incomplete(Needed::Unknown)), }, )(i) + .map(|(rest, parsed)| match parsed { + Ok(parsed) => Ok((rest, parsed)), + Err(e) => Err(e), + })? +} + +#[cfg(test)] +mod test { + use bytes::{Buf, BufMut}; + + use super::*; + + fn mock_packet(mut data: bytes::Bytes, start_seq: u8) -> bytes::Bytes { + let mut buf = BytesMut::new(); + let mut seq = start_seq; + while data.len() > U24_MAX { + buf.extend_from_slice(&[0xff, 0xff, 0xff]); + buf.put_u8(seq); + buf.put(&data[0..U24_MAX]); + data.advance(U24_MAX); + seq += 1; + } + if !data.is_empty() { + let le_u64: [u8; 8] = data.len().to_le_bytes(); + let le_u24 = &le_u64[0..3]; + buf.extend_from_slice(le_u24); + buf.put_u8(seq); + buf.put(data); + } + buf.freeze() + } + + #[tokio::test] + async fn test_various_packet_size() { + // test for off by one, and off by header size(3 bytes for length and 1 for seq num) + let testcases = [ + 0, + 1, + 2, + PACKET_BUFFER_SIZE - 1 - 4, + PACKET_BUFFER_SIZE - 1, + PACKET_BUFFER_SIZE, + PACKET_BUFFER_SIZE + 1, + PACKET_BUFFER_SIZE + 1 + 4, + PACKET_LARGE_BUFFER_SIZE - 4 - 1, + PACKET_LARGE_BUFFER_SIZE - 4, + PACKET_LARGE_BUFFER_SIZE - 1, + PACKET_LARGE_BUFFER_SIZE, + PACKET_LARGE_BUFFER_SIZE + 1, + PACKET_LARGE_BUFFER_SIZE + 4, + PACKET_LARGE_BUFFER_SIZE + 4 + 1, + U24_MAX - 4 - 1, + U24_MAX - 4, + U24_MAX - 1, + U24_MAX, + U24_MAX + 1, + U24_MAX + 4, + U24_MAX + 4 + 1, + U24_MAX * 2 - 4 - 1, + U24_MAX * 2 - 4, + U24_MAX * 2 - 1, + U24_MAX * 2, + U24_MAX * 2 + 1, + U24_MAX * 2 + 4, + U24_MAX * 2 + 4 + 1, + ]; + for input_size in testcases { + let large_data = bytes::Bytes::from(vec![0; input_size]); + let packet = mock_packet(large_data, 0); + let mut reader = PacketReader::new(packet.reader()); + let mut last_seq = 0; + let mut total_size = 0; + while let Some((seq, packet)) = reader.next().unwrap() { + if seq != 0 { + assert!(seq > last_seq); + } + total_size += packet.len(); + last_seq = seq; + } + assert_eq!(total_size, input_size); + } + for input_size in testcases { + let large_data = bytes::Bytes::from(vec![0; input_size]); + let packet = mock_packet(large_data, 0); + let mut reader = PacketReader::new(packet.as_ref()); + let mut last_seq = 0; + let mut total_size = 0; + while let Some((seq, packet)) = reader.next_async().await.unwrap() { + if seq != 0 { + assert!(seq > last_seq); + } + total_size += packet.len(); + last_seq = seq; + } + assert_eq!(total_size, input_size); + } + } } diff --git a/mysql/src/tests/packet.rs b/mysql/src/tests/packet.rs index f2e6d98..48ba824 100644 --- a/mysql/src/tests/packet.rs +++ b/mysql/src/tests/packet.rs @@ -18,14 +18,14 @@ use crate::U24_MAX; #[test] fn test_one_ping() { assert_eq!( - onepacket(&[0x01, 0, 0, 0, 0x10]).unwrap().1, - (0, &[0x10][..]) + onepacket((&[0x01, 0, 0, 0, 0x10][..]).into()).unwrap().1, + (0, (&[0x10][..]).into()) ); } #[test] fn test_ping() { - let p = packet(&[0x01, 0, 0, 0, 0x10]).unwrap().1; + let p = packet((&[0x01, 0, 0, 0, 0x10][..]).into()).unwrap().1; assert_eq!(p.0, 0); assert_eq!(&*p.1, &[0x10][..]); } @@ -39,7 +39,7 @@ fn test_long_exact() { data.push(0x00); data.push(1); - let (rest, p) = packet(&data[..]).unwrap(); + let (rest, p) = packet((&data[..]).into()).unwrap(); assert!(rest.is_empty()); assert_eq!(p.0, 1); assert_eq!(p.1.len(), U24_MAX); @@ -56,7 +56,7 @@ fn test_long_more() { data.push(1); data.push(0x10); - let (rest, p) = packet(&data[..]).unwrap(); + let (rest, p) = packet((&data[..]).into()).unwrap(); assert!(rest.is_empty()); assert_eq!(p.0, 1); assert_eq!(p.1.len(), U24_MAX + 1);