Skip to content

Commit

Permalink
fix(math): fix edge cases of ln and pow
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewWestberg committed Oct 31, 2024
1 parent 0dd9bcd commit 2db4170
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 9 deletions.
1 change: 1 addition & 0 deletions pallas-math/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ thiserror = "1.0.61"
quickcheck = "1.0"
quickcheck_macros = "1.0"
rand = "0.8"
proptest = "1.5"
67 changes: 65 additions & 2 deletions pallas-math/src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
# Cardano Math functions
*/

use once_cell::sync::Lazy;
use std::fmt::{Debug, Display};
use std::ops::{Div, Mul, Neg, Sub};

use thiserror::Error;

pub type FixedDecimal = crate::math_malachite::Decimal;

pub static ZERO: Lazy<FixedDecimal> = Lazy::new(|| FixedDecimal::from(0u64));
pub static MINUS_ONE: Lazy<FixedDecimal> = Lazy::new(|| FixedDecimal::from(-1i64));
pub static ONE: Lazy<FixedDecimal> = Lazy::new(|| FixedDecimal::from(1u64));

#[derive(Debug, Error)]
pub enum Error {
#[error("error in regex")]
Expand Down Expand Up @@ -38,7 +42,7 @@ pub trait FixedPrecision:

/// Entry point for 'ln' approximation. First does the necessary scaling, and
/// then calls the continued fraction calculation. For any value outside the
/// domain, i.e., 'x in (-inf,0]', the function returns '-INFINITY'.
/// domain, i.e., 'x in (-inf,0]', the function panics.
fn ln(&self) -> Self;

/// Entry point for 'pow' function. x^y = exp(y * ln x)
Expand Down Expand Up @@ -91,6 +95,9 @@ pub struct ExpCmpOrdering {
#[cfg(test)]
mod tests {
use super::*;
use malachite_base::num::arithmetic::traits::Abs;
use proptest::prelude::Strategy;
use proptest::proptest;
use std::fs::File;
use std::io::BufRead;
use std::path::PathBuf;
Expand Down Expand Up @@ -479,4 +486,60 @@ mod tests {
assert_eq!(res.iterations.to_string(), expected_iterations);
}
}

#[test]
#[should_panic(expected = "ln of a value in (-inf,0] is undefined")]
fn ln_of_0_should_be_undefined() {
ZERO.ln();
}

#[test]
#[should_panic(expected = "ln of a value in (-inf,0] is undefined")]
fn ln_of_negative_should_be_undefined() {
MINUS_ONE.ln();
}

#[test]
fn pow_of_zero_to_any_positive_power_should_be_zero() {
proptest!(|(y in 1u64..=u64::MAX)| {
assert_eq!(ZERO.pow(&FixedDecimal::from(y)), *ZERO);
});
}

#[test]
#[should_panic(expected = "zero to a negative power is undefined")]
fn pow_of_zero_to_neg_power_should_be_undefined() {
let y = FixedDecimal::from(-1i64);
ZERO.pow(&y);
}

#[test]
fn pow_of_any_to_power_0_should_be_1() {
proptest!(|(x in i64::MIN..=i64::MAX)| {
assert_eq!(FixedDecimal::from(x).pow(&*ZERO), *ONE);
});
}

#[test]
fn pow_of_any_to_power_1_should_be_same() {
proptest!(|(x in i64::MIN..=i64::MAX)| {
assert_eq!(FixedDecimal::from(x).pow(&*ONE), FixedDecimal::from(x));
});
}

#[test]
fn pow_to_positive_times_pow_to_negative_should_be_1() {
let epsilon = FixedDecimal::from_str("1000000000000000000", 34).unwrap();
proptest!(|(x in (-5i64..=5i64).prop_filter("Exclude zero", |&x| x != 0), y in 1i64..=25i64)| {
let x = FixedDecimal::from(x);
let y = FixedDecimal::from(y);
let minus_y = -&y;
let x_to_y = x.pow(&y);
let x_to_minus_y = x.pow(&minus_y);
let result = &x_to_y * &x_to_minus_y;
let diff = (&result - &*ONE).abs();
// println!("x: {}, y: {}, x^y: {}, x^-y: {}, x^y * x^-y: {}, diff: {}", x, y, x_to_y, x_to_minus_y, result, diff);
assert!(diff <= epsilon);
});
}
}
77 changes: 70 additions & 7 deletions pallas-math/src/math_malachite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use malachite::num::basic::traits::One;
use malachite::platform_64::Limb;
use malachite::rounding_modes::RoundingMode;
use malachite::{Integer, Natural};
use malachite_base::num::arithmetic::traits::Sign;
use malachite_base::num::arithmetic::traits::{Parity, Sign};
use once_cell::sync::Lazy;
use regex::Regex;
use std::cmp::Ordering;
Expand Down Expand Up @@ -135,6 +135,27 @@ impl<'a> Neg for &'a Decimal {
}
}

