Skip to content

Commit

Permalink
Add Shl and Shr
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush committed Dec 28, 2023
1 parent e8f8b3b commit fd16996
Show file tree
Hide file tree
Showing 5 changed files with 411 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/uint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mod convert;
mod eq;
mod not;
mod or;
mod shl;
mod shr;
mod rotate;
mod select;
mod xor;
Expand Down Expand Up @@ -81,7 +83,7 @@ impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {

for i in 0..N {
bits[i] = Boolean::constant((bit_values & T::one()) == T::one());
bit_values = bit_values >> 1;
bit_values = bit_values >> 1u8;
}

Self {
Expand Down
40 changes: 40 additions & 0 deletions src/uint/prim_uint.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,39 @@
use core::usize;
use core::ops::{Shl, ShlAssign, Shr, ShrAssign};

#[doc(hidden)]
// Adapted from <https://github.com/rust-num/num-traits/pull/224>
pub trait PrimUInt:
core::fmt::Debug
+ num_traits::PrimInt
+ num_traits::WrappingAdd
+ num_traits::SaturatingAdd
+ Shl<usize, Output = Self>
+ Shl<u8, Output = Self>
+ Shl<u16, Output = Self>
+ Shl<u32, Output = Self>
+ Shl<u64, Output = Self>
+ Shl<u128, Output = Self>
+ Shr<usize, Output = Self>
+ Shr<u8, Output = Self>
+ Shr<u16, Output = Self>
+ Shr<u32, Output = Self>
+ Shr<u64, Output = Self>
+ Shr<u128, Output = Self>
+ ShlAssign<usize>
+ ShlAssign<u8>
+ ShlAssign<u16>
+ ShlAssign<u32>
+ ShlAssign<u64>
+ ShlAssign<u128>
+ ShrAssign<usize>
+ ShrAssign<u8>
+ ShrAssign<u16>
+ ShrAssign<u32>
+ ShrAssign<u64>
+ ShrAssign<u128>
+ Into<u128>
+ _private::Sealed
+ ark_std::UniformRand
{
type Bytes: NumBytes;
Expand Down Expand Up @@ -134,3 +163,14 @@ pub trait NumBytes:

#[doc(hidden)]
impl<const N: usize> NumBytes for [u8; N] {}


mod _private {
pub trait Sealed {}

impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
impl Sealed for u128 {}
}
160 changes: 160 additions & 0 deletions src/uint/shl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::Shl, ops::ShlAssign};

use crate::boolean::Boolean;

use super::{PrimUInt, UInt};

impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
fn _shl_u128(&self, other: u128) -> Result<Self, SynthesisError> {
if other < N as u128 {
let mut bits = [Boolean::FALSE; N];
for (a, b) in bits[other as usize..].iter_mut().zip(&self.bits) {
*a = b.clone();
}

let value = self.value.and_then(|a| Some(a << other));
Ok(Self {
bits,
value,
})
} else {
panic!("attempt to shift left with overflow")
}
}
}

impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl<T2> for UInt<N, T, F> {
type Output = Self;

/// Output `self << other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?;
///
/// (a << 1).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shl(self, other: T2) -> Self::Output {
self._shl_u128(other.into()).unwrap()
}
}

impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl<T2> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shl(self, other: T2) -> Self::Output {
self._shl_u128(other.into()).unwrap()
}
}

impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShlAssign<T2> for UInt<N, T, F> {
/// Sets `self = self << other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?;
///
/// a <<= b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shl_assign(&mut self, other: T2) {
let result = self._shl_u128(other.into()).unwrap();
*self = result;
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;

fn uint_shl<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let b = b.into() % (N as u128);
let computed = &a << b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()? << b),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}

#[test]
fn u8_shl() {
run_binary_exhaustive_with_native(uint_shl::<u8, 8, Fr>).unwrap()
}

#[test]
fn u16_shl() {
run_binary_random_with_native::<1000, 16, _, _>(uint_shl::<u16, 16, Fr>).unwrap()
}

#[test]
fn u32_shl() {
run_binary_random_with_native::<1000, 32, _, _>(uint_shl::<u32, 32, Fr>).unwrap()
}

#[test]
fn u64_shl() {
run_binary_random_with_native::<1000, 64, _, _>(uint_shl::<u64, 64, Fr>).unwrap()
}

#[test]
fn u128_shl() {
run_binary_random_with_native::<1000, 128, _, _>(uint_shl::<u128, 128, Fr>).unwrap()
}
}
160 changes: 160 additions & 0 deletions src/uint/shr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::Shr, ops::ShrAssign};

use crate::boolean::Boolean;

use super::{PrimUInt, UInt};

impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
fn _shr_u128(&self, other: u128) -> Result<Self, SynthesisError> {
if other < N as u128 {
let mut bits = [Boolean::FALSE; N];
for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) {
*a = b.clone();
}

let value = self.value.and_then(|a| Some(a >> other));
Ok(Self {
bits,
value,
})
} else {
panic!("attempt to shift right with overflow")
}
}
}

impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for UInt<N, T, F> {
type Output = Self;

/// Output `self >> other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?;
///
/// (a >> 1).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shr(self, other: T2) -> Self::Output {
self._shr_u128(other.into()).unwrap()
}
}

impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shr(self, other: T2) -> Self::Output {
self._shr_u128(other.into()).unwrap()
}
}

impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShrAssign<T2> for UInt<N, T, F> {
/// Sets `self = self >> other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?;
///
/// a >> b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shr_assign(&mut self, other: T2) {
let result = self._shr_u128(other.into()).unwrap();
*self = result;
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;

fn uint_shr<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let b = b.into() % (N as u128);
let computed = &a >> b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()? >> b),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}

#[test]
fn u8_shr() {
run_binary_exhaustive_with_native(uint_shr::<u8, 8, Fr>).unwrap()
}

#[test]
fn u16_shr() {
run_binary_random_with_native::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
}

#[test]
fn u32_shr() {
run_binary_random_with_native::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
}

#[test]
fn u64_shr() {
run_binary_random_with_native::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
}

#[test]
fn u128_shr() {
run_binary_random_with_native::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
}
}
Loading

0 comments on commit fd16996

Please sign in to comment.