diff --git a/src/uint/mod.rs b/src/uint/mod.rs index 26598675..11c5efac 100644 --- a/src/uint/mod.rs +++ b/src/uint/mod.rs @@ -12,6 +12,8 @@ mod convert; mod eq; mod not; mod or; +mod shl; +mod shr; mod rotate; mod select; mod xor; @@ -81,7 +83,7 @@ impl UInt { for i in 0..N { bits[i] = Boolean::constant((bit_values & T::one()) == T::one()); - bit_values = bit_values >> 1; + bit_values = bit_values >> 1u8; } Self { diff --git a/src/uint/prim_uint.rs b/src/uint/prim_uint.rs index 963db0f7..7f8c21ca 100644 --- a/src/uint/prim_uint.rs +++ b/src/uint/prim_uint.rs @@ -1,3 +1,6 @@ +use core::usize; +use core::ops::{Shl, ShlAssign, Shr, ShrAssign}; + #[doc(hidden)] // Adapted from pub trait PrimUInt: @@ -5,6 +8,32 @@ pub trait PrimUInt: + num_traits::PrimInt + num_traits::WrappingAdd + num_traits::SaturatingAdd + + Shl + + Shl + + Shl + + Shl + + Shl + + Shl + + Shr + + Shr + + Shr + + Shr + + Shr + + Shr + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + Into + + _private::Sealed + ark_std::UniformRand { type Bytes: NumBytes; @@ -134,3 +163,14 @@ pub trait NumBytes: #[doc(hidden)] impl NumBytes for [u8; N] {} + + +mod _private { + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for u128 {} +} \ No newline at end of file diff --git a/src/uint/shl.rs b/src/uint/shl.rs new file mode 100644 index 00000000..2aaf474c --- /dev/null +++ b/src/uint/shl.rs @@ -0,0 +1,160 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::Shl, ops::ShlAssign}; + +use crate::boolean::Boolean; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _shl_u128(&self, other: u128) -> Result { + if other < N as u128 { + let mut bits = [Boolean::FALSE; N]; + for (a, b) in bits[other as usize..].iter_mut().zip(&self.bits) { + *a = b.clone(); + } + + let value = self.value.and_then(|a| Some(a << other)); + Ok(Self { + bits, + value, + }) + } else { + panic!("attempt to shift left with overflow") + } + } +} + +impl Shl for UInt { + type Output = Self; + + /// Output `self << other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; + /// + /// (a << 1).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl(self, other: T2) -> Self::Output { + self._shl_u128(other.into()).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl(self, other: T2) -> Self::Output { + self._shl_u128(other.into()).unwrap() + } +} + +impl ShlAssign for UInt { + /// Sets `self = self << other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; + /// + /// a <<= b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl_assign(&mut self, other: T2) { + let result = self._shl_u128(other.into()).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_shl( + a: UInt, + b: T, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let b = b.into() % (N as u128); + let computed = &a << b; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? << b), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_shl() { + run_binary_exhaustive_with_native(uint_shl::).unwrap() + } + + #[test] + fn u16_shl() { + run_binary_random_with_native::<1000, 16, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u32_shl() { + run_binary_random_with_native::<1000, 32, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u64_shl() { + run_binary_random_with_native::<1000, 64, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u128_shl() { + run_binary_random_with_native::<1000, 128, _, _>(uint_shl::).unwrap() + } +} diff --git a/src/uint/shr.rs b/src/uint/shr.rs new file mode 100644 index 00000000..dbcdc8d7 --- /dev/null +++ b/src/uint/shr.rs @@ -0,0 +1,160 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::Shr, ops::ShrAssign}; + +use crate::boolean::Boolean; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _shr_u128(&self, other: u128) -> Result { + if other < N as u128 { + let mut bits = [Boolean::FALSE; N]; + for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) { + *a = b.clone(); + } + + let value = self.value.and_then(|a| Some(a >> other)); + Ok(Self { + bits, + value, + }) + } else { + panic!("attempt to shift right with overflow") + } + } +} + +impl Shr for UInt { + type Output = Self; + + /// Output `self >> other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; + /// + /// (a >> 1).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr(self, other: T2) -> Self::Output { + self._shr_u128(other.into()).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr(self, other: T2) -> Self::Output { + self._shr_u128(other.into()).unwrap() + } +} + +impl ShrAssign for UInt { + /// Sets `self = self >> other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; + /// + /// a >> b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr_assign(&mut self, other: T2) { + let result = self._shr_u128(other.into()).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_shr( + a: UInt, + b: T, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let b = b.into() % (N as u128); + let computed = &a >> b; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? >> b), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_shr() { + run_binary_exhaustive_with_native(uint_shr::).unwrap() + } + + #[test] + fn u16_shr() { + run_binary_random_with_native::<1000, 16, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u32_shr() { + run_binary_random_with_native::<1000, 32, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u64_shr() { + run_binary_random_with_native::<1000, 64, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u128_shr() { + run_binary_random_with_native::<1000, 128, _, _>(uint_shr::).unwrap() + } +} diff --git a/src/uint/test_utils.rs b/src/uint/test_utils.rs index 13716833..5b3503e8 100644 --- a/src/uint/test_utils.rs +++ b/src/uint/test_utils.rs @@ -28,6 +28,17 @@ pub(crate) fn test_binary_op( test(a, b) } +pub(crate) fn test_binary_op_with_native( + a: T, + b: T, + mode_a: AllocationMode, + test: impl FnOnce(UInt, T) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode_a)?; + test(a, b) +} + pub(crate) fn run_binary_random( test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError> @@ -65,6 +76,43 @@ where Ok(()) } +pub(crate) fn run_binary_random_with_native( + test: impl Fn(UInt, T) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for _ in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + let b = T::rand(&mut rng); + test_binary_op_with_native(a, b, mode_a, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_binary_exhaustive_with_native( + test: impl Fn(UInt, T) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) { + for b in T::min_value()..=T::max_value() { + test_binary_op_with_native(a, b, mode_a, test)?; + } + } + Ok(()) +} + + + pub(crate) fn run_unary_random( test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, ) -> Result<(), SynthesisError>