diff --git a/crates/fhe-math/src/zq/mod.rs b/crates/fhe-math/src/zq/mod.rs index b20c1ad7..62c1a8e9 100644 --- a/crates/fhe-math/src/zq/mod.rs +++ b/crates/fhe-math/src/zq/mod.rs @@ -803,10 +803,11 @@ impl Modulus { #[cfg(test)] mod tests { use super::Modulus; + use fhe_util::is_prime; use itertools::{Itertools, izip}; use proptest::collection::vec as prop_vec; use proptest::prelude::{BoxedStrategy, Just, Strategy, any}; - use rand::{RngCore, rng}; + use rand::rng; // Utility functions for the proptests. @@ -814,6 +815,12 @@ mod tests { any::().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok()) } + fn prime_moduli() -> BoxedStrategy { + proptest::sample::select(vec![2u64, 3, 17, 1987, 4611686018326724609]) + .prop_map(|p| Modulus::new(p).unwrap()) + .boxed() + } + fn valid_moduli_opt() -> impl Strategy { valid_moduli().prop_filter("filter moduli not supporting opt", |p| p.supports_opt) } @@ -1098,6 +1105,27 @@ mod tests { prop_assert_eq!(a, c); } + #[test] + fn inv(p in prop_oneof![valid_moduli().boxed(), prime_moduli()], mut a: u64) { + a = p.reduce(a); + let b = p.inv(a); + + if !is_prime(*p) || a == 0 { + prop_assert!(b.is_none()); + } else { + prop_assert!(b.is_some()); + prop_assert_eq!(p.mul(a, b.unwrap()), 1); + } + + #[cfg(debug_assertions)] + { + if is_prime(*p) { + prop_assert!(std::panic::catch_unwind(|| p.inv(*p)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.inv(*p << 1)).is_err()); + } + } + } + #[test] fn mul_opt(p in valid_moduli_opt(), mut a: u64, mut b: u64) { a = p.reduce(a); @@ -1114,9 +1142,7 @@ mod tests { prop_assert!(std::panic::catch_unwind(|| p.mul_opt(a, *p + 1)).is_err()); } } - } - proptest! { #[test] fn pow(p in valid_moduli(), mut a: u64, mut b: u64) { a = p.reduce(a); @@ -1144,37 +1170,4 @@ mod tests { } } } - - // TODO: Make a proptest. - #[test] - fn inv() { - let ntests = 100; - let mut rng = rand::rng(); - - for p in [2u64, 3, 17, 1987, 4611686018326724609] { - let q = Modulus::new(p).unwrap(); - - assert!(q.inv(0).is_none()); - assert_eq!(q.inv(1).unwrap(), 1); - assert_eq!(q.inv(p - 1).unwrap(), p - 1); - - #[cfg(debug_assertions)] - { - assert!(std::panic::catch_unwind(|| q.inv(p)).is_err()); - assert!(std::panic::catch_unwind(|| q.inv(p << 1)).is_err()); - } - - for _ in 0..ntests { - let a = rng.next_u64() % p; - let b = q.inv(a); - - if a == 0 { - assert!(b.is_none()) - } else { - assert!(b.is_some()); - assert_eq!(q.mul(a, b.unwrap()), 1) - } - } - } - } }