From bbdff1fff49a6f816aa35d9b84c79b2603f0961b Mon Sep 17 00:00:00 2001
From: Ross MacArthur <ross@macarthur.io>
Date: Tue, 21 Jun 2022 08:57:02 +0200
Subject: [PATCH] Add `Iterator::next_chunk`

---
 library/core/src/array/mod.rs              | 79 +++++++++++++---------
 library/core/src/iter/traits/iterator.rs   | 42 ++++++++++++
 library/core/tests/iter/traits/iterator.rs |  9 +++
 library/core/tests/lib.rs                  |  1 +
 4 files changed, 100 insertions(+), 31 deletions(-)

diff --git a/library/core/src/array/mod.rs b/library/core/src/array/mod.rs
index 2ea4458bf6427..c9823a136bc42 100644
--- a/library/core/src/array/mod.rs
+++ b/library/core/src/array/mod.rs
@@ -780,8 +780,8 @@ where
 }
 
 /// Pulls `N` items from `iter` and returns them as an array. If the iterator
-/// yields fewer than `N` items, `None` is returned and all already yielded
-/// items are dropped.
+/// yields fewer than `N` items, `Err` is returned containing an iterator over
+/// the already yielded items.
 ///
 /// Since the iterator is passed as a mutable reference and this function calls
 /// `next` at most `N` times, the iterator can still be used afterwards to
@@ -789,7 +789,10 @@ where
 ///
 /// If `iter.next()` panicks, all items already yielded by the iterator are
 /// dropped.
-fn try_collect_into_array<I, T, R, const N: usize>(iter: &mut I) -> Option<R::TryType>
+#[inline]
+fn try_collect_into_array<I, T, R, const N: usize>(
+    iter: &mut I,
+) -> Result<R::TryType, IntoIter<T, N>>
 where
     I: Iterator,
     I::Item: Try<Output = T, Residual = R>,
