diff --git a/library/std/src/io/mod.rs b/library/std/src/io/mod.rs index 314cbb45d49e2..5cf0ea8d5c90e 100644 --- a/library/std/src/io/mod.rs +++ b/library/std/src/io/mod.rs @@ -1586,6 +1586,60 @@ impl<'a> Deref for IoSlice<'a> { } } +/// Limits a slice of buffers to at most `n` buffers and ensures that it has at +/// least one buffer, even if empty. +/// +/// When the slice contains over `n` buffers, ensure that at least one non-empty +/// buffer is in the truncated slice, if there is one. +#[allow(unused_macros)] // Not used on all platforms +pub(crate) macro limit_slices($bufs:expr, $n:expr) { + 'slices: { + let bufs: &[IoSlice<'_>] = $bufs; + let n: usize = $n; + // if bufs.len() > n || bufs.is_empty() + if core::intrinsics::unlikely(bufs.len().wrapping_sub(1) >= n) { + for (i, buf) in bufs.iter().enumerate() { + if !buf.is_empty() { + let len = cmp::min(bufs.len() - i, n); + break 'slices &bufs[i..i + len]; + } + } + // All buffers are empty. Since POSIX requires at least one buffer + // for [writev], but possibly bufs.is_empty(), return an empty write. + // [writev]: https://pubs.opengroup.org/onlinepubs/9799919799/functions/writev.html + return Ok(0); + } + bufs + } +} + +/// Limits a slice of buffers to at most `n` buffers and ensures that it has at +/// least one buffer, even if empty. +/// +/// When the slice contains over `n` buffers, ensure that at least one non-empty +/// buffer is in the truncated slice, if there is one. +#[allow(unused_macros)] // Not used on all platforms +pub(crate) macro limit_slices_mut($bufs:expr, $n:expr) { + 'slices: { + let bufs: &mut [IoSliceMut<'_>] = $bufs; + let n: usize = $n; + // if bufs.len() > n || bufs.is_empty() + if core::intrinsics::unlikely(bufs.len().wrapping_sub(1) >= n) { + for (i, buf) in bufs.iter().enumerate() { + if !buf.is_empty() { + let len = cmp::min(bufs.len() - i, n); + break 'slices &mut bufs[i..i + len]; + } + } + // All buffers are empty. Since POSIX requires at least one buffer + // for [readv], but possibly bufs.is_empty(), return an empty read. + // [readv]: https://pubs.opengroup.org/onlinepubs/9799919799/functions/readv.html + return Ok(0); + } + bufs + } +} + /// A trait for objects which are byte-oriented sinks. /// /// Implementors of the `Write` trait are sometimes called 'writers'. diff --git a/library/std/src/sys/fd/hermit.rs b/library/std/src/sys/fd/hermit.rs index 7e8ba065f1b96..793b2bd644885 100644 --- a/library/std/src/sys/fd/hermit.rs +++ b/library/std/src/sys/fd/hermit.rs @@ -1,6 +1,5 @@ #![unstable(reason = "not public", issue = "none", feature = "fd")] -use crate::cmp; use crate::io::{self, BorrowedCursor, IoSlice, IoSliceMut, Read, SeekFrom}; use crate::os::hermit::hermit_abi; use crate::os::hermit::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; @@ -39,11 +38,12 @@ impl FileDesc { } pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { hermit_abi::readv( self.as_raw_fd(), bufs.as_mut_ptr() as *mut hermit_abi::iovec as *const hermit_abi::iovec, - cmp::min(bufs.len(), max_iov()), + bufs.len(), ) })?; Ok(ret as usize) @@ -66,11 +66,12 @@ impl FileDesc { } pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { hermit_abi::writev( self.as_raw_fd(), bufs.as_ptr() as *const hermit_abi::iovec, - cmp::min(bufs.len(), max_iov()), + bufs.len(), ) })?; Ok(ret as usize) diff --git a/library/std/src/sys/fd/unix.rs b/library/std/src/sys/fd/unix.rs index 2042ea2c73d00..26b322b597ac4 100644 --- a/library/std/src/sys/fd/unix.rs +++ b/library/std/src/sys/fd/unix.rs @@ -109,11 +109,12 @@ impl FileDesc { target_os = "nuttx" )))] pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { libc::readv( self.as_raw_fd(), bufs.as_mut_ptr() as *mut libc::iovec as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, ) })?; Ok(ret as usize) @@ -199,11 +200,12 @@ impl FileDesc { target_os = "openbsd", // OpenBSD 2.7 ))] pub fn read_vectored_at(&self, bufs: &mut [IoSliceMut<'_>], offset: u64) -> io::Result { + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { libc::preadv( self.as_raw_fd(), bufs.as_mut_ptr() as *mut libc::iovec as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -245,11 +247,12 @@ impl FileDesc { ) -> isize; ); + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { preadv( self.as_raw_fd(), bufs.as_mut_ptr() as *mut libc::iovec as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -272,11 +275,12 @@ impl FileDesc { match preadv64.get() { Some(preadv) => { + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { preadv( self.as_raw_fd(), bufs.as_mut_ptr() as *mut libc::iovec as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -308,11 +312,12 @@ impl FileDesc { match preadv.get() { Some(preadv) => { + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { preadv( self.as_raw_fd(), bufs.as_mut_ptr() as *mut libc::iovec as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -340,11 +345,12 @@ impl FileDesc { target_os = "nuttx" )))] pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { libc::writev( self.as_raw_fd(), bufs.as_ptr() as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, ) })?; Ok(ret as usize) @@ -409,11 +415,12 @@ impl FileDesc { target_os = "openbsd", // OpenBSD 2.7 ))] pub fn write_vectored_at(&self, bufs: &[IoSlice<'_>], offset: u64) -> io::Result { + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { libc::pwritev( self.as_raw_fd(), bufs.as_ptr() as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -455,11 +462,12 @@ impl FileDesc { ) -> isize; ); + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { pwritev( self.as_raw_fd(), bufs.as_ptr() as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -479,11 +487,12 @@ impl FileDesc { match pwritev64.get() { Some(pwritev) => { + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { pwritev( self.as_raw_fd(), bufs.as_ptr() as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; @@ -515,11 +524,12 @@ impl FileDesc { match pwritev.get() { Some(pwritev) => { + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { pwritev( self.as_raw_fd(), bufs.as_ptr() as *const libc::iovec, - cmp::min(bufs.len(), max_iov()) as libc::c_int, + bufs.len() as libc::c_int, offset as _, ) })?; diff --git a/library/std/src/sys/fd/unix/tests.rs b/library/std/src/sys/fd/unix/tests.rs index fcd66c71707d9..9f82b65c00556 100644 --- a/library/std/src/sys/fd/unix/tests.rs +++ b/library/std/src/sys/fd/unix/tests.rs @@ -1,12 +1,32 @@ use core::mem::ManuallyDrop; -use super::FileDesc; +use super::{FileDesc, max_iov}; use crate::io::IoSlice; use crate::os::unix::io::FromRawFd; #[test] fn limit_vector_count() { + const IOV_MAX: usize = max_iov(); + + let stdout = ManuallyDrop::new(unsafe { FileDesc::from_raw_fd(1) }); + let mut bufs = vec![IoSlice::new(&[]); IOV_MAX * 2 + 1]; + assert_eq!(stdout.write_vectored(&bufs).unwrap(), 0); + + // The slice of buffers is truncated to IOV_MAX buffers. However, since the + // first IOV_MAX buffers are all empty, it is sliced starting at the first + // non-empty buffer to avoid erroneously returning Ok(0). In this case, that + // starts with the b"hello" buffer and ends just before the b"world!" + // buffer. + bufs[IOV_MAX] = IoSlice::new(b"hello"); + bufs[IOV_MAX * 2] = IoSlice::new(b"world!"); + assert_eq!(stdout.write_vectored(&bufs).unwrap(), b"hello".len()) +} + +#[test] +fn empty_vector() { + let stdin = ManuallyDrop::new(unsafe { FileDesc::from_raw_fd(0) }); + assert_eq!(stdin.read_vectored(&mut []).unwrap(), 0); + let stdout = ManuallyDrop::new(unsafe { FileDesc::from_raw_fd(1) }); - let bufs = (0..1500).map(|_| IoSlice::new(&[])).collect::>(); - assert!(stdout.write_vectored(&bufs).is_ok()); + assert_eq!(stdout.write_vectored(&[]).unwrap(), 0); } diff --git a/library/std/src/sys/net/connection/socket/solid.rs b/library/std/src/sys/net/connection/socket/solid.rs index 94bb605c1007c..0b979b3c04052 100644 --- a/library/std/src/sys/net/connection/socket/solid.rs +++ b/library/std/src/sys/net/connection/socket/solid.rs @@ -9,7 +9,7 @@ use crate::os::solid::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, Owne use crate::sys::abi; use crate::sys_common::{FromInner, IntoInner}; use crate::time::Duration; -use crate::{cmp, mem, ptr, str}; +use crate::{mem, ptr, str}; pub(super) mod netc { pub use crate::sys::abi::sockets::*; @@ -223,12 +223,9 @@ impl Socket { } pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let bufs = io::limit_slices_mut!(bufs, max_iov()); let ret = cvt(unsafe { - netc::readv( - self.as_raw_fd(), - bufs.as_ptr() as *const netc::iovec, - cmp::min(bufs.len(), max_iov()) as c_int, - ) + netc::readv(self.as_raw_fd(), bufs.as_ptr() as *const netc::iovec, bufs.len() as c_int) })?; Ok(ret as usize) } @@ -268,12 +265,9 @@ impl Socket { } pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let bufs = io::limit_slices!(bufs, max_iov()); let ret = cvt(unsafe { - netc::writev( - self.as_raw_fd(), - bufs.as_ptr() as *const netc::iovec, - cmp::min(bufs.len(), max_iov()) as c_int, - ) + netc::writev(self.as_raw_fd(), bufs.as_ptr() as *const netc::iovec, bufs.len() as c_int) })?; Ok(ret as usize) } diff --git a/library/std/src/sys/net/connection/socket/windows.rs b/library/std/src/sys/net/connection/socket/windows.rs index ce975bb2289c2..d6dafb63e38cb 100644 --- a/library/std/src/sys/net/connection/socket/windows.rs +++ b/library/std/src/sys/net/connection/socket/windows.rs @@ -299,8 +299,6 @@ impl Socket { } fn recv_with_flags(&self, mut buf: BorrowedCursor<'_>, flags: c_int) -> io::Result<()> { - // On unix when a socket is shut down all further reads return 0, so we - // do the same on windows to map a shut down socket to returning EOF. let length = cmp::min(buf.capacity(), i32::MAX as usize) as i32; let result = unsafe { c::recv(self.as_raw(), buf.as_mut().as_mut_ptr() as *mut _, length, flags) }; @@ -309,6 +307,9 @@ impl Socket { c::SOCKET_ERROR => { let error = unsafe { c::WSAGetLastError() }; + // On Unix when a socket is shut down, all further reads return + // 0, so we do the same on Windows to map a shut down socket to + // returning EOF. if error == c::WSAESHUTDOWN { Ok(()) } else { @@ -333,8 +334,11 @@ impl Socket { } pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - // On unix when a socket is shut down all further reads return 0, so we - // do the same on windows to map a shut down socket to returning EOF. + // WSARecv requires at least one buffer. + if bufs.is_empty() { + return Ok(0); + } + let length = cmp::min(bufs.len(), u32::MAX as usize) as u32; let mut nread = 0; let mut flags = 0; @@ -355,6 +359,9 @@ impl Socket { _ => { let error = unsafe { c::WSAGetLastError() }; + // On Unix when a socket is shut down, all further reads return + // 0, so we do the same on Windows to map a shut down socket to + // returning EOF. if error == c::WSAESHUTDOWN { Ok(0) } else { @@ -384,8 +391,6 @@ impl Socket { let mut addrlen = size_of_val(&storage) as netc::socklen_t; let length = cmp::min(buf.len(), ::MAX as usize) as wrlen_t; - // On unix when a socket is shut down all further reads return 0, so we - // do the same on windows to map a shut down socket to returning EOF. let result = unsafe { c::recvfrom( self.as_raw(), @@ -401,6 +406,9 @@ impl Socket { c::SOCKET_ERROR => { let error = unsafe { c::WSAGetLastError() }; + // On Unix when a socket is shut down, all further reads return + // 0, so we do the same on Windows to map a shut down socket to + // returning EOF. if error == c::WSAESHUTDOWN { Ok((0, unsafe { socket_addr_from_c(&storage, addrlen as usize)? })) } else { @@ -420,6 +428,11 @@ impl Socket { } pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + // WSASend requires at least one buffer. + if bufs.is_empty() { + return Ok(0); + } + let length = cmp::min(bufs.len(), u32::MAX as usize) as u32; let mut nwritten = 0; let result = unsafe {