impl Abs for Decimal {
type Output = Self;

fn abs(self) -> Self::Output {
let mut result = Decimal::new(self.precision);
result.data = self.data.abs();
result
}
}

// Implement Abs for a reference to Decimal
impl<'a> Abs for &'a Decimal {
type Output = Decimal;

fn abs(self) -> Self::Output {
let mut result = Decimal::new(self.precision);
result.data = (&self.data).abs();
result
}
}

impl Mul for Decimal {
type Output = Self;

Expand Down Expand Up @@ -310,10 +331,15 @@ impl FixedPrecision for Decimal {

fn ln(&self) -> Self {
let mut ln_x = Decimal::new(self.precision);
ref_ln(&mut ln_x.data, &self.data);
ln_x
if ref_ln(&mut ln_x.data, &self.data) {
ln_x
} else {
panic!("ln of a value in (-inf,0] is undefined")
}
}

/// Compute the power of a Decimal approximation using x^y = exp(y * ln x) formula
/// While not exact, this is a more performant way to compute the power of a Decimal
fn pow(&self, rhs: &Self) -> Self {
let mut pow_x = Decimal::new(self.precision);
ref_pow(&mut pow_x.data, &self.data, &rhs.data);
Expand Down Expand Up @@ -677,10 +703,47 @@ fn ref_ln(rop: &mut Integer, x: &Integer) -> bool {
fn ref_pow(rop: &mut Integer, base: &Integer, exponent: &Integer) {
/* x^y = exp(y * ln x) */
let mut tmp: Integer = Integer::from(0);
ref_ln(&mut tmp, base);
tmp *= exponent;
scale(&mut tmp);
ref_exp(rop, &tmp);

if exponent == &ZERO.value || base == &ONE.value {
// any base to the power of zero is one, or 1 to any power is 1
*rop = ONE.value.clone();
return;
}
if exponent == &ONE.value {
// any base to the power of one is the base
*rop = base.clone();
return;
}
if base == &ZERO.value && exponent > &ZERO.value {
// zero to any positive power is zero
*rop = &ZERO.value * &PRECISION.value;
return;
}
if base == &ZERO.value && exponent < &ZERO.value {
panic!("zero to a negative power is undefined");
}
if base < &ZERO.value {
// negate the base and calculate the power
let neg_base = base.neg();
let ref_ln = ref_ln(&mut tmp, &neg_base);
debug_assert!(ref_ln);
tmp *= exponent;
scale(&mut tmp);
let mut tmp_rop = Integer::from(0);
ref_exp(&mut tmp_rop, &tmp);
*rop = if (exponent / &PRECISION.value).even() {
tmp_rop
} else {
-tmp_rop
};
} else {
// base is positive, ref_ln result is valid
let ref_ln = ref_ln(&mut tmp, base);
debug_assert!(ref_ln);
tmp *= exponent;
scale(&mut tmp);
ref_exp(rop, &tmp);
}
}

/// `bound_x` is the bound for exp in the interval x is chosen from
Expand Down

0 comments on commit 2db4170

Please sign in to comment.