@@ -797,7 +800,7 @@ where
 {
     if N == 0 {
         // SAFETY: An empty array is always inhabited and has no validity invariants.
-        return unsafe { Some(Try::from_output(mem::zeroed())) };
+        return Ok(Try::from_output(unsafe { mem::zeroed() }));
     }
 
     struct Guard<'a, T, const N: usize> {
@@ -821,35 +824,49 @@ where
     let mut array = MaybeUninit::uninit_array::<N>();
     let mut guard = Guard { array_mut: &mut array, initialized: 0 };
 
-    while let Some(item_rslt) = iter.next() {
-        let item = match item_rslt.branch() {
-            ControlFlow::Break(r) => {
-                return Some(FromResidual::from_residual(r));
+    for _ in 0..N {
+        match iter.next() {
+            Some(item_rslt) => {
+                let item = match item_rslt.branch() {
+                    ControlFlow::Break(r) => {
+                        return Ok(FromResidual::from_residual(r));
+                    }
+                    ControlFlow::Continue(elem) => elem,
+                };
+
+                // SAFETY: `guard.initialized` starts at 0, is increased by one in the
+                // loop and the loop is aborted once it reaches N (which is
+                // `array.len()`).
+                unsafe {
+                    guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
+                }
+                guard.initialized += 1;
+            }
+            None => {
+                let alive = 0..guard.initialized;
+                mem::forget(guard);
+                // SAFETY: `array` was initialized with exactly `initialized`
+                // number of elements.
+                return Err(unsafe { IntoIter::new_unchecked(array, alive) });
             }
-            ControlFlow::Continue(elem) => elem,
-        };
-
-        // SAFETY: `guard.initialized` starts at 0, is increased by one in the
-        // loop and the loop is aborted once it reaches N (which is
-        // `array.len()`).
-        unsafe {
-            guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
-        }
-        guard.initialized += 1;
-
-        // Check if the whole array was initialized.
-        if guard.initialized == N {
-            mem::forget(guard);
-
-            // SAFETY: the condition above asserts that all elements are
-            // initialized.
-            let out = unsafe { MaybeUninit::array_assume_init(array) };
-            return Some(Try::from_output(out));
         }
     }
 
-    // This is only reached if the iterator is exhausted before
-    // `guard.initialized` reaches `N`. Also note that `guard` is dropped here,
-    // dropping all already initialized elements.
-    None
+    mem::forget(guard);
+    // SAFETY: All elements of the array were populated in the loop above.
+    let output = unsafe { MaybeUninit::array_assume_init(array) };
+    Ok(Try::from_output(output))
+}
+
+/// Returns the next chunk of `N` items from the iterator or errors with an
+/// iterator over the remainder. Used for `Iterator::next_chunk`.
+#[inline]
+pub(crate) fn iter_next_chunk<I, const N: usize>(
+    iter: &mut I,
+) -> Result<[I::Item; N], IntoIter<I::Item, N>>
+where
+    I: Iterator,
+{
+    let mut map = iter.map(NeverShortCircuit);
+    try_collect_into_array(&mut map).map(|NeverShortCircuit(arr)| arr)
 }
diff --git a/library/core/src/iter/traits/iterator.rs b/library/core/src/iter/traits/iterator.rs
index 1cc9133fc3dc4..326b98ec947d2 100644
--- a/library/core/src/iter/traits/iterator.rs
+++ b/library/core/src/iter/traits/iterator.rs
@@ -1,3 +1,4 @@
+use crate::array;
 use crate::cmp::{self, Ordering};
 use crate::ops::{ChangeOutputType, ControlFlow, FromResidual, Residual, Try};
 
@@ -102,6 +103,47 @@ pub trait Iterator {
     #[stable(feature = "rust1", since = "1.0.0")]
     fn next(&mut self) -> Option<Self::Item>;
 
+    /// Advances the iterator and returns an array containing the next `N` values.
+    ///
+    /// If there are not enough elements to fill the array then `Err` is returned
+    /// containing an iterator over the remaining elements.
+    ///
+    /// # Examples
+    ///
+    /// Basic usage:
+    ///
+    /// ```
+    /// #![feature(iter_next_chunk)]
+    ///
+    /// let mut iter = "lorem".chars();
+    ///
+    /// assert_eq!(iter.next_chunk().unwrap(), ['l', 'o']);              // N is inferred as 2
+    /// assert_eq!(iter.next_chunk().unwrap(), ['r', 'e', 'm']);         // N is inferred as 3
+    /// assert_eq!(iter.next_chunk::<4>().unwrap_err().as_slice(), &[]); // N is explicitly 4
+    /// ```
+    ///
+    /// Split a string and get the first three items.
+    ///
+    /// ```
+    /// #![feature(iter_next_chunk)]
+    ///
+    /// let quote = "not all those who wander are lost";
+    /// let [first, second, third] = quote.split_whitespace().next_chunk().unwrap();
+    /// assert_eq!(first, "not");
+    /// assert_eq!(second, "all");
+    /// assert_eq!(third, "those");
+    /// ```
+    #[inline]
+    #[unstable(feature = "iter_next_chunk", reason = "recently added", issue = "98326")]
+    fn next_chunk<const N: usize>(
+        &mut self,
+    ) -> Result<[Self::Item; N], array::IntoIter<Self::Item, N>>
+    where
+        Self: Sized,
+    {
+        array::iter_next_chunk(self)
+    }
+
     /// Returns the bounds on the remaining length of the iterator.
     ///
     /// Specifically, `size_hint()` returns a tuple where the first element
diff --git a/library/core/tests/iter/traits/iterator.rs b/library/core/tests/iter/traits/iterator.rs
index 731b1592d4193..37345c1d38142 100644
--- a/library/core/tests/iter/traits/iterator.rs
+++ b/library/core/tests/iter/traits/iterator.rs
@@ -575,6 +575,15 @@ fn iter_try_collect_uses_try_fold_not_next() {
     // validation is just that it didn't panic.
 }
 
+#[test]
+fn test_next_chunk() {
+    let mut it = 0..12;
+    assert_eq!(it.next_chunk().unwrap(), [0, 1, 2, 3]);
+    assert_eq!(it.next_chunk().unwrap(), []);
+    assert_eq!(it.next_chunk().unwrap(), [4, 5, 6, 7, 8, 9]);
+    assert_eq!(it.next_chunk::<4>().unwrap_err().as_slice(), &[10, 11]);
+}
+
 // just tests by whether or not this compiles
 fn _empty_impl_all_auto_traits<T>() {
     use std::panic::{RefUnwindSafe, UnwindSafe};
diff --git a/library/core/tests/lib.rs b/library/core/tests/lib.rs
index 63c9602abe75c..9611e197a41c4 100644
--- a/library/core/tests/lib.rs
+++ b/library/core/tests/lib.rs
@@ -67,6 +67,7 @@
 #![feature(iter_partition_in_place)]
 #![feature(iter_intersperse)]
 #![feature(iter_is_partitioned)]
+#![feature(iter_next_chunk)]
 #![feature(iter_order_by)]
 #![feature(iterator_try_collect)]
 #![feature(iterator_try_reduce)]