Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement all mix/max functions in a (hopefully) more optimization amendable way #136307

Merged
merged 2 commits into from
Feb 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 137 additions & 33 deletions library/core/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,24 @@ pub trait Ord: Eq + PartialOrd<Self> {
/// assert_eq!(1.max(2), 2);
/// assert_eq!(2.max(2), 2);
/// ```
/// ```
/// use std::cmp::Ordering;
///
/// #[derive(Eq)]
/// struct Equal(&'static str);
///
/// impl PartialEq for Equal {
/// fn eq(&self, other: &Self) -> bool { true }
/// }
/// impl PartialOrd for Equal {
/// fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(Ordering::Equal) }
/// }
/// impl Ord for Equal {
/// fn cmp(&self, other: &Self) -> Ordering { Ordering::Equal }
/// }
///
/// assert_eq!(Equal("self").max(Equal("other")).0, "other");
/// ```
#[stable(feature = "ord_max_min", since = "1.21.0")]
#[inline]
#[must_use]
Expand All @@ -981,7 +999,7 @@ pub trait Ord: Eq + PartialOrd<Self> {
where
Self: Sized,
{
max_by(self, other, Ord::cmp)
if other < self { self } else { other }
}

/// Compares and returns the minimum of two values.
Expand All @@ -994,6 +1012,24 @@ pub trait Ord: Eq + PartialOrd<Self> {
/// assert_eq!(1.min(2), 1);
/// assert_eq!(2.min(2), 2);
/// ```
/// ```
/// use std::cmp::Ordering;
///
/// #[derive(Eq)]
/// struct Equal(&'static str);
///
/// impl PartialEq for Equal {
/// fn eq(&self, other: &Self) -> bool { true }
/// }
/// impl PartialOrd for Equal {
/// fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(Ordering::Equal) }
/// }
/// impl Ord for Equal {
/// fn cmp(&self, other: &Self) -> Ordering { Ordering::Equal }
/// }
///
/// assert_eq!(Equal("self").min(Equal("other")).0, "self");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here, for emphasis:

Suggested change
/// assert_eq!(Equal("self").min(Equal("other")).0, "self");
/// assert_eq!(Equal("self").min(Equal("other")).0, "self");
/// assert_eq!(Equal("other").min(Equal("self")).0, "other");

(though I guess the values work less well here)

/// ```
#[stable(feature = "ord_max_min", since = "1.21.0")]
#[inline]
#[must_use]
Expand All @@ -1002,7 +1038,7 @@ pub trait Ord: Eq + PartialOrd<Self> {
where
Self: Sized,
{
min_by(self, other, Ord::cmp)
if other < self { other } else { self }
}

/// Restrict a value to a certain interval.
Expand Down Expand Up @@ -1414,6 +1450,24 @@ pub macro PartialOrd($item:item) {
/// assert_eq!(cmp::min(1, 2), 1);
/// assert_eq!(cmp::min(2, 2), 2);
/// ```
/// ```
/// use std::cmp::{self, Ordering};
///
/// #[derive(Eq)]
/// struct Equal(&'static str);
///
/// impl PartialEq for Equal {
/// fn eq(&self, other: &Self) -> bool { true }
/// }
/// impl PartialOrd for Equal {
/// fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(Ordering::Equal) }
/// }
/// impl Ord for Equal {
/// fn cmp(&self, other: &Self) -> Ordering { Ordering::Equal }
/// }
///
/// assert_eq!(cmp::min(Equal("v1"), Equal("v2")).0, "v1");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: also include

Suggested change
/// assert_eq!(cmp::min(Equal("v1"), Equal("v2")).0, "v1");
/// assert_eq!(cmp::min(Equal("v1"), Equal("v2")).0, "v1");
/// assert_eq!(cmp::min(Equal("v2"), Equal("v1")).0, "v2");

for emphasis, because cmp::min("v1", "v2") == "v1" too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values were chosen after the function's argument names, this feels confusing...

/// ```
#[inline]
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1431,20 +1485,22 @@ pub fn min<T: Ord>(v1: T, v2: T) -> T {
/// ```
/// use std::cmp;
///
/// let result = cmp::min_by(-2, 1, |x: &i32, y: &i32| x.abs().cmp(&y.abs()));
/// assert_eq!(result, 1);
/// let abs_cmp = |x: &i32, y: &i32| x.abs().cmp(&y.abs());
///
/// let result = cmp::min_by(-2, 3, |x: &i32, y: &i32| x.abs().cmp(&y.abs()));
/// assert_eq!(result, -2);
/// let result = cmp::min_by(2, -1, abs_cmp);
/// assert_eq!(result, -1);
///
/// let result = cmp::min_by(2, -3, abs_cmp);
/// assert_eq!(result, 2);
///
/// let result = cmp::min_by(1, -1, abs_cmp);
/// assert_eq!(result, 1);
/// ```
#[inline]
#[must_use]
#[stable(feature = "cmp_min_max_by", since = "1.53.0")]
pub fn min_by<T, F: FnOnce(&T, &T) -> Ordering>(v1: T, v2: T, compare: F) -> T {
match compare(&v1, &v2) {
Ordering::Less | Ordering::Equal => v1,
Ordering::Greater => v2,
}
if compare(&v2, &v1).is_lt() { v2 } else { v1 }
}

/// Returns the element that gives the minimum value from the specified function.
Expand All @@ -1456,17 +1512,20 @@ pub fn min_by<T, F: FnOnce(&T, &T) -> Ordering>(v1: T, v2: T, compare: F) -> T {
/// ```
/// use std::cmp;
///
/// let result = cmp::min_by_key(-2, 1, |x: &i32| x.abs());
/// assert_eq!(result, 1);
/// let result = cmp::min_by_key(2, -1, |x: &i32| x.abs());
/// assert_eq!(result, -1);
///
/// let result = cmp::min_by_key(-2, 2, |x: &i32| x.abs());
/// assert_eq!(result, -2);
/// let result = cmp::min_by_key(2, -3, |x: &i32| x.abs());
/// assert_eq!(result, 2);
///
/// let result = cmp::min_by_key(1, -1, |x: &i32| x.abs());
/// assert_eq!(result, 1);
/// ```
#[inline]
#[must_use]
#[stable(feature = "cmp_min_max_by", since = "1.53.0")]
pub fn min_by_key<T, F: FnMut(&T) -> K, K: Ord>(v1: T, v2: T, mut f: F) -> T {
min_by(v1, v2, |v1, v2| f(v1).cmp(&f(v2)))
if f(&v2) < f(&v1) { v2 } else { v1 }
}

/// Compares and returns the maximum of two values.
Expand All @@ -1483,6 +1542,24 @@ pub fn min_by_key<T, F: FnMut(&T) -> K, K: Ord>(v1: T, v2: T, mut f: F) -> T {
/// assert_eq!(cmp::max(1, 2), 2);
/// assert_eq!(cmp::max(2, 2), 2);
/// ```
/// ```
/// use std::cmp::{self, Ordering};
///
/// #[derive(Eq)]
/// struct Equal(&'static str);
///
/// impl PartialEq for Equal {
/// fn eq(&self, other: &Self) -> bool { true }
/// }
/// impl PartialOrd for Equal {
/// fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(Ordering::Equal) }
/// }
/// impl Ord for Equal {
/// fn cmp(&self, other: &Self) -> Ordering { Ordering::Equal }
/// }
///
/// assert_eq!(cmp::max(Equal("v1"), Equal("v2")).0, "v2");
/// ```
#[inline]
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1500,20 +1577,22 @@ pub fn max<T: Ord>(v1: T, v2: T) -> T {
/// ```
/// use std::cmp;
///
/// let result = cmp::max_by(-2, 1, |x: &i32, y: &i32| x.abs().cmp(&y.abs()));
/// let abs_cmp = |x: &i32, y: &i32| x.abs().cmp(&y.abs());
///
/// let result = cmp::max_by(3, -2, abs_cmp) ;
/// assert_eq!(result, 3);
///
/// let result = cmp::max_by(1, -2, abs_cmp);
/// assert_eq!(result, -2);
///
/// let result = cmp::max_by(-2, 2, |x: &i32, y: &i32| x.abs().cmp(&y.abs())) ;
/// assert_eq!(result, 2);
/// let result = cmp::max_by(1, -1, abs_cmp);
/// assert_eq!(result, -1);
/// ```
#[inline]
#[must_use]
#[stable(feature = "cmp_min_max_by", since = "1.53.0")]
pub fn max_by<T, F: FnOnce(&T, &T) -> Ordering>(v1: T, v2: T, compare: F) -> T {
match compare(&v1, &v2) {
Ordering::Less | Ordering::Equal => v2,
Ordering::Greater => v1,
}
if compare(&v2, &v1).is_lt() { v1 } else { v2 }
}

/// Returns the element that gives the maximum value from the specified function.
Expand All @@ -1525,17 +1604,20 @@ pub fn max_by<T, F: FnOnce(&T, &T) -> Ordering>(v1: T, v2: T, compare: F) -> T {
/// ```
/// use std::cmp;
///
/// let result = cmp::max_by_key(-2, 1, |x: &i32| x.abs());
/// let result = cmp::max_by_key(3, -2, |x: &i32| x.abs());
/// assert_eq!(result, 3);
///
/// let result = cmp::max_by_key(1, -2, |x: &i32| x.abs());
/// assert_eq!(result, -2);
///
/// let result = cmp::max_by_key(-2, 2, |x: &i32| x.abs());
/// assert_eq!(result, 2);
/// let result = cmp::max_by_key(1, -1, |x: &i32| x.abs());
/// assert_eq!(result, -1);
/// ```
#[inline]
#[must_use]
#[stable(feature = "cmp_min_max_by", since = "1.53.0")]
pub fn max_by_key<T, F: FnMut(&T) -> K, K: Ord>(v1: T, v2: T, mut f: F) -> T {
max_by(v1, v2, |v1, v2| f(v1).cmp(&f(v2)))
if f(&v2) < f(&v1) { v1 } else { v2 }
}

/// Compares and sorts two values, returning minimum and maximum.
Expand All @@ -1549,21 +1631,40 @@ pub fn max_by_key<T, F: FnMut(&T) -> K, K: Ord>(v1: T, v2: T, mut f: F) -> T {
/// use std::cmp;
///
/// assert_eq!(cmp::minmax(1, 2), [1, 2]);
/// assert_eq!(cmp::minmax(2, 2), [2, 2]);
/// assert_eq!(cmp::minmax(2, 1), [1, 2]);
///
/// // You can destructure the result using array patterns
/// let [min, max] = cmp::minmax(42, 17);
/// assert_eq!(min, 17);
/// assert_eq!(max, 42);
/// ```
/// ```
/// #![feature(cmp_minmax)]
/// use std::cmp::{self, Ordering};
///
/// #[derive(Eq)]
/// struct Equal(&'static str);
///
/// impl PartialEq for Equal {
/// fn eq(&self, other: &Self) -> bool { true }
/// }
/// impl PartialOrd for Equal {
/// fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(Ordering::Equal) }
/// }
/// impl Ord for Equal {
/// fn cmp(&self, other: &Self) -> Ordering { Ordering::Equal }
/// }
///
/// assert_eq!(cmp::minmax(Equal("v1"), Equal("v2")).map(|v| v.0), ["v1", "v2"]);
/// ```
#[inline]
#[must_use]
#[unstable(feature = "cmp_minmax", issue = "115939")]
pub fn minmax<T>(v1: T, v2: T) -> [T; 2]
where
T: Ord,
{
if v1 <= v2 { [v1, v2] } else { [v2, v1] }
if v2 < v1 { [v2, v1] } else { [v1, v2] }
}

/// Returns minimum and maximum values with respect to the specified comparison function.
Expand All @@ -1576,11 +1677,14 @@ where
/// #![feature(cmp_minmax)]
/// use std::cmp;
///
/// assert_eq!(cmp::minmax_by(-2, 1, |x: &i32, y: &i32| x.abs().cmp(&y.abs())), [1, -2]);
/// assert_eq!(cmp::minmax_by(-2, 2, |x: &i32, y: &i32| x.abs().cmp(&y.abs())), [-2, 2]);
/// let abs_cmp = |x: &i32, y: &i32| x.abs().cmp(&y.abs());
///
/// assert_eq!(cmp::minmax_by(-2, 1, abs_cmp), [1, -2]);
/// assert_eq!(cmp::minmax_by(-1, 2, abs_cmp), [-1, 2]);
/// assert_eq!(cmp::minmax_by(-2, 2, abs_cmp), [-2, 2]);
///
/// // You can destructure the result using array patterns
/// let [min, max] = cmp::minmax_by(-42, 17, |x: &i32, y: &i32| x.abs().cmp(&y.abs()));
/// let [min, max] = cmp::minmax_by(-42, 17, abs_cmp);
/// assert_eq!(min, 17);
/// assert_eq!(max, -42);
/// ```
Expand All @@ -1591,7 +1695,7 @@ pub fn minmax_by<T, F>(v1: T, v2: T, compare: F) -> [T; 2]
where
F: FnOnce(&T, &T) -> Ordering,
{
if compare(&v1, &v2).is_le() { [v1, v2] } else { [v2, v1] }
if compare(&v2, &v1).is_lt() { [v2, v1] } else { [v1, v2] }
}

/// Returns minimum and maximum values with respect to the specified key function.
Expand Down Expand Up @@ -1620,7 +1724,7 @@ where
F: FnMut(&T) -> K,
K: Ord,
{
minmax_by(v1, v2, |v1, v2| f(v1).cmp(&f(v2)))
if f(&v2) < f(&v1) { [v2, v1] } else { [v1, v2] }
}

// Implementation of PartialEq, Eq, PartialOrd and Ord for primitive types
Expand Down
Loading