diff --git a/crates/fhe-math/src/rq/mod.rs b/crates/fhe-math/src/rq/mod.rs index 6b9da69b..294ff3a7 100644 --- a/crates/fhe-math/src/rq/mod.rs +++ b/crates/fhe-math/src/rq/mod.rs @@ -877,7 +877,7 @@ mod tests { for i in 1..=16 { let p = Poly::small(&ctx, Representation::PowerBasis, i, &mut rng)?; let coefficients = p.coefficients().to_slice().unwrap(); - let v = unsafe { q.center_vec_vt(coefficients) }; + let v = q.center_vec(coefficients); assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * i as i64); } @@ -889,7 +889,7 @@ mod tests { let mut rng = rand::rng(); let p = Poly::small(&ctx, Representation::PowerBasis, 16, &mut rng)?; let coefficients = p.coefficients().to_slice().unwrap(); - let v = unsafe { q.center_vec_vt(coefficients) }; + let v = q.center_vec(coefficients); assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 32); assert_eq!(variance(&v).round(), 16.0); diff --git a/crates/fhe-math/src/zq/mod.rs b/crates/fhe-math/src/zq/mod.rs index b20c1ad7..055611b6 100644 --- a/crates/fhe-math/src/zq/mod.rs +++ b/crates/fhe-math/src/zq/mod.rs @@ -440,34 +440,27 @@ impl Modulus { .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce(*ai))) } - /// Center a value modulo p as i64 in variable time. - /// TODO: To test and to make constant time? + /// Center a value modulo p as i64 in constant time. /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the value being centered. - const unsafe fn center_vt(&self, a: u64) -> i64 { + /// The output is in the interval `[-p/2, p/2)`. + /// Aborts if `a >= p` in debug mode. + #[must_use] + pub const fn center(&self, a: u64) -> i64 { debug_assert!(a < self.p); - if a >= self.p >> 1 { - (a as i64) - (self.p as i64) - } else { - a as i64 - } + let threshold = self.p >> 1; + let cond = a >= threshold; + let on_true = (a as i64).wrapping_sub(self.p as i64) as u64; + let on_false = a; + + const_time_cond_select(on_true, on_false, cond) as i64 } - /// Center a vector in variable time. - /// - /// # Safety - /// This function is not constant time and its timing may reveal information - /// about the values being centered. + /// Center a vector in constant time. #[must_use] - pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec { - self.arch.dispatch(|| { - a.iter() - .map(|ai| unsafe { self.center_vt(*ai) }) - .collect_vec() - }) + pub fn center_vec(&self, a: &[u64]) -> Vec { + self.arch + .dispatch(|| a.iter().map(|ai| self.center(*ai)).collect_vec()) } /// Reduce a vector in place in variable time. @@ -1098,6 +1091,29 @@ mod tests { prop_assert_eq!(a, c); } + #[test] + fn center(p in valid_moduli(), a: u64) { + let a = p.reduce(a); + let b = p.center(a); + if a >= *p >> 1 { + prop_assert_eq!(b, (a as i64) - (*p as i64)); + } else { + prop_assert_eq!(b, a as i64); + } + prop_assert_eq!(p.reduce_i64(b), a); + } + + #[test] + fn center_vec(p in valid_moduli(), a: Vec) { + let mut a = a.clone(); + p.reduce_vec(&mut a); + let b = p.center_vec(&a); + prop_assert_eq!(b.len(), a.len()); + for (ai, bi) in izip!(a.iter(), b.iter()) { + prop_assert_eq!(p.center(*ai), *bi); + } + } + #[test] fn mul_opt(p in valid_moduli_opt(), mut a: u64, mut b: u64) { a = p.reduce(a); @@ -1114,9 +1130,7 @@ mod tests { prop_assert!(std::panic::catch_unwind(|| p.mul_opt(a, *p + 1)).is_err()); } } - } - proptest! { #[test] fn pow(p in valid_moduli(), mut a: u64, mut b: u64) { a = p.reduce(a); diff --git a/crates/fhe/src/bfv/plaintext.rs b/crates/fhe/src/bfv/plaintext.rs index ae31abd7..2390ee66 100644 --- a/crates/fhe/src/bfv/plaintext.rs +++ b/crates/fhe/src/bfv/plaintext.rs @@ -258,7 +258,7 @@ impl FheDecoder for Vec<i64> { E: Into<Option<Encoding>>, { let v = Vec::<u64>::try_decode(pt, encoding)?; - Ok(unsafe { pt.par.plaintext.center_vec_vt(&v) }) + Ok(pt.par.plaintext.center_vec(&v)) } type Error = Error; @@ -322,7 +322,7 @@ mod tests { let b = Vec::<u64>::try_decode(&plaintext?, Encoding::simd())?; assert_eq!(b, a); - let a = unsafe { params.plaintext.center_vec_vt(&a) }; + let a = params.plaintext.center_vec(&a); let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params); assert!(plaintext.is_ok()); let b = Vec::<i64>::try_decode(&plaintext?, Encoding::poly())?;