diff --git a/src/bits/boolean.rs b/src/bits/boolean.rs deleted file mode 100644 index 9f0be047..00000000 --- a/src/bits/boolean.rs +++ /dev/null @@ -1,1823 +0,0 @@ -use ark_ff::{BitIteratorBE, Field, PrimeField}; - -use crate::{fields::fp::FpVar, prelude::*, Assignment, ToConstraintFieldGadget, Vec}; -use ark_relations::r1cs::{ - ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, -}; -use core::borrow::Borrow; - -/// Represents a variable in the constraint system which is guaranteed -/// to be either zero or one. -/// -/// In general, one should prefer using `Boolean` instead of `AllocatedBool`, -/// as `Boolean` offers better support for constant values, and implements -/// more traits. -#[derive(Clone, Debug, Eq, PartialEq)] -#[must_use] -pub struct AllocatedBool { - variable: Variable, - cs: ConstraintSystemRef, -} - -pub(crate) fn bool_to_field(val: impl Borrow) -> F { - if *val.borrow() { - F::one() - } else { - F::zero() - } -} - -impl AllocatedBool { - /// Get the assigned value for `self`. - pub fn value(&self) -> Result { - let value = self.cs.assigned_value(self.variable).get()?; - if value.is_zero() { - Ok(false) - } else if value.is_one() { - Ok(true) - } else { - unreachable!("Incorrect value assigned: {:?}", value); - } - } - - /// Get the R1CS variable for `self`. - pub fn variable(&self) -> Variable { - self.variable - } - - /// Allocate a witness variable without a booleanity check. - pub(crate) fn new_witness_without_booleanity_check>( - cs: ConstraintSystemRef, - f: impl FnOnce() -> Result, - ) -> Result { - let variable = cs.new_witness_variable(|| f().map(bool_to_field))?; - Ok(Self { variable, cs }) - } - - /// Performs an XOR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn xor(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? ^ b.value()?) - })?; - - // Constrain (a + a) * (b) = (a + b - c) - // Given that a and b are boolean constrained, if they - // are equal, the only solution for c is 0, and if they - // are different, the only solution for c is 1. - // - // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c - // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c - // (1 - ab) * (1 - (1 - a - b + ab)) = c - // (1 - ab) * (a + b - ab) = c - // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c - // a + b - ab - ab - ab + ab = c - // a + b - 2ab = c - // -2a * b = c - a - b - // 2a * b = a + b - c - // (a + a) * b = a + b - c - self.cs.enforce_constraint( - lc!() + self.variable + self.variable, - lc!() + b.variable, - lc!() + self.variable + b.variable - result.variable, - )?; - - Ok(result) - } - - /// Performs an AND operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn and(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? & b.value()?) - })?; - - // Constrain (a) * (b) = (c), ensuring c is 1 iff - // a AND b are both 1. - self.cs.enforce_constraint( - lc!() + self.variable, - lc!() + b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } - - /// Performs an OR operation over the two operands, returning - /// an `AllocatedBool`. - #[tracing::instrument(target = "r1cs")] - pub fn or(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? | b.value()?) - })?; - - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + Variable::One - self.variable, - lc!() + Variable::One - b.variable, - lc!() + Variable::One - result.variable, - )?; - - Ok(result) - } - - /// Calculates `a AND (NOT b)`. - #[tracing::instrument(target = "r1cs")] - pub fn and_not(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(self.value()? & !b.value()?) - })?; - - // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff - // a is true and b is false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + self.variable, - lc!() + Variable::One - b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } - - /// Calculates `(NOT a) AND (NOT b)`. - #[tracing::instrument(target = "r1cs")] - pub fn nor(&self, b: &Self) -> Result { - let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { - Ok(!(self.value()? | b.value()?)) - })?; - - // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff - // a and b are both false, and otherwise c is 0. - self.cs.enforce_constraint( - lc!() + Variable::One - self.variable, - lc!() + Variable::One - b.variable, - lc!() + result.variable, - )?; - - Ok(result) - } -} - -impl AllocVar for AllocatedBool { - /// Produces a new variable of the appropriate kind - /// (instance or witness), with a booleanity check. - /// - /// N.B.: we could omit the booleanity check when allocating `self` - /// as a new public input, but that places an additional burden on - /// protocol designers. Better safe than sorry! - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - if mode == AllocationMode::Constant { - let variable = if *f()?.borrow() { - Variable::One - } else { - Variable::Zero - }; - Ok(Self { variable, cs }) - } else { - let variable = if mode == AllocationMode::Input { - cs.new_input_variable(|| f().map(bool_to_field))? - } else { - cs.new_witness_variable(|| f().map(bool_to_field))? - }; - - // Constrain: (1 - a) * a = 0 - // This constrains a to be either 0 or 1. - - cs.enforce_constraint(lc!() + Variable::One - variable, lc!() + variable, lc!())?; - - Ok(Self { variable, cs }) - } - } -} - -impl CondSelectGadget for AllocatedBool { - #[tracing::instrument(target = "r1cs")] - fn conditionally_select( - cond: &Boolean, - true_val: &Self, - false_val: &Self, - ) -> Result { - let res = Boolean::conditionally_select( - cond, - &true_val.clone().into(), - &false_val.clone().into(), - )?; - match res { - Boolean::Is(a) => Ok(a), - _ => unreachable!("Impossible"), - } - } -} - -/// Represents a boolean value in the constraint system which is guaranteed -/// to be either zero or one. -#[derive(Clone, Debug, Eq, PartialEq)] -#[must_use] -pub enum Boolean { - /// Existential view of the boolean variable. - Is(AllocatedBool), - /// Negated view of the boolean variable. - Not(AllocatedBool), - /// Constant (not an allocated variable). - Constant(bool), -} - -impl R1CSVar for Boolean { - type Value = bool; - - fn cs(&self) -> ConstraintSystemRef { - match self { - Self::Is(a) | Self::Not(a) => a.cs.clone(), - _ => ConstraintSystemRef::None, - } - } - - fn value(&self) -> Result { - match self { - Boolean::Constant(c) => Ok(*c), - Boolean::Is(ref v) => v.value(), - Boolean::Not(ref v) => v.value().map(|b| !b), - } - } -} - -impl Boolean { - /// The constant `true`. - pub const TRUE: Self = Boolean::Constant(true); - - /// The constant `false`. - pub const FALSE: Self = Boolean::Constant(false); - - /// Constructs a `LinearCombination` from `Self`'s variables according - /// to the following map. - /// - /// * `Boolean::Constant(true) => lc!() + Variable::One` - /// * `Boolean::Constant(false) => lc!()` - /// * `Boolean::Is(v) => lc!() + v.variable()` - /// * `Boolean::Not(v) => lc!() + Variable::One - v.variable()` - pub fn lc(&self) -> LinearCombination { - match self { - Boolean::Constant(false) => lc!(), - Boolean::Constant(true) => lc!() + Variable::One, - Boolean::Is(v) => v.variable().into(), - Boolean::Not(v) => lc!() + Variable::One - v.variable(), - } - } - - /// Constructs a `Boolean` vector from a slice of constant `u8`. - /// The `u8`s are decomposed in little-endian manner. - /// - /// This *does not* create any new variables or constraints. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let t = Boolean::::TRUE; - /// let f = Boolean::::FALSE; - /// - /// let bits = vec![f, t]; - /// let generated_bits = Boolean::constant_vec_from_bytes(&[2]); - /// bits[..2].enforce_equal(&generated_bits[..2])?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn constant_vec_from_bytes(values: &[u8]) -> Vec { - let mut bits = vec![]; - for byte in values { - for i in 0..8 { - bits.push(Self::Constant(((byte >> i) & 1u8) == 1u8)); - } - } - bits - } - - /// Constructs a constant `Boolean` with value `b`. - /// - /// This *does not* create any new variables or constraints. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_r1cs_std::prelude::*; - /// - /// let true_var = Boolean::::TRUE; - /// let false_var = Boolean::::FALSE; - /// - /// true_var.enforce_equal(&Boolean::constant(true))?; - /// false_var.enforce_equal(&Boolean::constant(false))?; - /// # Ok(()) - /// # } - /// ``` - pub fn constant(b: bool) -> Self { - Boolean::Constant(b) - } - - /// Negates `self`. - /// - /// This *does not* create any new variables or constraints. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.not().enforce_equal(&b)?; - /// b.not().enforce_equal(&a)?; - /// - /// a.not().enforce_equal(&Boolean::FALSE)?; - /// b.not().enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn not(&self) -> Self { - match *self { - Boolean::Constant(c) => Boolean::Constant(!c), - Boolean::Is(ref v) => Boolean::Not(v.clone()), - Boolean::Not(ref v) => Boolean::Is(v.clone()), - } - } - - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.xor(&b)?.enforce_equal(&Boolean::TRUE)?; - /// b.xor(&a)?.enforce_equal(&Boolean::TRUE)?; - /// - /// a.xor(&a)?.enforce_equal(&Boolean::FALSE)?; - /// b.xor(&b)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn xor<'a>(&'a self, other: &'a Self) -> Result { - use Boolean::*; - match (self, other) { - (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), - (&Constant(true), x) | (x, &Constant(true)) => Ok(x.not()), - // a XOR (NOT b) = NOT(a XOR b) - (is @ &Is(_), not @ &Not(_)) | (not @ &Not(_), is @ &Is(_)) => { - Ok(is.xor(¬.not())?.not()) - }, - // a XOR b = (NOT a) XOR (NOT b) - (&Is(ref a), &Is(ref b)) | (&Not(ref a), &Not(ref b)) => Ok(Is(a.xor(b)?)), - } - } - - /// Outputs `self | other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.or(&b)?.enforce_equal(&Boolean::TRUE)?; - /// b.or(&a)?.enforce_equal(&Boolean::TRUE)?; - /// - /// a.or(&a)?.enforce_equal(&Boolean::TRUE)?; - /// b.or(&b)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn or<'a>(&'a self, other: &'a Self) -> Result { - use Boolean::*; - match (self, other) { - (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), - (&Constant(true), _) | (_, &Constant(true)) => Ok(Constant(true)), - // a OR b = NOT ((NOT a) AND (NOT b)) - (a @ &Is(_), b @ &Not(_)) | (b @ &Not(_), a @ &Is(_)) | (b @ &Not(_), a @ &Not(_)) => { - Ok(a.not().and(&b.not())?.not()) - }, - (&Is(ref a), &Is(ref b)) => a.or(b).map(From::from), - } - } - - /// Outputs `self & other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// a.and(&a)?.enforce_equal(&Boolean::TRUE)?; - /// - /// a.and(&b)?.enforce_equal(&Boolean::FALSE)?; - /// b.and(&a)?.enforce_equal(&Boolean::FALSE)?; - /// b.and(&b)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn and<'a>(&'a self, other: &'a Self) -> Result { - use Boolean::*; - match (self, other) { - // false AND x is always false - (&Constant(false), _) | (_, &Constant(false)) => Ok(Constant(false)), - // true AND x is always x - (&Constant(true), x) | (x, &Constant(true)) => Ok(x.clone()), - // a AND (NOT b) - (&Is(ref is), &Not(ref not)) | (&Not(ref not), &Is(ref is)) => Ok(Is(is.and_not(not)?)), - // (NOT a) AND (NOT b) = a NOR b - (&Not(ref a), &Not(ref b)) => Ok(Is(a.nor(b)?)), - // a AND b - (&Is(ref a), &Is(ref b)) => Ok(Is(a.and(b)?)), - } - } - - /// Outputs `bits[0] & bits[1] & ... & bits.last().unwrap()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// Boolean::kary_and(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// Boolean::kary_and(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_and(bits: &[Self]) -> Result { - assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b.and(next)?) - } else { - Some(next.clone()) - }; - } - - Ok(cur.expect("should not be 0")) - } - - /// Outputs `bits[0] | bits[1] | ... | bits.last().unwrap()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// Boolean::kary_or(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_or(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_or(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_or(bits: &[Self]) -> Result { - assert!(!bits.is_empty()); - let mut cur: Option = None; - for next in bits { - cur = if let Some(b) = cur { - Some(b.or(next)?) - } else { - Some(next.clone()) - }; - } - - Ok(cur.expect("should not be 0")) - } - - /// Outputs `(bits[0] & bits[1] & ... & bits.last().unwrap()).not()`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// Boolean::kary_nand(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// Boolean::kary_nand(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; - /// Boolean::kary_nand(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn kary_nand(bits: &[Self]) -> Result { - Ok(Self::kary_and(bits)?.not()) - } - - /// Enforces that `Self::kary_nand(bits).is_eq(&Boolean::TRUE)`. - /// - /// Informally, this means that at least one element in `bits` must be - /// `false`. - #[tracing::instrument(target = "r1cs")] - fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { - use Boolean::*; - let r = Self::kary_nand(bits)?; - match r { - Constant(true) => Ok(()), - Constant(false) => Err(SynthesisError::AssignmentMissing), - Is(_) | Not(_) => { - r.cs() - .enforce_constraint(r.lc(), lc!() + Variable::One, lc!() + Variable::One) - }, - } - } - - /// Convert a little-endian bitwise representation of a field element to - /// `FpVar` - #[tracing::instrument(target = "r1cs", skip(bits))] - pub fn le_bits_to_fp_var(bits: &[Self]) -> Result, SynthesisError> - where - F: PrimeField, - { - // Compute the value of the `FpVar` variable via double-and-add. - let mut value = None; - let cs = bits.cs(); - // Assign a value only when `cs` is in setup mode, or if we are constructing - // a constant. - let should_construct_value = (!cs.is_in_setup_mode()) || bits.is_constant(); - if should_construct_value { - let bits = bits.iter().map(|b| b.value().unwrap()).collect::>(); - let bytes = bits - .chunks(8) - .map(|c| { - let mut value = 0u8; - for (i, &bit) in c.iter().enumerate() { - value += (bit as u8) << i; - } - value - }) - .collect::>(); - value = Some(F::from_le_bytes_mod_order(&bytes)); - } - - if bits.is_constant() { - Ok(FpVar::constant(value.unwrap())) - } else { - let mut power = F::one(); - // Compute a linear combination for the new field variable, again - // via double and add. - let mut combined_lc = LinearCombination::zero(); - bits.iter().for_each(|b| { - combined_lc = &combined_lc + (power, b.lc()); - power.double_in_place(); - }); - // Allocate the new variable as a SymbolicLc - let variable = cs.new_lc(combined_lc)?; - // If the number of bits is less than the size of the field, - // then we do not need to enforce that the element is less than - // the modulus. - if bits.len() >= F::MODULUS_BIT_SIZE as usize { - Self::enforce_in_field_le(bits)?; - } - Ok(crate::fields::fp::AllocatedFp::new(value, variable, cs.clone()).into()) - } - } - - /// Enforces that `bits`, when interpreted as a integer, is less than - /// `F::characteristic()`, That is, interpret bits as a little-endian - /// integer, and enforce that this integer is "in the field Z_p", where - /// `p = F::characteristic()` . - #[tracing::instrument(target = "r1cs")] - pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> { - // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 - let mut b = F::characteristic().to_vec(); - assert_eq!(b[0] % 2, 1); - b[0] -= 1; // This works, because the LSB is one, so there's no borrows. - let run = Self::enforce_smaller_or_equal_than_le(bits, b)?; - - // We should always end in a "run" of zeros, because - // the characteristic is an odd prime. So, this should - // be empty. - assert!(run.is_empty()); - - Ok(()) - } - - /// Enforces that `bits` is less than or equal to `element`, - /// when both are interpreted as (little-endian) integers. - #[tracing::instrument(target = "r1cs", skip(element))] - pub fn enforce_smaller_or_equal_than_le<'a>( - bits: &[Self], - element: impl AsRef<[u64]>, - ) -> Result, SynthesisError> { - let b: &[u64] = element.as_ref(); - - let mut bits_iter = bits.iter().rev(); // Iterate in big-endian - - // Runs of ones in r - let mut last_run = Boolean::constant(true); - let mut current_run = vec![]; - - let mut element_num_bits = 0; - for _ in BitIteratorBE::without_leading_zeros(b) { - element_num_bits += 1; - } - - if bits.len() > element_num_bits { - let mut or_result = Boolean::constant(false); - for should_be_zero in &bits[element_num_bits..] { - or_result = or_result.or(should_be_zero)?; - let _ = bits_iter.next().unwrap(); - } - or_result.enforce_equal(&Boolean::constant(false))?; - } - - for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { - if b { - // This is part of a run of ones. - current_run.push(a.clone()); - } else { - if !current_run.is_empty() { - // This is the start of a run of zeros, but we need - // to k-ary AND against `last_run` first. - - current_run.push(last_run.clone()); - last_run = Self::kary_and(¤t_run)?; - current_run.truncate(0); - } - - // If `last_run` is true, `a` must be false, or it would - // not be in the field. - // - // If `last_run` is false, `a` can be true or false. - // - // Ergo, at least one of `last_run` and `a` must be false. - Self::enforce_kary_nand(&[last_run.clone(), a.clone()])?; - } - } - assert!(bits_iter.next().is_none()); - - Ok(current_run) - } - - /// Conditionally selects one of `first` and `second` based on the value of - /// `self`: - /// - /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs - /// `second`. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// - /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; - /// - /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?; - /// - /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?; - /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?; - /// - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs", skip(first, second))] - pub fn select>( - &self, - first: &T, - second: &T, - ) -> Result { - T::conditionally_select(&self, first, second) - } -} - -impl From> for Boolean { - fn from(b: AllocatedBool) -> Self { - Boolean::Is(b) - } -} - -impl AllocVar for Boolean { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - if mode == AllocationMode::Constant { - Ok(Boolean::Constant(*f()?.borrow())) - } else { - AllocatedBool::new_variable(cs, f, mode).map(Boolean::from) - } - } -} - -impl EqGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - // self | other | XNOR(self, other) | self == other - // -----|-------|-------------------|-------------- - // 0 | 0 | 1 | 1 - // 0 | 1 | 0 | 0 - // 1 | 0 | 0 | 0 - // 1 | 1 | 1 | 1 - Ok(self.xor(other)?.not()) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - use Boolean::*; - let one = Variable::One; - let difference = match (self, other) { - // 1 == 1; 0 == 0 - (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), - // false != true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), - // 1 - a - (Constant(true), Is(a)) | (Is(a), Constant(true)) => lc!() + one - a.variable(), - // a - 0 = a - (Constant(false), Is(a)) | (Is(a), Constant(false)) => lc!() + a.variable(), - // 1 - !a = 1 - (1 - a) = a - (Constant(true), Not(a)) | (Not(a), Constant(true)) => lc!() + a.variable(), - // !a - 0 = !a = 1 - a - (Constant(false), Not(a)) | (Not(a), Constant(false)) => lc!() + one - a.variable(), - // b - a, - (Is(a), Is(b)) => lc!() + b.variable() - a.variable(), - // !b - a = (1 - b) - a - (Is(a), Not(b)) | (Not(b), Is(a)) => lc!() + one - b.variable() - a.variable(), - // !b - !a = (1 - b) - (1 - a) = a - b, - (Not(a), Not(b)) => lc!() + a.variable() - b.variable(), - }; - - if condition != &Constant(false) { - let cs = self.cs().or(other.cs()).or(condition.cs()); - cs.enforce_constraint(lc!() + difference, condition.lc(), lc!())?; - } - Ok(()) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_not_equal( - &self, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - use Boolean::*; - let one = Variable::One; - let difference = match (self, other) { - // 1 != 0; 0 != 1 - (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => return Ok(()), - // false == false and true == true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), - // 1 - a - (Constant(true), Is(a)) | (Is(a), Constant(true)) => lc!() + one - a.variable(), - // a - 0 = a - (Constant(false), Is(a)) | (Is(a), Constant(false)) => lc!() + a.variable(), - // 1 - !a = 1 - (1 - a) = a - (Constant(true), Not(a)) | (Not(a), Constant(true)) => lc!() + a.variable(), - // !a - 0 = !a = 1 - a - (Constant(false), Not(a)) | (Not(a), Constant(false)) => lc!() + one - a.variable(), - // b - a, - (Is(a), Is(b)) => lc!() + b.variable() - a.variable(), - // !b - a = (1 - b) - a - (Is(a), Not(b)) | (Not(b), Is(a)) => lc!() + one - b.variable() - a.variable(), - // !b - !a = (1 - b) - (1 - a) = a - b, - (Not(a), Not(b)) => lc!() + a.variable() - b.variable(), - }; - - if should_enforce != &Constant(false) { - let cs = self.cs().or(other.cs()).or(should_enforce.cs()); - cs.enforce_constraint(difference, should_enforce.lc(), should_enforce.lc())?; - } - Ok(()) - } -} - -impl ToBytesGadget for Boolean { - /// Outputs `1u8` if `self` is true, and `0u8` otherwise. - #[tracing::instrument(target = "r1cs")] - fn to_bytes(&self) -> Result>, SynthesisError> { - let value = self.value().map(u8::from).ok(); - let mut bits = [Boolean::FALSE; 8]; - bits[0] = self.clone(); - Ok(vec![UInt8 { bits, value }]) - } -} - -impl ToConstraintFieldGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn to_constraint_field(&self) -> Result>, SynthesisError> { - let var = From::from(self.clone()); - Ok(vec![var]) - } -} - -impl CondSelectGadget for Boolean { - #[tracing::instrument(target = "r1cs")] - fn conditionally_select( - cond: &Boolean, - true_val: &Self, - false_val: &Self, - ) -> Result { - use Boolean::*; - match cond { - Constant(true) => Ok(true_val.clone()), - Constant(false) => Ok(false_val.clone()), - cond @ Not(_) => Self::conditionally_select(&cond.not(), false_val, true_val), - cond @ Is(_) => match (true_val, false_val) { - (x, &Constant(false)) => cond.and(x), - (&Constant(false), x) => cond.not().and(x), - (&Constant(true), x) => cond.or(x), - (x, &Constant(true)) => cond.not().or(x), - (a, b) => { - let cs = cond.cs(); - let result: Boolean = - AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || { - let cond = cond.value()?; - Ok(if cond { a.value()? } else { b.value()? }) - })? - .into(); - // a = self; b = other; c = cond; - // - // r = c * a + (1 - c) * b - // r = b + c * (a - b) - // c * (a - b) = r - b - // - // If a, b, cond are all boolean, so is r. - // - // self | other | cond | result - // -----|-------|---------------- - // 0 | 0 | 1 | 0 - // 0 | 1 | 1 | 0 - // 1 | 0 | 1 | 1 - // 1 | 1 | 1 | 1 - // 0 | 0 | 0 | 0 - // 0 | 1 | 0 | 1 - // 1 | 0 | 0 | 0 - // 1 | 1 | 0 | 1 - cs.enforce_constraint( - cond.lc(), - lc!() + a.lc() - b.lc(), - lc!() + result.lc() - b.lc(), - )?; - - Ok(result) - }, - }, - } - } -} - -#[cfg(test)] -mod test { - use super::{AllocatedBool, Boolean}; - use crate::prelude::*; - use ark_ff::{ - AdditiveGroup, BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand, Zero, - }; - use ark_relations::r1cs::{ConstraintSystem, Namespace, SynthesisError}; - use ark_test_curves::bls12_381::Fr; - - #[test] - fn test_boolean_to_byte() -> Result<(), SynthesisError> { - for val in [true, false].iter() { - let cs = ConstraintSystem::::new_ref(); - let a = Boolean::new_witness(cs.clone(), || Ok(*val))?; - let bytes = a.to_bytes()?; - assert_eq!(bytes.len(), 1); - let byte = &bytes[0]; - assert_eq!(byte.value()?, *val as u8); - - for (i, bit) in byte.bits.iter().enumerate() { - assert_eq!(bit.value()?, (byte.value()? >> i) & 1 == 1); - } - } - Ok(()) - } - - #[test] - fn test_xor() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::xor(&a, &b)?; - assert_eq!(c.value()?, a_val ^ b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val ^ b_val)); - } - } - Ok(()) - } - - #[test] - fn test_or() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::or(&a, &b)?; - assert_eq!(c.value()?, a_val | b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val | b_val)); - } - } - Ok(()) - } - - #[test] - fn test_and() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::and(&a, &b)?; - assert_eq!(c.value()?, a_val & b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val & b_val)); - } - } - Ok(()) - } - - #[test] - fn test_and_not() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::and_not(&a, &b)?; - assert_eq!(c.value()?, a_val & !b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (a_val & !b_val)); - } - } - Ok(()) - } - - #[test] - fn test_nor() -> Result<(), SynthesisError> { - for a_val in [false, true].iter().copied() { - for b_val in [false, true].iter().copied() { - let cs = ConstraintSystem::::new_ref(); - let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; - let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; - let c = AllocatedBool::nor(&a, &b)?; - assert_eq!(c.value()?, !a_val & !b_val); - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(a.value()?, (a_val)); - assert_eq!(b.value()?, (b_val)); - assert_eq!(c.value()?, (!a_val & !b_val)); - } - } - Ok(()) - } - - #[test] - fn test_enforce_equal() -> Result<(), SynthesisError> { - for a_bool in [false, true].iter().cloned() { - for b_bool in [false, true].iter().cloned() { - for a_neg in [false, true].iter().cloned() { - for b_neg in [false, true].iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let mut a = Boolean::new_witness(cs.clone(), || Ok(a_bool))?; - let mut b = Boolean::new_witness(cs.clone(), || Ok(b_bool))?; - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - a.enforce_equal(&b)?; - - assert_eq!( - cs.is_satisfied().unwrap(), - (a_bool ^ a_neg) == (b_bool ^ b_neg) - ); - } - } - } - } - Ok(()) - } - - #[test] - fn test_conditional_enforce_equal() -> Result<(), SynthesisError> { - for a_bool in [false, true].iter().cloned() { - for b_bool in [false, true].iter().cloned() { - for a_neg in [false, true].iter().cloned() { - for b_neg in [false, true].iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - // First test if constraint system is satisfied - // when we do want to enforce the condition. - let mut a = Boolean::new_witness(cs.clone(), || Ok(a_bool))?; - let mut b = Boolean::new_witness(cs.clone(), || Ok(b_bool))?; - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - a.conditional_enforce_equal(&b, &Boolean::constant(true))?; - - assert_eq!( - cs.is_satisfied().unwrap(), - (a_bool ^ a_neg) == (b_bool ^ b_neg) - ); - - // Now test if constraint system is satisfied even - // when we don't want to enforce the condition. - let cs = ConstraintSystem::::new_ref(); - - let mut a = Boolean::new_witness(cs.clone(), || Ok(a_bool))?; - let mut b = Boolean::new_witness(cs.clone(), || Ok(b_bool))?; - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - let false_cond = - Boolean::new_witness(ark_relations::ns!(cs, "cond"), || Ok(false))?; - a.conditional_enforce_equal(&b, &false_cond)?; - - assert!(cs.is_satisfied().unwrap()); - } - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_negation() -> Result<(), SynthesisError> { - let cs = ConstraintSystem::::new_ref(); - - let mut b = Boolean::new_witness(cs.clone(), || Ok(true))?; - assert!(matches!(b, Boolean::Is(_))); - - b = b.not(); - assert!(matches!(b, Boolean::Not(_))); - - b = b.not(); - assert!(matches!(b, Boolean::Is(_))); - - b = Boolean::Constant(true); - assert!(matches!(b, Boolean::Constant(true))); - - b = b.not(); - assert!(matches!(b, Boolean::Constant(false))); - - b = b.not(); - assert!(matches!(b, Boolean::Constant(true))); - Ok(()) - } - - #[derive(Eq, PartialEq, Copy, Clone, Debug)] - enum OpType { - True, - False, - AllocatedTrue, - AllocatedFalse, - NegatedAllocatedTrue, - NegatedAllocatedFalse, - } - - const VARIANTS: [OpType; 6] = [ - OpType::True, - OpType::False, - OpType::AllocatedTrue, - OpType::AllocatedFalse, - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - ]; - - fn construct( - ns: Namespace, - operand: OpType, - ) -> Result, SynthesisError> { - let cs = ns.cs(); - - let b = match operand { - OpType::True => Boolean::constant(true), - OpType::False => Boolean::constant(false), - OpType::AllocatedTrue => Boolean::new_witness(cs, || Ok(true))?, - OpType::AllocatedFalse => Boolean::new_witness(cs, || Ok(false))?, - OpType::NegatedAllocatedTrue => Boolean::new_witness(cs, || Ok(true))?.not(), - OpType::NegatedAllocatedFalse => Boolean::new_witness(cs, || Ok(false))?.not(), - }; - Ok(b) - } - - #[test] - fn test_boolean_xor() -> Result<(), SynthesisError> { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = Boolean::xor(&a, &b)?; - - assert!(cs.is_satisfied().unwrap()); - - match (first_operand, second_operand, c) { - (OpType::True, OpType::True, Boolean::Constant(false)) => (), - (OpType::True, OpType::False, Boolean::Constant(true)) => (), - (OpType::True, OpType::AllocatedTrue, Boolean::Not(_)) => (), - (OpType::True, OpType::AllocatedFalse, Boolean::Not(_)) => (), - (OpType::True, OpType::NegatedAllocatedTrue, Boolean::Is(_)) => (), - (OpType::True, OpType::NegatedAllocatedFalse, Boolean::Is(_)) => (), - - (OpType::False, OpType::True, Boolean::Constant(true)) => (), - (OpType::False, OpType::False, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedTrue, Boolean::Is(_)) => (), - (OpType::False, OpType::AllocatedFalse, Boolean::Is(_)) => (), - (OpType::False, OpType::NegatedAllocatedTrue, Boolean::Not(_)) => (), - (OpType::False, OpType::NegatedAllocatedFalse, Boolean::Not(_)) => (), - - (OpType::AllocatedTrue, OpType::True, Boolean::Not(_)) => (), - (OpType::AllocatedTrue, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedFalse, OpType::True, Boolean::Not(_)) => (), - (OpType::AllocatedFalse, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::AllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedTrue, OpType::True, Boolean::Is(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::NegatedAllocatedTrue, OpType::AllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - - (OpType::NegatedAllocatedFalse, OpType::True, Boolean::Is(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::AllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - _ => unreachable!(), - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_cond_select() -> Result<(), SynthesisError> { - for condition in VARIANTS.iter().cloned() { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let cond = construct(ark_relations::ns!(cs, "cond"), condition)?; - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = cond.select(&a, &b)?; - - assert!( - cs.is_satisfied().unwrap(), - "failed with operands: cond: {:?}, a: {:?}, b: {:?}", - condition, - first_operand, - second_operand, - ); - assert_eq!( - c.value()?, - if cond.value()? { - a.value()? - } else { - b.value()? - } - ); - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_or() -> Result<(), SynthesisError> { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = a.or(&b)?; - - assert!(cs.is_satisfied().unwrap()); - - match (first_operand, second_operand, c.clone()) { - (OpType::True, OpType::True, Boolean::Constant(true)) => (), - (OpType::True, OpType::False, Boolean::Constant(true)) => (), - (OpType::True, OpType::AllocatedTrue, Boolean::Constant(true)) => (), - (OpType::True, OpType::AllocatedFalse, Boolean::Constant(true)) => (), - (OpType::True, OpType::NegatedAllocatedTrue, Boolean::Constant(true)) => (), - (OpType::True, OpType::NegatedAllocatedFalse, Boolean::Constant(true)) => (), - - (OpType::False, OpType::True, Boolean::Constant(true)) => (), - (OpType::False, OpType::False, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedTrue, Boolean::Is(_)) => (), - (OpType::False, OpType::AllocatedFalse, Boolean::Is(_)) => (), - (OpType::False, OpType::NegatedAllocatedTrue, Boolean::Not(_)) => (), - (OpType::False, OpType::NegatedAllocatedFalse, Boolean::Not(_)) => (), - - (OpType::AllocatedTrue, OpType::True, Boolean::Constant(true)) => (), - (OpType::AllocatedTrue, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::AllocatedFalse, OpType::True, Boolean::Constant(true)) => (), - (OpType::AllocatedFalse, OpType::False, Boolean::Is(_)) => (), - (OpType::AllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::AllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedTrue, OpType::True, Boolean::Constant(true)) => (), - (OpType::NegatedAllocatedTrue, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - (OpType::NegatedAllocatedTrue, OpType::AllocatedFalse, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedTrue, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(true)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedFalse, OpType::True, Boolean::Constant(true)) => (), - (OpType::NegatedAllocatedFalse, OpType::False, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::AllocatedTrue, Boolean::Not(ref v)) => { - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::AllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedTrue, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Not(ref v), - ) => { - assert_eq!(v.value(), Ok(false)); - }, - - _ => panic!( - "this should never be encountered, in case: (a = {:?}, b = {:?}, c = {:?})", - a, b, c - ), - } - } - } - Ok(()) - } - - #[test] - fn test_boolean_and() -> Result<(), SynthesisError> { - for first_operand in VARIANTS.iter().cloned() { - for second_operand in VARIANTS.iter().cloned() { - let cs = ConstraintSystem::::new_ref(); - - let a = construct(ark_relations::ns!(cs, "a"), first_operand)?; - let b = construct(ark_relations::ns!(cs, "b"), second_operand)?; - let c = a.and(&b)?; - - assert!(cs.is_satisfied().unwrap()); - - match (first_operand, second_operand, c) { - (OpType::True, OpType::True, Boolean::Constant(true)) => (), - (OpType::True, OpType::False, Boolean::Constant(false)) => (), - (OpType::True, OpType::AllocatedTrue, Boolean::Is(_)) => (), - (OpType::True, OpType::AllocatedFalse, Boolean::Is(_)) => (), - (OpType::True, OpType::NegatedAllocatedTrue, Boolean::Not(_)) => (), - (OpType::True, OpType::NegatedAllocatedFalse, Boolean::Not(_)) => (), - - (OpType::False, OpType::True, Boolean::Constant(false)) => (), - (OpType::False, OpType::False, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedTrue, Boolean::Constant(false)) => (), - (OpType::False, OpType::AllocatedFalse, Boolean::Constant(false)) => (), - (OpType::False, OpType::NegatedAllocatedTrue, Boolean::Constant(false)) => (), - (OpType::False, OpType::NegatedAllocatedFalse, Boolean::Constant(false)) => (), - - (OpType::AllocatedTrue, OpType::True, Boolean::Is(_)) => (), - (OpType::AllocatedTrue, OpType::False, Boolean::Constant(false)) => (), - (OpType::AllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - (OpType::AllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedTrue, OpType::NegatedAllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - - (OpType::AllocatedFalse, OpType::True, Boolean::Is(_)) => (), - (OpType::AllocatedFalse, OpType::False, Boolean::Constant(false)) => (), - (OpType::AllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::AllocatedFalse, OpType::NegatedAllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedTrue, OpType::True, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedTrue, OpType::False, Boolean::Constant(false)) => (), - (OpType::NegatedAllocatedTrue, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - (OpType::NegatedAllocatedTrue, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedTrue, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - - (OpType::NegatedAllocatedFalse, OpType::True, Boolean::Not(_)) => (), - (OpType::NegatedAllocatedFalse, OpType::False, Boolean::Constant(false)) => (), - (OpType::NegatedAllocatedFalse, OpType::AllocatedTrue, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - (OpType::NegatedAllocatedFalse, OpType::AllocatedFalse, Boolean::Is(ref v)) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedTrue, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::zero()); - assert_eq!(v.value(), Ok(false)); - }, - ( - OpType::NegatedAllocatedFalse, - OpType::NegatedAllocatedFalse, - Boolean::Is(ref v), - ) => { - assert_eq!(cs.assigned_value(v.variable()).unwrap(), Fr::one()); - assert_eq!(v.value(), Ok(true)); - }, - - _ => { - panic!( - "unexpected behavior at {:?} AND {:?}", - first_operand, second_operand - ); - }, - } - } - } - Ok(()) - } - - #[test] - fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - for _ in 0..1000 { - let mut r = Fr::rand(&mut rng); - let mut s = Fr::rand(&mut rng); - if r > s { - core::mem::swap(&mut r, &mut s) - } - - let cs = ConstraintSystem::::new_ref(); - - let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); - let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; - Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; - - assert!(cs.is_satisfied().unwrap()); - } - - for _ in 0..1000 { - let r = Fr::rand(&mut rng); - if r == -Fr::one() { - continue; - } - let s = r + Fr::one(); - let s2 = r.double(); - let cs = ConstraintSystem::::new_ref(); - - let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); - let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; - Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; - if r < s2 { - Boolean::enforce_smaller_or_equal_than_le(&bits, s2.into_bigint())?; - } - - assert!(cs.is_satisfied().unwrap()); - } - Ok(()) - } - - #[test] - fn test_enforce_in_field() -> Result<(), SynthesisError> { - { - let cs = ConstraintSystem::::new_ref(); - - let mut bits = vec![]; - for b in BitIteratorBE::new(Fr::characteristic()).skip(1) { - bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); - } - bits.reverse(); - - Boolean::enforce_in_field_le(&bits)?; - - assert!(!cs.is_satisfied().unwrap()); - } - - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let r = Fr::rand(&mut rng); - let cs = ConstraintSystem::::new_ref(); - - let mut bits = vec![]; - for b in BitIteratorBE::new(r.into_bigint()).skip(1) { - bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); - } - bits.reverse(); - - Boolean::enforce_in_field_le(&bits)?; - - assert!(cs.is_satisfied().unwrap()); - } - Ok(()) - } - - #[test] - fn test_enforce_nand() -> Result<(), SynthesisError> { - { - let cs = ConstraintSystem::::new_ref(); - - assert!( - Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), false)?]).is_ok() - ); - assert!( - Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), true)?]).is_err() - ); - } - - for i in 1..5 { - // with every possible assignment for them - for mut b in 0..(1 << i) { - // with every possible negation - for mut n in 0..(1 << i) { - let cs = ConstraintSystem::::new_ref(); - - let mut expected = true; - - let mut bits = vec![]; - for _ in 0..i { - expected &= b & 1 == 1; - - let bit = if n & 1 == 1 { - Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))? - } else { - Boolean::new_witness(cs.clone(), || Ok(b & 1 == 0))?.not() - }; - bits.push(bit); - - b >>= 1; - n >>= 1; - } - - let expected = !expected; - - Boolean::enforce_kary_nand(&bits)?; - - if expected { - assert!(cs.is_satisfied().unwrap()); - } else { - assert!(!cs.is_satisfied().unwrap()); - } - } - } - } - Ok(()) - } - - #[test] - fn test_kary_and() -> Result<(), SynthesisError> { - // test different numbers of operands - for i in 1..15 { - // with every possible assignment for them - for mut b in 0..(1 << i) { - let cs = ConstraintSystem::::new_ref(); - - let mut expected = true; - - let mut bits = vec![]; - for _ in 0..i { - expected &= b & 1 == 1; - bits.push(Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))?); - b >>= 1; - } - - let r = Boolean::kary_and(&bits)?; - - assert!(cs.is_satisfied().unwrap()); - - if let Boolean::Is(ref r) = r { - assert_eq!(r.value()?, expected); - } - } - } - Ok(()) - } - - #[test] - fn test_bits_to_fp() -> Result<(), SynthesisError> { - use AllocationMode::*; - let rng = &mut ark_std::test_rng(); - let cs = ConstraintSystem::::new_ref(); - - let modes = [Input, Witness, Constant]; - for &mode in modes.iter() { - for _ in 0..1000 { - let f = Fr::rand(rng); - let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); - let bits: Vec<_> = - AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; - let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; - let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; - claimed_f.enforce_equal(&f)?; - } - - for _ in 0..1000 { - let f = Fr::from(u64::rand(rng)); - let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); - let bits: Vec<_> = - AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; - let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; - let claimed_f = Boolean::le_bits_to_fp_var(&bits)?; - claimed_f.enforce_equal(&f)?; - } - assert!(cs.is_satisfied().unwrap()); - } - - Ok(()) - } -} diff --git a/src/bits/uint.rs b/src/bits/uint.rs deleted file mode 100644 index b575af43..00000000 --- a/src/bits/uint.rs +++ /dev/null @@ -1,567 +0,0 @@ -macro_rules! make_uint { - ($name:ident, $size:expr, $native:ident, $mod_name:ident, $r1cs_doc_name:expr, $native_doc_name:expr, $num_bits_doc:expr) => { - #[doc = "This module contains the "] - #[doc = $r1cs_doc_name] - #[doc = "type, which is the R1CS equivalent of the "] - #[doc = $native_doc_name] - #[doc = " type."] - pub mod $mod_name { - use ark_ff::{Field, One, PrimeField, Zero}; - use core::{borrow::Borrow, convert::TryFrom}; - use num_bigint::BigUint; - use num_traits::cast::ToPrimitive; - - use ark_relations::r1cs::{ - ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, - }; - - use crate::{ - boolean::{AllocatedBool, Boolean}, - prelude::*, - Assignment, Vec, - }; - - #[doc = "This struct represent an unsigned"] - #[doc = $num_bits_doc] - #[doc = " bit integer as a sequence of "] - #[doc = $num_bits_doc] - #[doc = " `Boolean`s. \n"] - #[doc = "This is the R1CS equivalent of the native "] - #[doc = $native_doc_name] - #[doc = " unsigned integer type."] - #[derive(Clone, Debug)] - pub struct $name { - // Least significant bit first - bits: [Boolean; $size], - value: Option<$native>, - } - - impl R1CSVar for $name { - type Value = $native; - - fn cs(&self) -> ConstraintSystemRef { - self.bits.as_ref().cs() - } - - fn value(&self) -> Result { - let mut value = None; - for (i, bit) in self.bits.iter().enumerate() { - let b = $native::from(bit.value()?); - value = match value { - Some(value) => Some(value + (b << i)), - None => Some(b << i), - }; - } - debug_assert_eq!(self.value, value); - value.get() - } - } - - impl $name { - #[doc = "Construct a constant "] - #[doc = $r1cs_doc_name] - #[doc = " from the native "] - #[doc = $native_doc_name] - #[doc = " type."] - pub fn constant(value: $native) -> Self { - let mut bits = [Boolean::FALSE; $size]; - - let mut tmp = value; - for i in 0..$size { - bits[i] = Boolean::constant((tmp & 1) == 1); - tmp >>= 1; - } - - $name { - bits, - value: Some(value), - } - } - - /// Turns `self` into the underlying little-endian bits. - pub fn to_bits_le(&self) -> Vec> { - self.bits.to_vec() - } - - /// Construct `Self` from a slice of `Boolean`s. - /// - /// # Panics - #[doc = "This method panics if `bits.len() != "] - #[doc = $num_bits_doc] - #[doc = "`."] - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), $size); - - let bits = <&[Boolean; $size]>::try_from(bits).unwrap().clone(); - - let mut value = Some(0); - for b in bits.iter().rev() { - value.as_mut().map(|v| *v <<= 1); - - match *b { - Boolean::Constant(b) => { - value.as_mut().map(|v| *v |= $native::from(b)); - }, - Boolean::Is(ref b) => match b.value() { - Ok(b) => { - value.as_mut().map(|v| *v |= $native::from(b)); - }, - Err(_) => value = None, - }, - Boolean::Not(ref b) => match b.value() { - Ok(b) => { - value.as_mut().map(|v| *v |= $native::from(!b)); - }, - Err(_) => value = None, - }, - } - } - - Self { value, bits } - } - - /// Rotates `self` to the right by `by` steps, wrapping around. - #[tracing::instrument(target = "r1cs", skip(self))] - pub fn rotr(&self, by: usize) -> Self { - let mut result = self.clone(); - let by = by % $size; - - let new_bits = self.bits.iter().skip(by).chain(&self.bits).take($size); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new.clone(); - } - - result.value = self - .value - .map(|v| v.rotate_right(u32::try_from(by).unwrap())); - result - } - - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this - /// method *does not* create any constraints or variables. - #[tracing::instrument(target = "r1cs", skip(self, other))] - pub fn xor(&self, other: &Self) -> Result { - let mut result = self.clone(); - result.value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new?; - } - - Ok(result) - } - - /// Perform modular addition of `operands`. - /// - /// The user must ensure that overflow does not occur. - #[tracing::instrument(target = "r1cs", skip(operands))] - pub fn addmany(operands: &[Self]) -> Result - where - F: PrimeField, - { - // Make some arbitrary bounds for ourselves to avoid overflows - // in the scalar field - assert!(F::MODULUS_BIT_SIZE >= 2 * $size); - - // Support up to 128 - assert!($size <= 128); - - assert!(operands.len() >= 1); - assert!($size + ark_std::log2(operands.len()) <= F::MODULUS_BIT_SIZE); - - if operands.len() == 1 { - return Ok(operands[0].clone()); - } - - // Compute the maximum value of the sum so we allocate enough bits for - // the result - let mut max_value = - BigUint::from($native::max_value()) * BigUint::from(operands.len()); - - // Keep track of the resulting value - let mut result_value = Some(BigUint::zero()); - - // This is a linear combination that we will enforce to be "zero" - let mut lc = LinearCombination::zero(); - - let mut all_constants = true; - - // Iterate over the operands - for op in operands { - // Accumulate the value - match op.value { - Some(val) => { - result_value.as_mut().map(|v| *v += BigUint::from(val)); - }, - - None => { - // If any of our operands have unknown value, we won't - // know the value of the result - result_value = None; - }, - } - - // Iterate over each bit_gadget of the operand and add the operand to - // the linear combination - let mut coeff = F::one(); - for bit in &op.bits { - match *bit { - Boolean::Is(ref bit) => { - all_constants = false; - - // Add coeff * bit_gadget - lc += (coeff, bit.variable()); - }, - Boolean::Not(ref bit) => { - all_constants = false; - - // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * - // bit_gadget - lc = lc + (coeff, Variable::One) - (coeff, bit.variable()); - }, - Boolean::Constant(bit) => { - if bit { - lc += (coeff, Variable::One); - } - }, - } - - coeff.double_in_place(); - } - } - - // The value of the actual result is modulo 2^$size - let modular_value = result_value.clone().map(|v| { - let modulus = BigUint::from(1u64) << ($size as u32); - (v % modulus).to_u128().unwrap() as $native - }); - - if all_constants && modular_value.is_some() { - // We can just return a constant, rather than - // unpacking the result into allocated bits. - - return Ok($name::constant(modular_value.unwrap())); - } - let cs = operands.cs(); - - // Storage area for the resulting bits - let mut result_bits = vec![]; - - // Allocate each bit_gadget of the result - let mut coeff = F::one(); - let mut i = 0; - while max_value != BigUint::zero() { - // Allocate the bit_gadget - let b = AllocatedBool::new_witness(cs.clone(), || { - result_value - .clone() - .map(|v| (v >> i) & BigUint::one() == BigUint::one()) - .get() - })?; - - // Subtract this bit_gadget from the linear combination to ensure the sums - // balance out - lc = lc - (coeff, b.variable()); - - result_bits.push(b.into()); - - max_value >>= 1; - i += 1; - coeff.double_in_place(); - } - - // Enforce that the linear combination equals zero - cs.enforce_constraint(lc!(), lc!(), lc)?; - - // Discard carry bits that we don't care about - result_bits.truncate($size); - let bits = TryFrom::try_from(result_bits).unwrap(); - - Ok($name { - bits, - value: modular_value, - }) - } - } - - impl ToBytesGadget for $name { - #[tracing::instrument(target = "r1cs", skip(self))] - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self - .to_bits_le() - .chunks(8) - .map(UInt8::from_bits_le) - .collect()) - } - } - - impl EqGadget for $name { - #[tracing::instrument(target = "r1cs", skip(self))] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_ref().is_eq(&other.bits) - } - - #[tracing::instrument(target = "r1cs", skip(self))] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits.conditional_enforce_equal(&other.bits, condition) - } - - #[tracing::instrument(target = "r1cs", skip(self))] - fn conditional_enforce_not_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(&other.bits, condition) - } - } - - impl CondSelectGadget for $name { - #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] - fn conditionally_select( - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; $size]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.value().ok().and_then(|cond| { - if cond { - true_value.value().ok() - } else { - false_value.value().ok() - } - }); - Ok(Self { bits, value }) - } - } - - impl AllocVar<$native, ConstraintF> for $name { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let value = f().map(|f| *f.borrow()).ok(); - - let mut values = [None; $size]; - if let Some(val) = value { - values - .iter_mut() - .enumerate() - .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); - } - - let mut bits = [Boolean::FALSE; $size]; - for (b, v) in bits.iter_mut().zip(&values) { - *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; - } - Ok(Self { bits, value }) - } - } - - #[cfg(test)] - mod test { - use super::$name; - use crate::{bits::boolean::Boolean, prelude::*, Vec}; - use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; - use ark_std::rand::Rng; - use ark_test_curves::mnt4_753::Fr; - - #[test] - fn test_from_bits() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let v = (0..$size) - .map(|_| Boolean::constant(rng.gen())) - .collect::>>(); - - let b = $name::from_bits_le(&v); - - for (i, bit) in b.bits.iter().enumerate() { - match bit { - &Boolean::Constant(bit) => { - assert_eq!(bit, ((b.value()? >> i) & 1 == 1)); - }, - _ => unreachable!(), - } - } - - let expected_to_be_same = b.to_bits_le(); - - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, - (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, - _ => unreachable!(), - } - } - } - Ok(()) - } - - #[test] - fn test_xor() -> Result<(), SynthesisError> { - use Boolean::*; - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - - let mut expected = a ^ b ^ c; - - let a_bit = $name::new_witness(cs.clone(), || Ok(a))?; - let b_bit = $name::constant(b); - let c_bit = $name::new_witness(cs.clone(), || Ok(c))?; - - let r = a_bit.xor(&b_bit).unwrap(); - let r = r.xor(&c_bit).unwrap(); - - assert!(cs.is_satisfied().unwrap()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), - Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), - Constant(b) => assert_eq!(*b, (expected & 1 == 1)), - } - - expected >>= 1; - } - } - Ok(()) - } - - #[test] - fn test_addmany_constants() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - - let a_bit = $name::new_constant(cs.clone(), a)?; - let b_bit = $name::new_constant(cs.clone(), b)?; - let c_bit = $name::new_constant(cs.clone(), c)?; - - let mut expected = a.wrapping_add(b).wrapping_add(c); - - let r = $name::addmany(&[a_bit, b_bit, c_bit]).unwrap(); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Boolean::Is(_) => unreachable!(), - Boolean::Not(_) => unreachable!(), - Boolean::Constant(b) => assert_eq!(*b, (expected & 1 == 1)), - } - - expected >>= 1; - } - } - Ok(()) - } - - #[test] - fn test_addmany() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: $native = rng.gen(); - let b: $native = rng.gen(); - let c: $native = rng.gen(); - let d: $native = rng.gen(); - - let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - - let a_bit = $name::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a))?; - let b_bit = $name::constant(b); - let c_bit = $name::constant(c); - let d_bit = $name::new_witness(ark_relations::ns!(cs, "d_bit"), || Ok(d))?; - - let r = a_bit.xor(&b_bit).unwrap(); - let r = $name::addmany(&[r, c_bit, d_bit]).unwrap(); - - assert!(cs.is_satisfied().unwrap()); - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Boolean::Is(b) => assert_eq!(b.value()?, (expected & 1 == 1)), - Boolean::Not(b) => assert_eq!(!b.value()?, (expected & 1 == 1)), - Boolean::Constant(_) => unreachable!(), - } - - expected >>= 1; - } - } - Ok(()) - } - - #[test] - fn test_rotr() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - let mut num = rng.gen(); - - let a: $name = $name::constant(num); - - for i in 0..$size { - let b = a.rotr(i); - - assert!(b.value.unwrap() == num); - - let mut tmp = num; - for b in &b.bits { - match b { - Boolean::Constant(b) => assert_eq!(*b, tmp & 1 == 1), - _ => unreachable!(), - } - - tmp >>= 1; - } - - num = num.rotate_right(1); - } - Ok(()) - } - } - } - }; -} diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs deleted file mode 100644 index a8dd57f4..00000000 --- a/src/bits/uint8.rs +++ /dev/null @@ -1,550 +0,0 @@ -use ark_ff::{Field, PrimeField, ToConstraintField}; - -use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; - -use crate::{ - fields::fp::{AllocatedFp, FpVar}, - prelude::*, - Assignment, ToConstraintFieldGadget, Vec, -}; -use core::{borrow::Borrow, convert::TryFrom}; - -/// Represents an interpretation of 8 `Boolean` objects as an -/// unsigned integer. -#[derive(Clone, Debug)] -pub struct UInt8 { - /// Little-endian representation: least significant bit first - pub(crate) bits: [Boolean; 8], - pub(crate) value: Option, -} - -impl R1CSVar for UInt8 { - type Value = u8; - - fn cs(&self) -> ConstraintSystemRef { - self.bits.as_ref().cs() - } - - fn value(&self) -> Result { - let mut value = None; - for (i, bit) in self.bits.iter().enumerate() { - let b = u8::from(bit.value()?); - value = match value { - Some(value) => Some(value + (b << i)), - None => Some(b << i), - }; - } - debug_assert_eq!(self.value, value); - value.get() - } -} - -impl UInt8 { - /// Construct a constant vector of `UInt8` from a vector of `u8` - /// - /// This *does not* create any new variables or constraints. - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = vec![UInt8::new_witness(cs.clone(), || Ok(2))?]; - /// - /// let constant = UInt8::constant_vec(&[2]); - /// var.enforce_equal(&constant)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn constant_vec(values: &[u8]) -> Vec { - let mut result = Vec::new(); - for value in values { - result.push(UInt8::constant(*value)); - } - result - } - - /// Construct a constant `UInt8` from a `u8` - /// - /// This *does not* create new variables or constraints. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = UInt8::new_witness(cs.clone(), || Ok(2))?; - /// - /// let constant = UInt8::constant(2); - /// var.enforce_equal(&constant)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn constant(value: u8) -> Self { - let mut bits = [Boolean::FALSE; 8]; - - let mut tmp = value; - for i in 0..8 { - // If last bit is one, push one. - bits[i] = Boolean::constant((tmp & 1) == 1); - tmp >>= 1; - } - - Self { - bits, - value: Some(value), - } - } - - /// Allocates a slice of `u8`'s as private witnesses. - pub fn new_witness_vec( - cs: impl Into>, - values: &[impl Into> + Copy], - ) -> Result, SynthesisError> { - let ns = cs.into(); - let cs = ns.cs(); - let mut output_vec = Vec::with_capacity(values.len()); - for value in values { - let byte: Option = Into::into(*value); - output_vec.push(Self::new_witness(cs.clone(), || byte.get())?); - } - Ok(output_vec) - } - - /// Allocates a slice of `u8`'s as public inputs by first packing them into - /// elements of `F`, (thus reducing the number of input allocations), - /// allocating these elements as public inputs, and then converting - /// these field variables `FpVar` variables back into bytes. - /// - /// From a user perspective, this trade-off adds constraints, but improves - /// verifier time and verification key size. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let two = UInt8::new_witness(cs.clone(), || Ok(2))?; - /// let var = vec![two.clone(); 32]; - /// - /// let c = UInt8::new_input_vec(cs.clone(), &[2; 32])?; - /// var.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - pub fn new_input_vec( - cs: impl Into>, - values: &[u8], - ) -> Result, SynthesisError> - where - F: PrimeField, - { - let ns = cs.into(); - let cs = ns.cs(); - let values_len = values.len(); - let field_elements: Vec = ToConstraintField::::to_field_elements(values).unwrap(); - - let max_size = 8 * ((F::MODULUS_BIT_SIZE - 1) / 8) as usize; - let mut allocated_bits = Vec::new(); - for field_element in field_elements.into_iter() { - let fe = AllocatedFp::new_input(cs.clone(), || Ok(field_element))?; - let fe_bits = fe.to_bits_le()?; - - // Remove the most significant bit, because we know it should be zero - // because `values.to_field_elements()` only - // packs field elements up to the penultimate bit. - // That is, the most significant bit (`ConstraintF::NUM_BITS`-th bit) is - // unset, so we can just pop it off. - allocated_bits.extend_from_slice(&fe_bits[0..max_size]); - } - - // Chunk up slices of 8 bit into bytes. - Ok(allocated_bits[0..(8 * values_len)] - .chunks(8) - .map(Self::from_bits_le) - .collect()) - } - - /// Converts a little-endian byte order representation of bits into a - /// `UInt8`. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let var = UInt8::new_witness(cs.clone(), || Ok(128))?; - /// - /// let f = Boolean::FALSE; - /// let t = Boolean::TRUE; - /// - /// // Construct [0, 0, 0, 0, 0, 0, 0, 1] - /// let mut bits = vec![f.clone(); 7]; - /// bits.push(t); - /// - /// let mut c = UInt8::from_bits_le(&bits); - /// var.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), 8); - let bits = <&[Boolean; 8]>::try_from(bits).unwrap().clone(); - - let mut value = Some(0u8); - for (i, b) in bits.iter().enumerate() { - value = match b.value().ok() { - Some(b) => value.map(|v| v + (u8::from(b) << i)), - None => None, - } - } - - Self { value, bits } - } - - /// Outputs `self ^ other`. - /// - /// If at least one of `self` and `other` are constants, then this method - /// *does not* create any constraints or variables. - /// - /// ``` - /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { - /// // We'll use the BLS12-381 scalar field for our constraints. - /// use ark_test_curves::bls12_381::Fr; - /// use ark_relations::r1cs::*; - /// use ark_r1cs_std::prelude::*; - /// - /// let cs = ConstraintSystem::::new_ref(); - /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; - /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; - /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; - /// - /// a.xor(&b)?.enforce_equal(&c)?; - /// assert!(cs.is_satisfied().unwrap()); - /// # Ok(()) - /// # } - /// ``` - #[tracing::instrument(target = "r1cs")] - pub fn xor(&self, other: &Self) -> Result { - let mut result = self.clone(); - result.value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - - for (res, new) in result.bits.iter_mut().zip(new_bits) { - *res = new?; - } - - Ok(result) - } -} - -impl EqGadget for UInt8 { - #[tracing::instrument(target = "r1cs")] - fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_ref().is_eq(&other.bits) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits.conditional_enforce_equal(&other.bits, condition) - } - - #[tracing::instrument(target = "r1cs")] - fn conditional_enforce_not_equal( - &self, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(&other.bits, condition) - } -} - -impl CondSelectGadget for UInt8 { - #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] - fn conditionally_select( - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)); - let mut bits = [Boolean::FALSE; 8]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.value().ok().and_then(|cond| { - if cond { - true_value.value().ok() - } else { - false_value.value().ok() - } - }); - Ok(Self { bits, value }) - } -} - -impl AllocVar for UInt8 { - fn new_variable>( - cs: impl Into>, - f: impl FnOnce() -> Result, - mode: AllocationMode, - ) -> Result { - let ns = cs.into(); - let cs = ns.cs(); - let value = f().map(|f| *f.borrow()).ok(); - - let mut values = [None; 8]; - if let Some(val) = value { - values - .iter_mut() - .enumerate() - .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); - } - - let mut bits = [Boolean::FALSE; 8]; - for (b, v) in bits.iter_mut().zip(&values) { - *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; - } - Ok(Self { bits, value }) - } -} - -/// Parses the `Vec>` in fixed-sized -/// `ConstraintF::MODULUS_BIT_SIZE - 1` chunks and converts each chunk, which is -/// assumed to be little-endian, to its `FpVar` representation. -/// This is the gadget counterpart to the `[u8]` implementation of -/// [`ToConstraintField`]. -impl ToConstraintFieldGadget for [UInt8] { - #[tracing::instrument(target = "r1cs")] - fn to_constraint_field(&self) -> Result>, SynthesisError> { - let max_size = ((ConstraintF::MODULUS_BIT_SIZE - 1) / 8) as usize; - self.chunks(max_size) - .map(|chunk| Boolean::le_bits_to_fp_var(chunk.to_bits_le()?.as_slice())) - .collect::, SynthesisError>>() - } -} - -impl ToConstraintFieldGadget for Vec> { - #[tracing::instrument(target = "r1cs")] - fn to_constraint_field(&self) -> Result>, SynthesisError> { - self.as_slice().to_constraint_field() - } -} - -#[cfg(test)] -mod test { - use super::UInt8; - use crate::{ - fields::fp::FpVar, - prelude::{ - AllocationMode::{Constant, Input, Witness}, - *, - }, - ToConstraintFieldGadget, Vec, - }; - use ark_ff::{PrimeField, ToConstraintField}; - use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; - use ark_std::rand::{distributions::Uniform, Rng}; - use ark_test_curves::bls12_381::Fr; - - #[test] - fn test_uint8_from_bits_to_bits() -> Result<(), SynthesisError> { - let cs = ConstraintSystem::::new_ref(); - let byte_val = 0b01110001; - let byte = - UInt8::new_witness(ark_relations::ns!(cs, "alloc value"), || Ok(byte_val)).unwrap(); - let bits = byte.to_bits_le()?; - for (i, bit) in bits.iter().enumerate() { - assert_eq!(bit.value()?, (byte_val >> i) & 1 == 1) - } - Ok(()) - } - - #[test] - fn test_uint8_new_input_vec() -> Result<(), SynthesisError> { - let cs = ConstraintSystem::::new_ref(); - let byte_vals = (64u8..128u8).collect::>(); - let bytes = - UInt8::new_input_vec(ark_relations::ns!(cs, "alloc value"), &byte_vals).unwrap(); - dbg!(bytes.value())?; - for (native, variable) in byte_vals.into_iter().zip(bytes) { - let bits = variable.to_bits_le()?; - for (i, bit) in bits.iter().enumerate() { - assert_eq!( - bit.value()?, - (native >> i) & 1 == 1, - "native value {}: bit {:?}", - native, - i - ) - } - } - Ok(()) - } - - #[test] - fn test_uint8_from_bits() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let v = (0..8) - .map(|_| Boolean::::Constant(rng.gen())) - .collect::>(); - - let val = UInt8::from_bits_le(&v); - - for (i, bit) in val.bits.iter().enumerate() { - match bit { - Boolean::Constant(b) => assert!(*b == ((val.value()? >> i) & 1 == 1)), - _ => unreachable!(), - } - } - - let expected_to_be_same = val.to_bits_le()?; - - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, - (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, - _ => unreachable!(), - } - } - } - Ok(()) - } - - #[test] - fn test_uint8_xor() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let a: u8 = rng.gen(); - let b: u8 = rng.gen(); - let c: u8 = rng.gen(); - - let mut expected = a ^ b ^ c; - - let a_bit = UInt8::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a)).unwrap(); - let b_bit = UInt8::constant(b); - let c_bit = UInt8::new_witness(ark_relations::ns!(cs, "c_bit"), || Ok(c)).unwrap(); - - let r = a_bit.xor(&b_bit).unwrap(); - let r = r.xor(&c_bit).unwrap(); - - assert!(cs.is_satisfied().unwrap()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - Boolean::Is(b) => assert!(b.value()? == (expected & 1 == 1)), - Boolean::Not(b) => assert!(!b.value()? == (expected & 1 == 1)), - Boolean::Constant(b) => assert!(*b == (expected & 1 == 1)), - } - - expected >>= 1; - } - } - Ok(()) - } - - #[test] - fn test_uint8_to_constraint_field() -> Result<(), SynthesisError> { - let mut rng = ark_std::test_rng(); - let max_size = ((::MODULUS_BIT_SIZE - 1) / 8) as usize; - - let modes = [Input, Witness, Constant]; - for mode in &modes { - for _ in 0..1000 { - let cs = ConstraintSystem::::new_ref(); - - let bytes: Vec = (&mut rng) - .sample_iter(&Uniform::new_inclusive(0, u8::max_value())) - .take(max_size * 3 + 5) - .collect(); - - let bytes_var = bytes - .iter() - .map(|byte| UInt8::new_variable(cs.clone(), || Ok(*byte), *mode)) - .collect::, SynthesisError>>()?; - - let f_vec: Vec = bytes.to_field_elements().unwrap(); - let f_var_vec: Vec> = bytes_var.to_constraint_field()?; - - assert!(cs.is_satisfied().unwrap()); - assert_eq!(f_vec, f_var_vec.value()?); - } - } - - Ok(()) - } - - #[test] - fn test_uint8_random_access() { - let mut rng = ark_std::test_rng(); - - for _ in 0..100 { - let cs = ConstraintSystem::::new_ref(); - - // value array - let values: Vec = (0..128).map(|_| rng.gen()).collect(); - let values_const: Vec> = values.iter().map(|x| UInt8::constant(*x)).collect(); - - // index array - let position: Vec = (0..7).map(|_| rng.gen()).collect(); - let position_var: Vec> = position - .iter() - .map(|b| { - Boolean::new_witness(ark_relations::ns!(cs, "index_arr_element"), || Ok(*b)) - .unwrap() - }) - .collect(); - - // index - let mut index = 0; - for x in position { - index *= 2; - index += if x { 1 } else { 0 }; - } - - assert_eq!( - UInt8::conditionally_select_power_of_two_vector(&position_var, &values_const) - .unwrap() - .value() - .unwrap(), - values[index] - ) - } - } -} diff --git a/src/boolean/allocated.rs b/src/boolean/allocated.rs new file mode 100644 index 00000000..9397f12d --- /dev/null +++ b/src/boolean/allocated.rs @@ -0,0 +1,334 @@ +use core::borrow::Borrow; + +use ark_ff::{Field, PrimeField}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError, Variable}; + +use crate::{ + alloc::{AllocVar, AllocationMode}, + select::CondSelectGadget, + Assignment, +}; + +use super::Boolean; + +/// Represents a variable in the constraint system which is guaranteed +/// to be either zero or one. +/// +/// In general, one should prefer using `Boolean` instead of `AllocatedBool`, +/// as `Boolean` offers better support for constant values, and implements +/// more traits. +#[derive(Clone, Debug, Eq, PartialEq)] +#[must_use] +pub struct AllocatedBool { + pub(super) variable: Variable, + pub(super) cs: ConstraintSystemRef, +} + +pub(crate) fn bool_to_field(val: impl Borrow) -> F { + F::from(*val.borrow()) +} + +impl AllocatedBool { + /// Get the assigned value for `self`. + pub fn value(&self) -> Result { + let value = self.cs.assigned_value(self.variable).get()?; + if value.is_zero() { + Ok(false) + } else if value.is_one() { + Ok(true) + } else { + unreachable!("Incorrect value assigned: {:?}", value); + } + } + + /// Get the R1CS variable for `self`. + pub fn variable(&self) -> Variable { + self.variable + } + + /// Allocate a witness variable without a booleanity check. + #[doc(hidden)] + pub fn new_witness_without_booleanity_check>( + cs: ConstraintSystemRef, + f: impl FnOnce() -> Result, + ) -> Result { + let variable = cs.new_witness_variable(|| f().map(bool_to_field))?; + Ok(Self { variable, cs }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn not(&self) -> Result { + let variable = self.cs.new_lc(lc!() + Variable::One - self.variable)?; + Ok(Self { + variable, + cs: self.cs.clone(), + }) + } + + /// Performs an XOR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn xor(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? ^ b.value()?) + })?; + + // Constrain (a + a) * (b) = (a + b - c) + // Given that a and b are boolean constrained, if they + // are equal, the only solution for c is 0, and if they + // are different, the only solution for c is 1. + // + // ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b) = c + // (1 - (a * b)) * (1 - ((1 - a) * (1 - b))) = c + // (1 - ab) * (1 - (1 - a - b + ab)) = c + // (1 - ab) * (a + b - ab) = c + // a + b - ab - (a^2)b - (b^2)a + (a^2)(b^2) = c + // a + b - ab - ab - ab + ab = c + // a + b - 2ab = c + // -2a * b = c - a - b + // 2a * b = a + b - c + // (a + a) * b = a + b - c + self.cs.enforce_constraint( + lc!() + self.variable + self.variable, + lc!() + b.variable, + lc!() + self.variable + b.variable - result.variable, + )?; + + Ok(result) + } + + /// Performs an AND operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn and(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? & b.value()?) + })?; + + // Constrain (a) * (b) = (c), ensuring c is 1 iff + // a AND b are both 1. + self.cs.enforce_constraint( + lc!() + self.variable, + lc!() + b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } + + /// Performs an OR operation over the two operands, returning + /// an `AllocatedBool`. + #[tracing::instrument(target = "r1cs")] + pub fn or(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? | b.value()?) + })?; + + // Constrain (1 - a) * (1 - b) = (1 - c), ensuring c is 0 iff + // a and b are both false, and otherwise c is 1. + self.cs.enforce_constraint( + lc!() + Variable::One - self.variable, + lc!() + Variable::One - b.variable, + lc!() + Variable::One - result.variable, + )?; + + Ok(result) + } + + /// Calculates `a AND (NOT b)`. + #[tracing::instrument(target = "r1cs")] + pub fn and_not(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(self.value()? & !b.value()?) + })?; + + // Constrain (a) * (1 - b) = (c), ensuring c is 1 iff + // a is true and b is false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + self.variable, + lc!() + Variable::One - b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } + + /// Calculates `(NOT a) AND (NOT b)`. + #[tracing::instrument(target = "r1cs")] + pub fn nor(&self, b: &Self) -> Result { + let result = Self::new_witness_without_booleanity_check(self.cs.clone(), || { + Ok(!(self.value()? | b.value()?)) + })?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + self.cs.enforce_constraint( + lc!() + Variable::One - self.variable, + lc!() + Variable::One - b.variable, + lc!() + result.variable, + )?; + + Ok(result) + } +} + +impl AllocVar for AllocatedBool { + /// Produces a new variable of the appropriate kind + /// (instance or witness), with a booleanity check. + /// + /// N.B.: we could omit the booleanity check when allocating `self` + /// as a new public input, but that places an additional burden on + /// protocol designers. Better safe than sorry! + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + if mode == AllocationMode::Constant { + let variable = if *f()?.borrow() { + Variable::One + } else { + Variable::Zero + }; + Ok(Self { variable, cs }) + } else { + let variable = if mode == AllocationMode::Input { + cs.new_input_variable(|| f().map(bool_to_field))? + } else { + cs.new_witness_variable(|| f().map(bool_to_field))? + }; + + // Constrain: (1 - a) * a = 0 + // This constrains a to be either 0 or 1. + + cs.enforce_constraint(lc!() + Variable::One - variable, lc!() + variable, lc!())?; + + Ok(Self { variable, cs }) + } + } +} + +impl CondSelectGadget for AllocatedBool { + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_val: &Self, + false_val: &Self, + ) -> Result { + let res = Boolean::conditionally_select( + cond, + &true_val.clone().into(), + &false_val.clone().into(), + )?; + match res { + Boolean::Var(a) => Ok(a), + _ => unreachable!("Impossible"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use ark_relations::r1cs::ConstraintSystem; + use ark_test_curves::bls12_381::Fr; + #[test] + fn allocated_xor() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::xor(&a, &b)?; + assert_eq!(c.value()?, a_val ^ b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val ^ b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_or() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::or(&a, &b)?; + assert_eq!(c.value()?, a_val | b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val | b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_and() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::and(&a, &b)?; + assert_eq!(c.value()?, a_val & b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val & b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_and_not() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::and_not(&a, &b)?; + assert_eq!(c.value()?, a_val & !b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (a_val & !b_val)); + } + } + Ok(()) + } + + #[test] + fn allocated_nor() -> Result<(), SynthesisError> { + for a_val in [false, true].iter().copied() { + for b_val in [false, true].iter().copied() { + let cs = ConstraintSystem::::new_ref(); + let a = AllocatedBool::new_witness(cs.clone(), || Ok(a_val))?; + let b = AllocatedBool::new_witness(cs.clone(), || Ok(b_val))?; + let c = AllocatedBool::nor(&a, &b)?; + assert_eq!(c.value()?, !a_val & !b_val); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(a.value()?, (a_val)); + assert_eq!(b.value()?, (b_val)); + assert_eq!(c.value()?, (!a_val & !b_val)); + } + } + Ok(()) + } +} diff --git a/src/boolean/and.rs b/src/boolean/and.rs new file mode 100644 index 00000000..20251f3f --- /dev/null +++ b/src/boolean/and.rs @@ -0,0 +1,331 @@ +use ark_ff::{Field, PrimeField}; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitAnd, ops::BitAndAssign}; + +use crate::{fields::fp::FpVar, prelude::EqGadget}; + +use super::Boolean; + +impl Boolean { + fn _and(&self, other: &Self) -> Result { + use Boolean::*; + match (self, other) { + // false AND x is always false + (&Constant(false), _) | (_, &Constant(false)) => Ok(Constant(false)), + // true AND x is always x + (&Constant(true), x) | (x, &Constant(true)) => Ok(x.clone()), + (Var(ref x), Var(ref y)) => Ok(Var(x.and(y)?)), + } + } + + /// Outputs `!(self & other)`. + pub fn nand(&self, other: &Self) -> Result { + self._and(other).map(|x| !x) + } +} + +impl Boolean { + /// Outputs `bits[0] & bits[1] & ... & bits.last().unwrap()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// Boolean::kary_and(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// Boolean::kary_and(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_and(bits: &[Self]) -> Result { + assert!(!bits.is_empty()); + if bits.len() <= 3 { + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b & next) + } else { + Some(next.clone()) + }; + } + + Ok(cur.expect("should not be 0")) + } else { + // b0 & b1 & ... & bN == 1 if and only if sum(b0, b1, ..., bN) == N + let sum_bits: FpVar<_> = bits.iter().map(|b| FpVar::from(b.clone())).sum(); + let num_bits = FpVar::Constant(F::from(bits.len() as u64)); + sum_bits.is_eq(&num_bits) + } + } + + /// Outputs `!(bits[0] & bits[1] & ... & bits.last().unwrap())`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// Boolean::kary_nand(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_nand(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// Boolean::kary_nand(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_nand(bits: &[Self]) -> Result { + Ok(!Self::kary_and(bits)?) + } + + /// Enforces that `!(bits[0] & bits[1] & ... ) == Boolean::TRUE`. + /// + /// Informally, this means that at least one element in `bits` must be + /// `false`. + #[tracing::instrument(target = "r1cs")] + pub fn enforce_kary_nand(bits: &[Self]) -> Result<(), SynthesisError> { + Self::kary_and(bits)?.enforce_equal(&Boolean::FALSE) + } +} + +impl<'a, F: Field> BitAnd for &'a Boolean { + type Output = Boolean; + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// (&a & &a).enforce_equal(&Boolean::TRUE)?; + /// + /// (&a & &b).enforce_equal(&Boolean::FALSE)?; + /// (&b & &a).enforce_equal(&Boolean::FALSE)?; + /// (&b & &b).enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(other).unwrap() + } +} + +impl<'a, F: Field> BitAnd<&'a Self> for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: &Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl<'a, F: Field> BitAnd> for &'a Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Boolean) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAnd for Boolean { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAndAssign for Boolean { + /// Sets `self = self & other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: Self) { + let result = self._and(&other).unwrap(); + *self = result; + } +} + +impl<'a, F: Field> BitAndAssign<&'a Self> for Boolean { + /// Sets `self = self & other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: &'a Self) { + let result = self._and(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_relations::r1cs::ConstraintSystem; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn and() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a & &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? & b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn nand() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.nand(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable( + cs.clone(), + || Ok(!(a.value()? & b.value()?)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn enforce_nand() -> Result<(), SynthesisError> { + { + let cs = ConstraintSystem::::new_ref(); + + assert!( + Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), false)?]).is_ok() + ); + assert!( + Boolean::enforce_kary_nand(&[Boolean::new_constant(cs.clone(), true)?]).is_err() + ); + } + + for i in 1..5 { + // with every possible assignment for them + for mut b in 0..(1 << i) { + // with every possible negation + for mut n in 0..(1 << i) { + let cs = ConstraintSystem::::new_ref(); + + let mut expected = true; + + let mut bits = vec![]; + for _ in 0..i { + expected &= b & 1 == 1; + + let bit = if n & 1 == 1 { + Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))? + } else { + !Boolean::new_witness(cs.clone(), || Ok(b & 1 == 0))? + }; + bits.push(bit); + + b >>= 1; + n >>= 1; + } + + let expected = !expected; + + Boolean::enforce_kary_nand(&bits)?; + + if expected { + assert!(cs.is_satisfied().unwrap()); + } else { + assert!(!cs.is_satisfied().unwrap()); + } + } + } + } + Ok(()) + } + + #[test] + fn kary_and() -> Result<(), SynthesisError> { + // test different numbers of operands + for i in 1..15 { + // with every possible assignment for them + for mut b in 0..(1 << i) { + let cs = ConstraintSystem::::new_ref(); + + let mut expected = true; + + let mut bits = vec![]; + for _ in 0..i { + expected &= b & 1 == 1; + bits.push(Boolean::new_witness(cs.clone(), || Ok(b & 1 == 1))?); + b >>= 1; + } + + let r = Boolean::kary_and(&bits)?; + + assert!(cs.is_satisfied().unwrap()); + + if let Boolean::Var(ref r) = r { + assert_eq!(r.value()?, expected); + } + } + } + Ok(()) + } +} diff --git a/src/boolean/cmp.rs b/src/boolean/cmp.rs new file mode 100644 index 00000000..a8c133a8 --- /dev/null +++ b/src/boolean/cmp.rs @@ -0,0 +1,95 @@ +use crate::cmp::CmpGadget; + +use super::*; +use ark_ff::PrimeField; + +impl CmpGadget for Boolean { + fn is_ge(&self, other: &Self) -> Result, SynthesisError> { + // a | b | (a | !b) | a >= b + // --|---|--------|-------- + // 0 | 0 | 1 | 1 + // 1 | 0 | 1 | 1 + // 0 | 1 | 0 | 0 + // 1 | 1 | 1 | 1 + Ok(self | &(!other)) + } +} + +impl Boolean { + /// Enforces that `bits`, when interpreted as a integer, is less than + /// `F::characteristic()`, That is, interpret bits as a little-endian + /// integer, and enforce that this integer is "in the field Z_p", where + /// `p = F::characteristic()` . + #[tracing::instrument(target = "r1cs")] + pub fn enforce_in_field_le(bits: &[Self]) -> Result<(), SynthesisError> { + // `bits` < F::characteristic() <==> `bits` <= F::characteristic() -1 + let mut b = F::characteristic().to_vec(); + assert_eq!(b[0] % 2, 1); + b[0] -= 1; // This works, because the LSB is one, so there's no borrows. + let run = Self::enforce_smaller_or_equal_than_le(bits, b)?; + + // We should always end in a "run" of zeros, because + // the characteristic is an odd prime. So, this should + // be empty. + assert!(run.is_empty()); + + Ok(()) + } + + /// Enforces that `bits` is less than or equal to `element`, + /// when both are interpreted as (little-endian) integers. + #[tracing::instrument(target = "r1cs", skip(element))] + pub fn enforce_smaller_or_equal_than_le( + bits: &[Self], + element: impl AsRef<[u64]>, + ) -> Result, SynthesisError> { + let b: &[u64] = element.as_ref(); + + let mut bits_iter = bits.iter().rev(); // Iterate in big-endian + + // Runs of ones in r + let mut last_run = Boolean::constant(true); + let mut current_run = vec![]; + + let mut element_num_bits = 0; + for _ in BitIteratorBE::without_leading_zeros(b) { + element_num_bits += 1; + } + + if bits.len() > element_num_bits { + let mut or_result = Boolean::constant(false); + for should_be_zero in &bits[element_num_bits..] { + or_result |= should_be_zero; + let _ = bits_iter.next().unwrap(); + } + or_result.enforce_equal(&Boolean::constant(false))?; + } + + for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { + if b { + // This is part of a run of ones. + current_run.push(a.clone()); + } else { + if !current_run.is_empty() { + // This is the start of a run of zeros, but we need + // to k-ary AND against `last_run` first. + + current_run.push(last_run.clone()); + last_run = Self::kary_and(¤t_run)?; + current_run.truncate(0); + } + + // If `last_run` is true, `a` must be false, or it would + // not be in the field. + // + // If `last_run` is false, `a` can be true or false. + // + // Ergo, at least one of `last_run` and `a` must be false. + Self::enforce_kary_nand(&[last_run.clone(), a.clone()])?; + } + } + assert!(bits_iter.next().is_none()); + + Ok(current_run) + } +} diff --git a/src/boolean/convert.rs b/src/boolean/convert.rs new file mode 100644 index 00000000..21e9fc09 --- /dev/null +++ b/src/boolean/convert.rs @@ -0,0 +1,21 @@ +use super::*; +use crate::convert::{ToBytesGadget, ToConstraintFieldGadget}; + +impl ToBytesGadget for Boolean { + /// Outputs `1u8` if `self` is true, and `0u8` otherwise. + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> Result>, SynthesisError> { + let value = self.value().map(u8::from).ok(); + let mut bits = [Boolean::FALSE; 8]; + bits[0] = self.clone(); + Ok(vec![UInt8 { bits, value }]) + } +} + +impl ToConstraintFieldGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let var = From::from(self.clone()); + Ok(vec![var]) + } +} diff --git a/src/boolean/eq.rs b/src/boolean/eq.rs new file mode 100644 index 00000000..43bc7da2 --- /dev/null +++ b/src/boolean/eq.rs @@ -0,0 +1,229 @@ +use ark_relations::r1cs::SynthesisError; + +use crate::boolean::Boolean; +use crate::eq::EqGadget; + +use super::*; + +impl EqGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + // self | other | XNOR(self, other) | self == other + // -----|-------|-------------------|-------------- + // 0 | 0 | 1 | 1 + // 0 | 1 | 0 | 0 + // 1 | 0 | 0 | 0 + // 1 | 1 | 1 | 1 + Ok(!(self ^ other)) + } + + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + use Boolean::*; + let one = Variable::One; + // We will use the following trick: a == b <=> a - b == 0 + // This works because a - b == 0 if and only if a = 0 and b = 0, or a = 1 and b = 1, + // which is exactly the definition of a == b. + let difference = match (self, other) { + // 1 == 1; 0 == 0 + (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), + // false != true + (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable), + // 1 - a + (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one - a.variable(), + // a - 0 = a + (Constant(false), Var(a)) | (Var(a), Constant(false)) => lc!() + a.variable(), + // b - a, + (Var(a), Var(b)) => lc!() + b.variable() - a.variable(), + }; + + if condition != &Constant(false) { + let cs = self.cs().or(other.cs()).or(condition.cs()); + cs.enforce_constraint(lc!() + difference, condition.lc(), lc!())?; + } + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + use Boolean::*; + let one = Variable::One; + // We will use the following trick: a != b <=> a + b == 1 + // This works because a + b == 1 if and only if a = 0 and b = 1, or a = 1 and b = 0, + // which is exactly the definition of a != b. + let sum = match (self, other) { + // 1 != 0; 0 != 1 + (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => return Ok(()), + // false == false and true == true + (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable), + // 1 + a + (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc!() + one + a.variable(), + // a + 0 = a + (Constant(false), Var(a)) | (Var(a), Constant(false)) => lc!() + a.variable(), + // b + a, + (Var(a), Var(b)) => lc!() + b.variable() + a.variable(), + }; + + if should_enforce != &Constant(false) { + let cs = self.cs().or(other.cs()).or(should_enforce.cs()); + cs.enforce_constraint(sum, should_enforce.lc(), lc!() + one)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::{run_binary_exhaustive, run_unary_exhaustive}, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn eq() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn neq() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_neq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn neq_and_eq_consistency() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let is_neq = &a.is_neq(&b)?; + let is_eq = &a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected_is_neq = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_eq!(expected_is_neq.value(), is_neq.value()); + assert_ne!(expected_is_neq.value(), is_eq.value()); + expected_is_neq.enforce_equal(is_neq)?; + expected_is_neq.enforce_equal(&!is_eq)?; + expected_is_neq.enforce_not_equal(is_eq)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn enforce_eq_and_enforce_neq_consistency() { + run_unary_exhaustive::(|a| { + let cs = a.cs(); + let not_a = !&a; + a.enforce_equal(&a)?; + not_a.enforce_equal(¬_a)?; + a.enforce_not_equal(¬_a)?; + not_a.enforce_not_equal(&a)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn eq_soundness() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_ne!(expected.value(), computed.value()); + expected.enforce_not_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } + + #[test] + fn neq_soundness() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a.is_neq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?; + assert_ne!(expected.value(), computed.value()); + expected.enforce_not_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/boolean/mod.rs b/src/boolean/mod.rs new file mode 100644 index 00000000..0193f58b --- /dev/null +++ b/src/boolean/mod.rs @@ -0,0 +1,337 @@ +use ark_ff::{BitIteratorBE, Field, PrimeField}; + +use crate::{fields::fp::FpVar, prelude::*, Vec}; +use ark_relations::r1cs::{ + ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, +}; +use core::borrow::Borrow; + +mod allocated; +mod and; +mod cmp; +mod convert; +mod eq; +mod not; +mod or; +mod select; +mod xor; + +pub use allocated::AllocatedBool; + +#[cfg(test)] +mod test_utils; + +/// Represents a boolean value in the constraint system which is guaranteed +/// to be either zero or one. +#[derive(Clone, Debug, Eq, PartialEq)] +#[must_use] +pub enum Boolean { + Var(AllocatedBool), + Constant(bool), +} + +impl R1CSVar for Boolean { + type Value = bool; + + fn cs(&self) -> ConstraintSystemRef { + match self { + Self::Var(a) => a.cs.clone(), + _ => ConstraintSystemRef::None, + } + } + + fn value(&self) -> Result { + match self { + Boolean::Constant(c) => Ok(*c), + Boolean::Var(ref v) => v.value(), + } + } +} + +impl Boolean { + /// The constant `true`. + pub const TRUE: Self = Boolean::Constant(true); + + /// The constant `false`. + pub const FALSE: Self = Boolean::Constant(false); + + /// Constructs a `Boolean` vector from a slice of constant `u8`. + /// The `u8`s are decomposed in little-endian manner. + /// + /// This *does not* create any new variables or constraints. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let t = Boolean::::TRUE; + /// let f = Boolean::::FALSE; + /// + /// let bits = vec![f, t]; + /// let generated_bits = Boolean::constant_vec_from_bytes(&[2]); + /// bits[..2].enforce_equal(&generated_bits[..2])?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn constant_vec_from_bytes(values: &[u8]) -> Vec { + let mut bits = vec![]; + for byte in values { + for i in 0..8 { + bits.push(Self::Constant(((byte >> i) & 1u8) == 1u8)); + } + } + bits + } + + /// Constructs a constant `Boolean` with value `b`. + /// + /// This *does not* create any new variables or constraints. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_r1cs_std::prelude::*; + /// + /// let true_var = Boolean::::TRUE; + /// let false_var = Boolean::::FALSE; + /// + /// true_var.enforce_equal(&Boolean::constant(true))?; + /// false_var.enforce_equal(&Boolean::constant(false))?; + /// # Ok(()) + /// # } + /// ``` + pub fn constant(b: bool) -> Self { + Boolean::Constant(b) + } + + /// Constructs a `LinearCombination` from `Self`'s variables according + /// to the following map. + /// + /// * `Boolean::TRUE => lc!() + Variable::One` + /// * `Boolean::FALSE => lc!()` + /// * `Boolean::Var(v) => lc!() + v.variable()` + pub fn lc(&self) -> LinearCombination { + match self { + &Boolean::Constant(false) => lc!(), + &Boolean::Constant(true) => lc!() + Variable::One, + Boolean::Var(v) => v.variable().into(), + } + } + + /// Convert a little-endian bitwise representation of a field element to + /// `FpVar` + /// + /// Wraps around if the bit representation is larger than the field modulus. + #[tracing::instrument(target = "r1cs", skip(bits))] + pub fn le_bits_to_fp(bits: &[Self]) -> Result, SynthesisError> + where + F: PrimeField, + { + // Compute the value of the `FpVar` variable via double-and-add. + let mut value = None; + let cs = bits.cs(); + // Assign a value only when `cs` is in setup mode, or if we are constructing + // a constant. + let should_construct_value = (!cs.is_in_setup_mode()) || bits.is_constant(); + if should_construct_value { + let bits = bits.iter().map(|b| b.value().unwrap()).collect::>(); + let bytes = bits + .chunks(8) + .map(|c| { + let mut value = 0u8; + for (i, &bit) in c.iter().enumerate() { + value += (bit as u8) << i; + } + value + }) + .collect::>(); + value = Some(F::from_le_bytes_mod_order(&bytes)); + } + + if bits.is_constant() { + Ok(FpVar::constant(value.unwrap())) + } else { + let mut power = F::one(); + // Compute a linear combination for the new field variable, again + // via double and add. + + let combined = bits + .iter() + .map(|b| { + let result = FpVar::from(b.clone()) * power; + power.double_in_place(); + result + }) + .sum(); + // If the number of bits is less than the size of the field, + // then we do not need to enforce that the element is less than + // the modulus. + if bits.len() >= F::MODULUS_BIT_SIZE as usize { + Self::enforce_in_field_le(bits)?; + } + Ok(combined) + } + } +} + +impl From> for Boolean { + fn from(b: AllocatedBool) -> Self { + Boolean::Var(b) + } +} + +impl AllocVar for Boolean { + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + if mode == AllocationMode::Constant { + Ok(Boolean::Constant(*f()?.borrow())) + } else { + AllocatedBool::new_variable(cs, f, mode).map(Boolean::Var) + } + } +} + +#[cfg(test)] +mod test { + use super::Boolean; + use crate::convert::ToBytesGadget; + use crate::prelude::*; + use ark_ff::{ + AdditiveGroup, BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand, + }; + use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn test_boolean_to_byte() -> Result<(), SynthesisError> { + for val in [true, false].iter() { + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::new_witness(cs.clone(), || Ok(*val))?; + let bytes = a.to_bytes()?; + assert_eq!(bytes.len(), 1); + let byte = &bytes[0]; + assert_eq!(byte.value()?, *val as u8); + + for (i, bit) in byte.bits.iter().enumerate() { + assert_eq!(bit.value()?, (byte.value()? >> i) & 1 == 1); + } + } + Ok(()) + } + + #[test] + fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + for _ in 0..1000 { + let mut r = Fr::rand(&mut rng); + let mut s = Fr::rand(&mut rng); + if r > s { + core::mem::swap(&mut r, &mut s) + } + + let cs = ConstraintSystem::::new_ref(); + + let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); + let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; + Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; + + assert!(cs.is_satisfied().unwrap()); + } + + for _ in 0..1000 { + let r = Fr::rand(&mut rng); + if r == -Fr::one() { + continue; + } + let s = r + Fr::one(); + let s2 = r.double(); + let cs = ConstraintSystem::::new_ref(); + + let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect(); + let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?; + Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?; + if r < s2 { + Boolean::enforce_smaller_or_equal_than_le(&bits, s2.into_bigint())?; + } + + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn test_enforce_in_field() -> Result<(), SynthesisError> { + { + let cs = ConstraintSystem::::new_ref(); + + let mut bits = vec![]; + for b in BitIteratorBE::new(Fr::characteristic()).skip(1) { + bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); + } + bits.reverse(); + + Boolean::enforce_in_field_le(&bits)?; + + assert!(!cs.is_satisfied().unwrap()); + } + + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let r = Fr::rand(&mut rng); + let cs = ConstraintSystem::::new_ref(); + + let mut bits = vec![]; + for b in BitIteratorBE::new(r.into_bigint()).skip(1) { + bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?); + } + bits.reverse(); + + Boolean::enforce_in_field_le(&bits)?; + + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn test_bits_to_fp() -> Result<(), SynthesisError> { + use AllocationMode::*; + let rng = &mut ark_std::test_rng(); + let cs = ConstraintSystem::::new_ref(); + + let modes = [Input, Witness, Constant]; + for &mode in modes.iter() { + for _ in 0..1000 { + let f = Fr::rand(rng); + let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); + let bits: Vec<_> = + AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; + let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; + let claimed_f = Boolean::le_bits_to_fp(&bits)?; + claimed_f.enforce_equal(&f)?; + } + + for _ in 0..1000 { + let f = Fr::from(u64::rand(rng)); + let bits = BitIteratorLE::new(f.into_bigint()).collect::>(); + let bits: Vec<_> = + AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?; + let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?; + let claimed_f = Boolean::le_bits_to_fp(&bits)?; + claimed_f.enforce_equal(&f)?; + } + assert!(cs.is_satisfied().unwrap()); + } + + Ok(()) + } +} diff --git a/src/boolean/not.rs b/src/boolean/not.rs new file mode 100644 index 00000000..1c7de2e6 --- /dev/null +++ b/src/boolean/not.rs @@ -0,0 +1,98 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::ops::Not; + +use super::Boolean; + +impl Boolean { + fn _not(&self) -> Result { + match *self { + Boolean::Constant(c) => Ok(Boolean::Constant(!c)), + Boolean::Var(ref v) => Ok(Boolean::Var(v.not().unwrap())), + } + } +} + +impl<'a, F: Field> Not for &'a Boolean { + type Output = Boolean; + /// Negates `self`. + /// + /// This *does not* create any new variables or constraints. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// (!&a).enforce_equal(&b)?; + /// (!&b).enforce_equal(&a)?; + /// + /// (!&a).enforce_equal(&Boolean::FALSE)?; + /// (!&b).enforce_equal(&Boolean::TRUE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +impl<'a, F: Field> Not for &'a mut Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +impl Not for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_unary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn not() { + run_unary_exhaustive::(|a| { + let cs = a.cs(); + let computed = !&a; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = Boolean::new_variable(cs.clone(), || Ok(!a.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/boolean/or.rs b/src/boolean/or.rs new file mode 100644 index 00000000..8f8b41c1 --- /dev/null +++ b/src/boolean/or.rs @@ -0,0 +1,182 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitOr, ops::BitOrAssign}; + +use crate::{ + eq::EqGadget, + fields::{fp::FpVar, FieldVar}, +}; + +use super::Boolean; + +impl Boolean { + fn _or(&self, other: &Self) -> Result { + use Boolean::*; + match (self, other) { + (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), + (&Constant(true), _) | (_, &Constant(true)) => Ok(Constant(true)), + (Var(ref x), Var(ref y)) => Ok(Var(x.or(y)?)), + } + } + + /// Outputs `bits[0] | bits[1] | ... | bits.last().unwrap()`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// let c = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// Boolean::kary_or(&[a.clone(), b.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_or(&[a.clone(), c.clone()])?.enforce_equal(&Boolean::TRUE)?; + /// Boolean::kary_or(&[b.clone(), c.clone()])?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn kary_or(bits: &[Self]) -> Result { + assert!(!bits.is_empty()); + if bits.len() <= 3 { + let mut cur: Option = None; + for next in bits { + cur = if let Some(b) = cur { + Some(b | next) + } else { + Some(next.clone()) + }; + } + + Ok(cur.expect("should not be 0")) + } else { + // b0 | b1 | ... | bN == 1 if and only if not all of b0, b1, ..., bN are 0. + // We can enforce this by requiring that the sum of b0, b1, ..., bN is not 0. + let sum_bits: FpVar<_> = bits.iter().map(|b| FpVar::from(b.clone())).sum(); + sum_bits.is_neq(&FpVar::zero()) + } + } +} + +impl<'a, F: PrimeField> BitOr for &'a Boolean { + type Output = Boolean; + + /// Outputs `self | other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// (&a | &b).enforce_equal(&Boolean::TRUE)?; + /// (&b | &a).enforce_equal(&Boolean::TRUE)?; + /// + /// (&a | &a).enforce_equal(&Boolean::TRUE)?; + /// (&b | &b).enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(other).unwrap() + } +} + +impl<'a, F: PrimeField> BitOr<&'a Self> for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: &Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl<'a, F: PrimeField> BitOr> for &'a Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Boolean) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOr for Boolean { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOrAssign for Boolean { + /// Sets `self = self | other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: Self) { + let result = self._or(&other).unwrap(); + *self = result; + } +} + +impl<'a, F: PrimeField> BitOrAssign<&'a Self> for Boolean { + /// Sets `self = self | other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: &'a Self) { + let result = self._or(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn or() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a | &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? | b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/boolean/select.rs b/src/boolean/select.rs new file mode 100644 index 00000000..78eb0448 --- /dev/null +++ b/src/boolean/select.rs @@ -0,0 +1,134 @@ +use super::*; + +impl Boolean { + /// Conditionally selects one of `first` and `second` based on the value of + /// `self`: + /// + /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs + /// `second`. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// + /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?; + /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(first, second))] + pub fn select>( + &self, + first: &T, + second: &T, + ) -> Result { + T::conditionally_select(&self, first, second) + } +} +impl CondSelectGadget for Boolean { + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_val: &Self, + false_val: &Self, + ) -> Result { + use Boolean::*; + match cond { + Constant(true) => Ok(true_val.clone()), + Constant(false) => Ok(false_val.clone()), + cond @ Var(_) => match (true_val, false_val) { + (x, &Constant(false)) => Ok(cond & x), + (&Constant(false), x) => Ok((!cond) & x), + (&Constant(true), x) => Ok(cond | x), + (x, &Constant(true)) => Ok((!cond) | x), + (a, b) => { + let cs = cond.cs(); + let result: Boolean = + AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || { + let cond = cond.value()?; + Ok(if cond { a.value()? } else { b.value()? }) + })? + .into(); + // a = self; b = other; c = cond; + // + // r = c * a + (1 - c) * b + // r = b + c * (a - b) + // c * (a - b) = r - b + // + // If a, b, cond are all boolean, so is r. + // + // self | other | cond | result + // -----|-------|---------------- + // 0 | 0 | 1 | 0 + // 0 | 1 | 1 | 0 + // 1 | 0 | 1 | 1 + // 1 | 1 | 1 | 1 + // 0 | 0 | 0 | 0 + // 0 | 1 | 0 | 1 + // 1 | 0 | 0 | 0 + // 1 | 1 | 0 | 1 + cs.enforce_constraint( + cond.lc(), + lc!() + a.lc() - b.lc(), + lc!() + result.lc() - b.lc(), + )?; + + Ok(result) + }, + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn or() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for cond in [true, false] { + let expected = Boolean::new_variable( + cs.clone(), + || Ok(if cond { a.value()? } else { b.value()? }), + expected_mode, + )?; + let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?; + let computed = cond.select(&a, &b)?; + + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/boolean/test_utils.rs b/src/boolean/test_utils.rs new file mode 100644 index 00000000..9577e82d --- /dev/null +++ b/src/boolean/test_utils.rs @@ -0,0 +1,47 @@ +use crate::test_utils; + +use super::*; +use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; + +pub(crate) fn test_unary_op( + a: bool, + mode: AllocationMode, + test: impl FnOnce(Boolean) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::new_variable(cs.clone(), || Ok(a), mode)?; + test(a) +} + +pub(crate) fn test_binary_op( + a: bool, + b: bool, + mode_a: AllocationMode, + mode_b: AllocationMode, + test: impl FnOnce(Boolean, Boolean) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::new_variable(cs.clone(), || Ok(a), mode_a)?; + let b = Boolean::::new_variable(cs.clone(), || Ok(b), mode_b)?; + test(a, b) +} + +pub(crate) fn run_binary_exhaustive( + test: impl Fn(Boolean, Boolean) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> { + for (mode_a, a) in test_utils::combination([false, true].into_iter()) { + for (mode_b, b) in test_utils::combination([false, true].into_iter()) { + test_binary_op(a, b, mode_a, mode_b, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_unary_exhaustive( + test: impl Fn(Boolean) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> { + for (mode, a) in test_utils::combination([false, true].into_iter()) { + test_unary_op(a, mode, test)?; + } + Ok(()) +} diff --git a/src/boolean/xor.rs b/src/boolean/xor.rs new file mode 100644 index 00000000..67e45b36 --- /dev/null +++ b/src/boolean/xor.rs @@ -0,0 +1,132 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitXor, ops::BitXorAssign}; + +use super::Boolean; + +impl Boolean { + fn _xor(&self, other: &Self) -> Result { + use Boolean::*; + match (self, other) { + (&Constant(false), x) | (x, &Constant(false)) => Ok(x.clone()), + (&Constant(true), x) | (x, &Constant(true)) => Ok(!x), + (Var(ref x), Var(ref y)) => Ok(Var(x.xor(y)?)), + } + } +} + +impl<'a, F: Field> BitXor for &'a Boolean { + type Output = Boolean; + + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// + /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?; + /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?; + /// + /// (&a ^ &b).enforce_equal(&Boolean::TRUE)?; + /// (&b ^ &a).enforce_equal(&Boolean::TRUE)?; + /// + /// (&a ^ &a).enforce_equal(&Boolean::FALSE)?; + /// (&b ^ &b).enforce_equal(&Boolean::FALSE)?; + /// + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(other).unwrap() + } +} + +impl<'a, F: Field> BitXor<&'a Self> for Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: &Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl<'a, F: Field> BitXor> for &'a Boolean { + type Output = Boolean; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Boolean) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXor for Boolean { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXorAssign for Boolean { + /// Sets `self = self ^ other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: Self) { + let result = self._xor(&other).unwrap(); + *self = result; + } +} + +impl<'a, F: Field> BitXorAssign<&'a Self> for Boolean { + /// Sets `self = self ^ other`. + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: &'a Self) { + let result = self._xor(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + boolean::test_utils::run_binary_exhaustive, + prelude::EqGadget, + R1CSVar, + }; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn xor() { + run_binary_exhaustive::(|a, b| { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a ^ &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? ^ b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + }) + .unwrap() + } +} diff --git a/src/cmp.rs b/src/cmp.rs new file mode 100644 index 00000000..2765415f --- /dev/null +++ b/src/cmp.rs @@ -0,0 +1,21 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; + +use crate::boolean::Boolean; + +/// Specifies how to generate constraints for comparing two variables. +pub trait CmpGadget { + fn is_gt(&self, other: &Self) -> Result, SynthesisError> { + other.is_lt(self) + } + + fn is_ge(&self, other: &Self) -> Result, SynthesisError>; + + fn is_lt(&self, other: &Self) -> Result, SynthesisError> { + Ok(!self.is_ge(other)?) + } + + fn is_le(&self, other: &Self) -> Result, SynthesisError> { + other.is_ge(self) + } +} diff --git a/src/bits/mod.rs b/src/convert.rs similarity index 64% rename from src/bits/mod.rs rename to src/convert.rs index 94ad5fab..e04abc48 100644 --- a/src/bits/mod.rs +++ b/src/convert.rs @@ -1,23 +1,8 @@ -use crate::{ - bits::{boolean::Boolean, uint8::UInt8}, - Vec, -}; use ark_ff::Field; use ark_relations::r1cs::SynthesisError; +use ark_std::vec::Vec; -/// This module contains `Boolean`, a R1CS equivalent of the `bool` type. -pub mod boolean; -/// This module contains `UInt8`, a R1CS equivalent of the `u8` type. -pub mod uint8; -/// This module contains a macro for generating `UIntN` types, which are R1CS -/// equivalents of `N`-bit unsigned integers. -#[macro_use] -pub mod uint; - -make_uint!(UInt16, 16, u16, uint16, "`U16`", "`u16`", "16"); -make_uint!(UInt32, 32, u32, uint32, "`U32`", "`u32`", "32"); -make_uint!(UInt64, 64, u64, uint64, "`U64`", "`u64`", "64"); -make_uint!(UInt128, 128, u128, uint128, "`U128`", "`u128`", "128"); +use crate::{boolean::Boolean, uint8::UInt8}; /// Specifies constraints for conversion to a little-endian bit representation /// of `self`. @@ -65,21 +50,6 @@ impl ToBitsGadget for [Boolean] { } } -impl ToBitsGadget for UInt8 { - fn to_bits_le(&self) -> Result>, SynthesisError> { - Ok(self.bits.to_vec()) - } -} - -impl ToBitsGadget for [UInt8] { - /// Interprets `self` as an integer, and outputs the little-endian - /// bit-wise decomposition of that integer. - fn to_bits_le(&self) -> Result>, SynthesisError> { - let bits = self.iter().flat_map(|b| &b.bits).cloned().collect(); - Ok(bits) - } -} - impl ToBitsGadget for Vec where [T]: ToBitsGadget, @@ -110,26 +80,17 @@ pub trait ToBytesGadget { } } -impl ToBytesGadget for [UInt8] { - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self.to_vec()) - } -} - -impl ToBytesGadget for Vec> { - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self.clone()) - } -} - impl<'a, F: Field, T: 'a + ToBytesGadget> ToBytesGadget for &'a T { fn to_bytes(&self) -> Result>, SynthesisError> { (*self).to_bytes() } } -impl<'a, F: Field> ToBytesGadget for &'a [UInt8] { - fn to_bytes(&self) -> Result>, SynthesisError> { - Ok(self.to_vec()) - } +/// Specifies how to convert a variable of type `Self` to variables of +/// type `FpVar` +pub trait ToConstraintFieldGadget { + /// Converts `self` to `FpVar` variables. + fn to_constraint_field( + &self, + ) -> Result>, ark_relations::r1cs::SynthesisError>; } diff --git a/src/eq.rs b/src/eq.rs index f1184619..4f2c066b 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -1,5 +1,5 @@ use crate::{prelude::*, Vec}; -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_relations::r1cs::SynthesisError; /// Specifies how to generate constraints that check for equality for two @@ -14,7 +14,7 @@ pub trait EqGadget { /// /// By default, this is defined as `self.is_eq(other)?.not()`. fn is_neq(&self, other: &Self) -> Result, SynthesisError> { - Ok(self.is_eq(other)?.not()) + Ok(!self.is_eq(other)?) } /// If `should_enforce == true`, enforce that `self` and `other` are equal; @@ -82,7 +82,7 @@ pub trait EqGadget { } } -impl + R1CSVar, F: Field> EqGadget for [T] { +impl + R1CSVar, F: PrimeField> EqGadget for [T] { #[tracing::instrument(target = "r1cs", skip(self, other))] fn is_eq(&self, other: &Self) -> Result, SynthesisError> { assert_eq!(self.len(), other.len()); @@ -116,7 +116,7 @@ impl + R1CSVar, F: Field> EqGadget for [T] { assert_eq!(self.len(), other.len()); let some_are_different = self.is_neq(other)?; if [&some_are_different, should_enforce].is_constant() { - assert!(some_are_different.value().unwrap()); + assert!(some_are_different.value()?); Ok(()) } else { let cs = [&some_are_different, should_enforce].cs(); diff --git a/src/fields/cubic_extension.rs b/src/fields/cubic_extension.rs index a4465819..20c040f4 100644 --- a/src/fields/cubic_extension.rs +++ b/src/fields/cubic_extension.rs @@ -6,9 +6,10 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use core::{borrow::Borrow, marker::PhantomData}; use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{fp::FpVar, FieldOpsBounds, FieldVar}, prelude::*, - ToConstraintFieldGadget, Vec, + Vec, }; /// This struct is the `R1CS` equivalent of the cubic extension field type @@ -372,7 +373,7 @@ where let b0 = self.c0.is_eq(&other.c0)?; let b1 = self.c1.is_eq(&other.c1)?; let b2 = self.c2.is_eq(&other.c2)?; - b0.and(&b1)?.and(&b2) + Ok(b0 & b1 & b2) } #[inline] @@ -396,9 +397,7 @@ where condition: &Boolean, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/fields/emulated_fp/allocated_field_var.rs b/src/fields/emulated_fp/allocated_field_var.rs index 6679425d..92b575fd 100644 --- a/src/fields/emulated_fp/allocated_field_var.rs +++ b/src/fields/emulated_fp/allocated_field_var.rs @@ -3,7 +3,11 @@ use super::{ reduce::{bigint_to_basefield, limbs_to_bigint, Reducer}, AllocatedMulResultVar, }; -use crate::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, + fields::fp::FpVar, + prelude::*, +}; use ark_ff::{BigInteger, PrimeField}; use ark_relations::{ ns, diff --git a/src/fields/emulated_fp/field_var.rs b/src/fields/emulated_fp/field_var.rs index 41c2b7af..09fcc22b 100644 --- a/src/fields/emulated_fp/field_var.rs +++ b/src/fields/emulated_fp/field_var.rs @@ -1,9 +1,10 @@ use super::{params::OptimizationType, AllocatedEmulatedFpVar, MulResultVar}; use crate::{ boolean::Boolean, + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{fp::FpVar, FieldVar}, prelude::*, - R1CSVar, ToConstraintFieldGadget, + R1CSVar, }; use ark_ff::{BigInteger, PrimeField}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, Result as R1CSResult, SynthesisError}; @@ -217,7 +218,7 @@ impl EqGadget for EmulatedFpVar CondSelectGadget false_value: &Self, ) -> R1CSResult { match cond { - Boolean::Constant(true) => Ok(true_value.clone()), - Boolean::Constant(false) => Ok(false_value.clone()), + &Boolean::Constant(true) => Ok(true_value.clone()), + &Boolean::Constant(false) => Ok(false_value.clone()), _ => { let cs = cond.cs(); let true_value = match true_value { diff --git a/src/fields/fp/cmp.rs b/src/fields/fp/cmp.rs index 4612d18f..f3304e4a 100644 --- a/src/fields/fp/cmp.rs +++ b/src/fields/fp/cmp.rs @@ -1,8 +1,8 @@ use crate::{ boolean::Boolean, + convert::ToBitsGadget, fields::{fp::FpVar, FieldVar}, prelude::*, - ToBitsGadget, }; use ark_ff::PrimeField; use ark_relations::r1cs::{SynthesisError, Variable}; diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index bd955bd7..ae01d9b3 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -7,9 +7,10 @@ use core::borrow::Borrow; use crate::{ boolean::AllocatedBool, + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{FieldOpsBounds, FieldVar}, prelude::*, - Assignment, ToConstraintFieldGadget, Vec, + Assignment, Vec, }; use ark_std::iter::Sum; @@ -50,6 +51,35 @@ pub enum FpVar { Var(AllocatedFp), } +impl FpVar { + /// Decomposes `self` into a vector of `bits` and a remainder `rest` such that + /// * `bits.len() == size`, and + /// * `rest == 0`. + pub fn to_bits_le_with_top_bits_zero( + &self, + size: usize, + ) -> Result<(Vec>, Self), SynthesisError> { + assert!(size <= F::MODULUS_BIT_SIZE as usize - 1); + let cs = self.cs(); + let mode = if self.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + + let value = self.value().map(|f| f.into_bigint()); + let lower_bits = (0..size) + .map(|i| { + Boolean::new_variable(cs.clone(), || value.map(|v| v.get_bit(i as usize)), mode) + }) + .collect::, _>>()?; + let lower_bits_fp = Boolean::le_bits_to_fp(&lower_bits)?; + let rest = self - &lower_bits_fp; + rest.enforce_equal(&Self::zero())?; + Ok((lower_bits, rest)) + } +} + impl R1CSVar for FpVar { type Value = F; @@ -130,13 +160,15 @@ impl AllocatedFp { /// /// This does not create any constraints and only creates one linear /// combination. - pub fn addmany<'a, I: Iterator>(iter: I) -> Self { + pub fn add_many, I: Iterator>(iter: I) -> Self { let mut cs = ConstraintSystemRef::None; let mut has_value = true; let mut value = F::zero(); let mut new_lc = lc!(); + let mut num_iters = 0; for variable in iter { + let variable = variable.borrow(); if !variable.cs.is_none() { cs = cs.or(variable.cs.clone()); } @@ -146,14 +178,16 @@ impl AllocatedFp { value += variable.value.unwrap(); } new_lc = new_lc + variable.variable; + num_iters += 1; } + assert_ne!(num_iters, 0); let variable = cs.new_lc(new_lc).unwrap(); if has_value { - AllocatedFp::new(Some(value), variable, cs.clone()) + AllocatedFp::new(Some(value), variable, cs) } else { - AllocatedFp::new(None, variable, cs.clone()) + AllocatedFp::new(None, variable, cs) } } @@ -324,7 +358,7 @@ impl AllocatedFp { /// This requires two constraints. #[tracing::instrument(target = "r1cs")] pub fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - Ok(self.is_neq(other)?.not()) + Ok(!self.is_neq(other)?) } /// Outputs the bit `self != other`. @@ -397,7 +431,7 @@ impl AllocatedFp { )?; self.cs.enforce_constraint( lc!() + self.variable - other.variable, - is_not_equal.not().lc(), + (!&is_not_equal).lc(), lc!(), )?; Ok(is_not_equal) @@ -560,8 +594,8 @@ impl CondSelectGadget for AllocatedFp { false_val: &Self, ) -> Result { match cond { - Boolean::Constant(true) => Ok(true_val.clone()), - Boolean::Constant(false) => Ok(false_val.clone()), + &Boolean::Constant(true) => Ok(true_val.clone()), + &Boolean::Constant(false) => Ok(false_val.clone()), _ => { let cs = cond.cs(); let result = Self::new_witness(cs.clone(), || { @@ -958,13 +992,13 @@ impl CondSelectGadget for FpVar { false_value: &Self, ) -> Result { match cond { - Boolean::Constant(true) => Ok(true_value.clone()), - Boolean::Constant(false) => Ok(false_value.clone()), + &Boolean::Constant(true) => Ok(true_value.clone()), + &Boolean::Constant(false) => Ok(false_value.clone()), _ => { match (true_value, false_value) { (Self::Constant(t), Self::Constant(f)) => { let is = AllocatedFp::from(cond.clone()); - let not = AllocatedFp::from(cond.not()); + let not = AllocatedFp::from(!cond); // cond * t + (1 - cond) * f Ok(is.mul_constant(*t).add(¬.mul_constant(*f)).into()) }, @@ -1056,7 +1090,23 @@ impl AllocVar for FpVar { impl<'a, F: PrimeField> Sum<&'a FpVar> for FpVar { fn sum>>(iter: I) -> FpVar { let mut sum_constants = F::zero(); - let sum_variables = FpVar::Var(AllocatedFp::::addmany(iter.filter_map(|x| match x { + let sum_variables = FpVar::Var(AllocatedFp::::add_many(iter.filter_map(|x| match x { + FpVar::Constant(c) => { + sum_constants += c; + None + }, + FpVar::Var(v) => Some(v), + }))); + + let sum = sum_variables + sum_constants; + sum + } +} + +impl<'a, F: PrimeField> Sum> for FpVar { + fn sum>>(iter: I) -> FpVar { + let mut sum_constants = F::zero(); + let sum_variables = FpVar::Var(AllocatedFp::::add_many(iter.filter_map(|x| match x { FpVar::Constant(c) => { sum_constants += c; None diff --git a/src/fields/mod.rs b/src/fields/mod.rs index bb82c5f2..bced7ffd 100644 --- a/src/fields/mod.rs +++ b/src/fields/mod.rs @@ -5,6 +5,7 @@ use core::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; +use crate::convert::{ToBitsGadget, ToBytesGadget}; use crate::prelude::*; /// This module contains a generic implementation of cubic extension field @@ -65,7 +66,7 @@ pub trait FieldOpsBounds<'a, F, T: 'a>: } /// A variable representing a field. Corresponds to the native type `F`. -pub trait FieldVar: +pub trait FieldVar: 'static + Clone + From> diff --git a/src/fields/quadratic_extension.rs b/src/fields/quadratic_extension.rs index 5e665bb4..da944cfd 100644 --- a/src/fields/quadratic_extension.rs +++ b/src/fields/quadratic_extension.rs @@ -6,9 +6,10 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use core::{borrow::Borrow, marker::PhantomData}; use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, fields::{fp::FpVar, FieldOpsBounds, FieldVar}, prelude::*, - ToConstraintFieldGadget, Vec, + Vec, }; /// This struct is the `R1CS` equivalent of the quadratic extension field type @@ -377,7 +378,7 @@ where fn is_eq(&self, other: &Self) -> Result, SynthesisError> { let b0 = self.c0.is_eq(&other.c0)?; let b1 = self.c1.is_eq(&other.c1)?; - b0.and(&b1) + Ok(b0 & b1) } #[inline] @@ -400,9 +401,7 @@ where condition: &Boolean, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/groups/curves/short_weierstrass/bls12/mod.rs b/src/groups/curves/short_weierstrass/bls12/mod.rs index 263a1bd5..7ac9f9b8 100644 --- a/src/groups/curves/short_weierstrass/bls12/mod.rs +++ b/src/groups/curves/short_weierstrass/bls12/mod.rs @@ -201,7 +201,7 @@ impl G2PreparedVar

{ let q = q.to_affine()?; let two_inv = P::Fp::one().double().inverse().unwrap(); // Enforce that `q` is not the point at infinity. - q.infinity.enforce_not_equal(&Boolean::Constant(true))?; + q.infinity.enforce_not_equal(&Boolean::TRUE)?; let mut ell_coeffs = vec![]; let mut r = q.clone(); diff --git a/src/groups/curves/short_weierstrass/mnt4/mod.rs b/src/groups/curves/short_weierstrass/mnt4/mod.rs index 1bd768e0..7908852e 100644 --- a/src/groups/curves/short_weierstrass/mnt4/mod.rs +++ b/src/groups/curves/short_weierstrass/mnt4/mod.rs @@ -6,6 +6,7 @@ use ark_ff::Field; use ark_relations::r1cs::{Namespace, SynthesisError}; use crate::{ + convert::ToBytesGadget, fields::{fp::FpVar, fp2::Fp2Var, FieldVar}, groups::curves::short_weierstrass::ProjectiveVar, pairing::mnt4::PairingVar, diff --git a/src/groups/curves/short_weierstrass/mnt6/mod.rs b/src/groups/curves/short_weierstrass/mnt6/mod.rs index 6d216e14..9e534298 100644 --- a/src/groups/curves/short_weierstrass/mnt6/mod.rs +++ b/src/groups/curves/short_weierstrass/mnt6/mod.rs @@ -6,6 +6,7 @@ use ark_ff::Field; use ark_relations::r1cs::{Namespace, SynthesisError}; use crate::{ + convert::ToBytesGadget, fields::{fp::FpVar, fp3::Fp3Var, FieldVar}, groups::curves::short_weierstrass::ProjectiveVar, pairing::mnt6::PairingVar, diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index 1499ffd0..82743058 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -7,8 +7,12 @@ use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul}; use non_zero_affine::NonZeroAffineVar; -use crate::fields::emulated_fp::EmulatedFpVar; -use crate::{fields::fp::FpVar, prelude::*, ToConstraintFieldGadget, Vec}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, + fields::{emulated_fp::EmulatedFpVar, fp::FpVar}, + prelude::*, + Vec, +}; /// This module provides a generic implementation of G1 and G2 for /// the [\[BLS12]\]() family of bilinear groups. @@ -178,7 +182,7 @@ where // `z_inv * self.z = 0` if `self.is_zero()`. // // Thus, `z_inv * self.z = !self.is_zero()`. - z_inv.mul_equals(&self.z, &F::from(infinity.not()))?; + z_inv.mul_equals(&self.z, &F::from(!&infinity))?; let non_zero_x = &self.x * &z_inv; let non_zero_y = &self.y * &z_inv; @@ -755,9 +759,9 @@ where fn is_eq(&self, other: &Self) -> Result>, SynthesisError> { let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; - let coordinates_equal = x_equal.and(&y_equal)?; - let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; - both_are_zero.or(&coordinates_equal) + let coordinates_equal = x_equal & y_equal; + let both_are_zero = self.is_zero()? & other.is_zero()?; + Ok(both_are_zero | coordinates_equal) } #[inline] @@ -769,12 +773,9 @@ where ) -> Result<(), SynthesisError> { let x_equal = (&self.x * &other.z).is_eq(&(&other.x * &self.z))?; let y_equal = (&self.y * &other.z).is_eq(&(&other.y * &self.z))?; - let coordinates_equal = x_equal.and(&y_equal)?; - let both_are_zero = self.is_zero()?.and(&other.is_zero()?)?; - both_are_zero - .or(&coordinates_equal)? - .conditional_enforce_equal(&Boolean::Constant(true), condition)?; - Ok(()) + let coordinates_equal = x_equal & y_equal; + let both_are_zero = self.is_zero()? & other.is_zero()?; + (both_are_zero | coordinates_equal).conditional_enforce_equal(&Boolean::TRUE, condition) } #[inline] @@ -785,9 +786,7 @@ where condition: &Boolean>, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } @@ -980,10 +979,10 @@ where mod test_sw_curve { use crate::{ alloc::AllocVar, + convert::ToBitsGadget, eq::EqGadget, fields::{emulated_fp::EmulatedFpVar, fp::FpVar}, groups::{curves::short_weierstrass::ProjectiveVar, CurveVar}, - ToBitsGadget, }; use ark_ec::{ short_weierstrass::{Projective, SWCurveConfig}, diff --git a/src/groups/curves/short_weierstrass/non_zero_affine.rs b/src/groups/curves/short_weierstrass/non_zero_affine.rs index c1de4161..7b894348 100644 --- a/src/groups/curves/short_weierstrass/non_zero_affine.rs +++ b/src/groups/curves/short_weierstrass/non_zero_affine.rs @@ -188,7 +188,7 @@ where ) -> Result::BasePrimeField>, SynthesisError> { let x_equal = self.x.is_eq(&other.x)?; let y_equal = self.y.is_eq(&other.y)?; - x_equal.and(&y_equal) + Ok(x_equal & y_equal) } #[inline] @@ -200,8 +200,8 @@ where ) -> Result<(), SynthesisError> { let x_equal = self.x.is_eq(&other.x)?; let y_equal = self.y.is_eq(&other.y)?; - let coordinates_equal = x_equal.and(&y_equal)?; - coordinates_equal.conditional_enforce_equal(&Boolean::Constant(true), condition)?; + let coordinates_equal = x_equal & y_equal; + coordinates_equal.conditional_enforce_equal(&Boolean::TRUE, condition)?; Ok(()) } @@ -221,9 +221,7 @@ where condition: &Boolean<::BasePrimeField>, ) -> Result<(), SynthesisError> { let is_equal = self.is_eq(other)?; - is_equal - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (is_equal & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/groups/curves/twisted_edwards/mod.rs b/src/groups/curves/twisted_edwards/mod.rs index 62bce203..f88431d9 100644 --- a/src/groups/curves/twisted_edwards/mod.rs +++ b/src/groups/curves/twisted_edwards/mod.rs @@ -8,8 +8,12 @@ use ark_ec::{ use ark_ff::{BitIteratorBE, Field, One, PrimeField, Zero}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; -use crate::fields::emulated_fp::EmulatedFpVar; -use crate::{prelude::*, ToConstraintFieldGadget, Vec}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget}, + fields::emulated_fp::EmulatedFpVar, + prelude::*, + Vec, +}; use crate::fields::fp::FpVar; use ark_std::{borrow::Borrow, marker::PhantomData, ops::Mul}; @@ -348,7 +352,7 @@ where let x_coeffs = coords.iter().map(|p| p.0).collect::>(); let y_coeffs = coords.iter().map(|p| p.1).collect::>(); - let precomp = bits[0].and(&bits[1])?; + let precomp = &bits[0] & &bits[1]; let x = F::zero() + x_coeffs[0] @@ -413,7 +417,7 @@ where } fn is_zero(&self) -> Result>, SynthesisError> { - self.x.is_zero()?.and(&self.y.is_one()?) + Ok(self.x.is_zero()? & &self.y.is_one()?) } #[tracing::instrument(target = "r1cs", skip(cs, f))] @@ -859,7 +863,7 @@ where fn is_eq(&self, other: &Self) -> Result>, SynthesisError> { let x_equal = self.x.is_eq(&other.x)?; let y_equal = self.y.is_eq(&other.y)?; - x_equal.and(&y_equal) + Ok(x_equal & y_equal) } #[inline] @@ -881,9 +885,7 @@ where other: &Self, condition: &Boolean>, ) -> Result<(), SynthesisError> { - self.is_eq(other)? - .and(condition)? - .enforce_equal(&Boolean::Constant(false)) + (self.is_eq(other)? & condition).enforce_equal(&Boolean::FALSE) } } diff --git a/src/groups/mod.rs b/src/groups/mod.rs index 444cdf18..08edbd57 100644 --- a/src/groups/mod.rs +++ b/src/groups/mod.rs @@ -1,4 +1,8 @@ -use crate::{fields::emulated_fp::EmulatedFpVar, prelude::*}; +use crate::{ + convert::{ToBitsGadget, ToBytesGadget}, + fields::emulated_fp::EmulatedFpVar, + prelude::*, +}; use ark_ff::PrimeField; use ark_relations::r1cs::{Namespace, SynthesisError}; use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; diff --git a/src/lib.rs b/src/lib.rs index 9c6b7019..74c9dc21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,44 +31,75 @@ pub(crate) use ark_std::vec::Vec; use ark_ff::Field; -/// This module implements gadgets related to bit manipulation, such as -/// `Boolean` and `UInt`s. -pub mod bits; -pub use self::bits::*; +/// This module contains `Boolean`, an R1CS equivalent of the `bool` type. +pub mod boolean; -/// This module implements gadgets related to field arithmetic. +/// Finite field arithmetic. pub mod fields; -/// This module implements gadgets related to group arithmetic, and specifically -/// elliptic curve arithmetic. +/// Implementations of elliptic curve group arithmetic for popular curve models. pub mod groups; -/// This module implements gadgets related to computing pairings in bilinear -/// groups. +/// Gadgets for computing pairings in bilinear groups. pub mod pairing; -/// This module describes a trait for allocating new variables in a constraint -/// system. +/// Utilities for allocating new variables in a constraint system. pub mod alloc; -/// This module describes a trait for checking equality of variables. + +/// Utilities for comparing variables. +pub mod cmp; + +/// Utilities for converting variables to other kinds of variables. +pub mod convert; + +/// Utilities for checking equality of variables. pub mod eq; -/// This module implements functions for manipulating polynomial variables over -/// finite fields. + +/// Definitions of polynomial variables over finite fields. pub mod poly; -/// This module describes traits for conditionally selecting a variable from a + +/// Contains traits for conditionally selecting a variable from a /// list of variables. pub mod select; +#[cfg(test)] +pub(crate) mod test_utils; + +/// This module contains `UInt8`, a R1CS equivalent of the `u8` type. +pub mod uint8; +/// This module contains a macro for generating `UIntN` types, which are R1CS +/// equivalents of `N`-bit unsigned integers. +#[macro_use] +pub mod uint; + +pub mod uint16 { + pub type UInt16 = super::uint::UInt<16, u16, F>; +} +pub mod uint32 { + pub type UInt32 = super::uint::UInt<32, u32, F>; +} +pub mod uint64 { + pub type UInt64 = super::uint::UInt<64, u64, F>; +} +pub mod uint128 { + pub type UInt128 = super::uint::UInt<128, u128, F>; +} + #[allow(missing_docs)] pub mod prelude { pub use crate::{ alloc::*, - bits::{boolean::Boolean, uint32::UInt32, uint8::UInt8, ToBitsGadget, ToBytesGadget}, + boolean::Boolean, eq::*, fields::{FieldOpsBounds, FieldVar}, groups::{CurveVar, GroupOpsBounds}, pairing::PairingVar, select::*, + uint128::UInt128, + uint16::UInt16, + uint32::UInt32, + uint64::UInt64, + uint8::UInt8, R1CSVar, }; } @@ -139,12 +170,3 @@ impl Assignment for Option { self.ok_or(ark_relations::r1cs::SynthesisError::AssignmentMissing) } } - -/// Specifies how to convert a variable of type `Self` to variables of -/// type `FpVar` -pub trait ToConstraintFieldGadget { - /// Converts `self` to `FpVar` variables. - fn to_constraint_field( - &self, - ) -> Result>, ark_relations::r1cs::SynthesisError>; -} diff --git a/src/pairing/mod.rs b/src/pairing/mod.rs index dbdb6483..ab081fb8 100644 --- a/src/pairing/mod.rs +++ b/src/pairing/mod.rs @@ -1,4 +1,4 @@ -use crate::prelude::*; +use crate::{convert::ToBytesGadget, prelude::*}; use ark_ec::pairing::Pairing; use ark_relations::r1cs::SynthesisError; use core::fmt::Debug; diff --git a/src/poly/domain/mod.rs b/src/poly/domain/mod.rs index 32411867..c30bd03e 100644 --- a/src/poly/domain/mod.rs +++ b/src/poly/domain/mod.rs @@ -129,7 +129,10 @@ mod tests { use ark_relations::r1cs::ConstraintSystem; use ark_std::{rand::Rng, test_rng}; - use crate::{alloc::AllocVar, fields::fp::FpVar, poly::domain::Radix2DomainVar, R1CSVar}; + use crate::{ + alloc::AllocVar, convert::ToBitsGadget, fields::fp::FpVar, poly::domain::Radix2DomainVar, + R1CSVar, + }; fn test_query_coset_template() { const COSET_DIM: u64 = 7; @@ -145,9 +148,11 @@ mod tests { let num_cosets = 1 << (COSET_DIM - LOCALIZATION); let coset_index = rng.gen_range(0..num_cosets); + println!("{:0b}", coset_index); let coset_index_var = UInt32::new_witness(cs.clone(), || Ok(coset_index)) .unwrap() .to_bits_le() + .unwrap() .into_iter() .take(COSET_DIM as usize) .collect::>(); diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 00000000..189eae69 --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,15 @@ +use core::iter; + +use crate::alloc::AllocationMode; + +pub(crate) fn modes() -> impl Iterator { + use AllocationMode::*; + [Constant, Input, Witness].into_iter() +} + +pub(crate) fn combination( + mut i: impl Iterator, +) -> impl Iterator { + iter::from_fn(move || i.next().map(|t| modes().map(move |mode| (mode, t.clone())))) + .flat_map(|x| x) +} diff --git a/src/uint/add/mod.rs b/src/uint/add/mod.rs new file mode 100644 index 00000000..d5eb7563 --- /dev/null +++ b/src/uint/add/mod.rs @@ -0,0 +1,50 @@ +use crate::fields::fp::FpVar; + +use super::*; + +mod saturating; +mod wrapping; + +impl UInt { + /// Adds up `operands`, returning the bit decomposition of the result, along with + /// the value of the result. If all the operands are constant, then the bit decomposition + /// is empty, and the value is the constant value of the result. + /// + /// # Panics + /// + /// This method panics if the result of addition could possibly exceed the field size. + #[tracing::instrument(target = "r1cs", skip(operands, adder))] + fn add_many_helper( + operands: &[Self], + adder: impl Fn(T, T) -> T, + ) -> Result<(Vec>, Option), SynthesisError> { + // Bounds on `N` to avoid overflows + + assert!(operands.len() >= 1); + let max_value_size = N as u32 + ark_std::log2(operands.len()); + assert!(max_value_size <= F::MODULUS_BIT_SIZE); + + if operands.len() == 1 { + return Ok((operands[0].bits.to_vec(), operands[0].value)); + } + + // Compute the value of the result. + let mut value = Some(T::zero()); + for op in operands { + value = value.and_then(|v| Some(adder(v, op.value?))); + } + if operands.is_constant() { + // If all operands are constant, then the result is also constant. + // In this case, we can return early. + return Ok((Vec::new(), value)); + } + + // Compute the full (non-wrapped) sum of the operands. + let result = operands + .iter() + .map(|op| Boolean::le_bits_to_fp(&op.bits).unwrap()) + .sum::>(); + let (result, _) = result.to_bits_le_with_top_bits_zero(max_value_size as usize)?; + Ok((result, value)) + } +} diff --git a/src/uint/add/saturating.rs b/src/uint/add/saturating.rs new file mode 100644 index 00000000..62c393eb --- /dev/null +++ b/src/uint/add/saturating.rs @@ -0,0 +1,117 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; + +use crate::uint::*; +use crate::{boolean::Boolean, R1CSVar}; + +impl UInt { + /// Compute `*self = self.wrapping_add(other)`. + pub fn saturating_add_in_place(&mut self, other: &Self) { + let result = Self::saturating_add_many(&[self.clone(), other.clone()]).unwrap(); + *self = result; + } + + /// Compute `self.wrapping_add(other)`. + pub fn saturating_add(&self, other: &Self) -> Self { + let mut result = self.clone(); + result.saturating_add_in_place(other); + result + } + + /// Perform wrapping addition of `operands`. + /// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`. + /// + /// The user must ensure that overflow does not occur. + #[tracing::instrument(target = "r1cs", skip(operands))] + pub fn saturating_add_many(operands: &[Self]) -> Result + where + F: PrimeField, + { + let (sum_bits, value) = Self::add_many_helper(operands, |a, b| a.saturating_add(b))?; + if operands.is_constant() { + // If all operands are constant, then the result is also constant. + // In this case, we can return early. + Ok(UInt::constant(value.unwrap())) + } else if sum_bits.len() == N { + // No overflow occurred. + Ok(UInt::from_bits_le(&sum_bits)) + } else { + // Split the sum into the bottom `N` bits and the top bits. + let (bottom_bits, top_bits) = sum_bits.split_at(N); + + // Construct a candidate result assuming that no overflow occurred. + let bits = TryFrom::try_from(bottom_bits.to_vec()).unwrap(); + let candidate_result = UInt { bits, value }; + + // Check if any of the top bits is set. + // If any of them is set, then overflow occurred. + let overflow_occurred = Boolean::kary_or(&top_bits)?; + + // If overflow occurred, return the maximum value. + overflow_occurred.select(&Self::MAX, &candidate_result) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_saturating_add( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.saturating_add(&b); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::new_variable( + cs.clone(), + || Ok(a.value()?.saturating_add(b.value()?)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_saturating_add() { + run_binary_exhaustive(uint_saturating_add::).unwrap() + } + + #[test] + fn u16_saturating_add() { + run_binary_random::<1000, 16, _, _>(uint_saturating_add::).unwrap() + } + + #[test] + fn u32_saturating_add() { + run_binary_random::<1000, 32, _, _>(uint_saturating_add::).unwrap() + } + + #[test] + fn u64_saturating_add() { + run_binary_random::<1000, 64, _, _>(uint_saturating_add::).unwrap() + } + + #[test] + fn u128_saturating_add() { + run_binary_random::<1000, 128, _, _>(uint_saturating_add::).unwrap() + } +} diff --git a/src/uint/add/wrapping.rs b/src/uint/add/wrapping.rs new file mode 100644 index 00000000..5dfe4496 --- /dev/null +++ b/src/uint/add/wrapping.rs @@ -0,0 +1,106 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; + +use crate::uint::*; +use crate::R1CSVar; + +impl UInt { + /// Compute `*self = self.wrapping_add(other)`. + pub fn wrapping_add_in_place(&mut self, other: &Self) { + let result = Self::wrapping_add_many(&[self.clone(), other.clone()]).unwrap(); + *self = result; + } + + /// Compute `self.wrapping_add(other)`. + pub fn wrapping_add(&self, other: &Self) -> Self { + let mut result = self.clone(); + result.wrapping_add_in_place(other); + result + } + + /// Perform wrapping addition of `operands`. + /// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`. + /// + /// The user must ensure that overflow does not occur. + #[tracing::instrument(target = "r1cs", skip(operands))] + pub fn wrapping_add_many(operands: &[Self]) -> Result + where + F: PrimeField, + { + let (mut sum_bits, value) = Self::add_many_helper(operands, |a, b| a.wrapping_add(&b))?; + if operands.is_constant() { + // If all operands are constant, then the result is also constant. + // In this case, we can return early. + Ok(UInt::constant(value.unwrap())) + } else { + sum_bits.truncate(N); + Ok(UInt { + bits: sum_bits.try_into().unwrap(), + value, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_wrapping_add( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.wrapping_add(&b); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::new_variable( + cs.clone(), + || Ok(a.value()?.wrapping_add(&b.value()?)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_wrapping_add() { + run_binary_exhaustive(uint_wrapping_add::).unwrap() + } + + #[test] + fn u16_wrapping_add() { + run_binary_random::<1000, 16, _, _>(uint_wrapping_add::).unwrap() + } + + #[test] + fn u32_wrapping_add() { + run_binary_random::<1000, 32, _, _>(uint_wrapping_add::).unwrap() + } + + #[test] + fn u64_wrapping_add() { + run_binary_random::<1000, 64, _, _>(uint_wrapping_add::).unwrap() + } + + #[test] + fn u128_wrapping_add() { + run_binary_random::<1000, 128, _, _>(uint_wrapping_add::).unwrap() + } +} diff --git a/src/uint/and.rs b/src/uint/and.rs new file mode 100644 index 00000000..4fb1b5b6 --- /dev/null +++ b/src/uint/and.rs @@ -0,0 +1,263 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitAnd, ops::BitAndAssign}; + +use super::*; + +impl UInt { + fn _and(&self, other: &Self) -> Result { + let mut result = self.clone(); + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a &= b; + } + result.value = self.value.and_then(|a| Some(a & other.value?)); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd for &'a UInt { + type Output = UInt; + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a Self> for UInt { + type Output = UInt; + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: &Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd> for &'a UInt { + type Output = UInt; + + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: UInt) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAnd for UInt { + type Output = Self; + + /// Outputs `self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// (a & &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand(self, other: Self) -> Self::Output { + self._and(&other).unwrap() + } +} + +impl BitAndAssign for UInt { + /// Sets `self = self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// a &= &b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: Self) { + let result = self._and(&other).unwrap(); + *self = result; + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a Self> for UInt { + /// Sets `self = self & other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?; + /// + /// a &= &b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitand_assign(&mut self, other: &'a Self) { + let result = self._and(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_and( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a & &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? & b.value()?), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_and() { + run_binary_exhaustive(uint_and::).unwrap() + } + + #[test] + fn u16_and() { + run_binary_random::<1000, 16, _, _>(uint_and::).unwrap() + } + + #[test] + fn u32_and() { + run_binary_random::<1000, 32, _, _>(uint_and::).unwrap() + } + + #[test] + fn u64_and() { + run_binary_random::<1000, 64, _, _>(uint_and::).unwrap() + } + + #[test] + fn u128_and() { + run_binary_random::<1000, 128, _, _>(uint_and::).unwrap() + } +} diff --git a/src/uint/cmp.rs b/src/uint/cmp.rs new file mode 100644 index 00000000..5a01b169 --- /dev/null +++ b/src/uint/cmp.rs @@ -0,0 +1,218 @@ +use crate::cmp::CmpGadget; + +use super::*; + +impl> CmpGadget for UInt { + fn is_ge(&self, other: &Self) -> Result, SynthesisError> { + if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) { + let a = self.to_fp()?; + let b = other.to_fp()?; + let (bits, _) = (a - b + F::from(T::max_value()) + F::one()) + .to_bits_le_with_top_bits_zero(N + 1)?; + Ok(bits.last().unwrap().clone()) + } else { + unimplemented!("bit sizes larger than modulus size not yet supported") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_gt>( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_gt(&b)?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? > b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_lt>( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_lt(&b)?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? < b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_ge>( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_ge(&b)?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? >= b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_le>( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let computed = a.is_le(&b)?; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? <= b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_gt() { + run_binary_exhaustive(uint_gt::).unwrap() + } + + #[test] + fn u16_gt() { + run_binary_random::<1000, 16, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u32_gt() { + run_binary_random::<1000, 32, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u64_gt() { + run_binary_random::<1000, 64, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u128_gt() { + run_binary_random::<1000, 128, _, _>(uint_gt::).unwrap() + } + + #[test] + fn u8_lt() { + run_binary_exhaustive(uint_lt::).unwrap() + } + + #[test] + fn u16_lt() { + run_binary_random::<1000, 16, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u32_lt() { + run_binary_random::<1000, 32, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u64_lt() { + run_binary_random::<1000, 64, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u128_lt() { + run_binary_random::<1000, 128, _, _>(uint_lt::).unwrap() + } + + #[test] + fn u8_le() { + run_binary_exhaustive(uint_le::).unwrap() + } + + #[test] + fn u16_le() { + run_binary_random::<1000, 16, _, _>(uint_le::).unwrap() + } + + #[test] + fn u32_le() { + run_binary_random::<1000, 32, _, _>(uint_le::).unwrap() + } + + #[test] + fn u64_le() { + run_binary_random::<1000, 64, _, _>(uint_le::).unwrap() + } + + #[test] + fn u128_le() { + run_binary_random::<1000, 128, _, _>(uint_le::).unwrap() + } + + #[test] + fn u8_ge() { + run_binary_exhaustive(uint_ge::).unwrap() + } + + #[test] + fn u16_ge() { + run_binary_random::<1000, 16, _, _>(uint_ge::).unwrap() + } + + #[test] + fn u32_ge() { + run_binary_random::<1000, 32, _, _>(uint_ge::).unwrap() + } + + #[test] + fn u64_ge() { + run_binary_random::<1000, 64, _, _>(uint_ge::).unwrap() + } + + #[test] + fn u128_ge() { + run_binary_random::<1000, 128, _, _>(uint_ge::).unwrap() + } +} diff --git a/src/uint/convert.rs b/src/uint/convert.rs new file mode 100644 index 00000000..45ceb3e5 --- /dev/null +++ b/src/uint/convert.rs @@ -0,0 +1,129 @@ +use crate::convert::*; +use crate::fields::fp::FpVar; + +use super::*; + +impl UInt { + /// Converts `self` into a field element. The elements comprising `self` are + /// interpreted as a little-endian bit order representation of a field element. + /// + /// # Panics + /// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn to_fp(&self) -> Result, SynthesisError> + where + F: PrimeField, + { + assert!(N <= F::MODULUS_BIT_SIZE as usize - 1); + + Boolean::le_bits_to_fp(&self.bits) + } + + /// Converts a field element into its little-endian bit order representation. + /// + /// # Panics + /// + /// Assumes that `N` is at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise. + pub fn from_fp(other: &FpVar) -> Result<(Self, FpVar), SynthesisError> + where + F: PrimeField, + { + let (bits, rest) = other.to_bits_le_with_top_bits_zero(N)?; + let result = Self::from_bits_le(&bits); + Ok((result, rest)) + } + + /// Converts a little-endian byte order representation of bits into a + /// `UInt`. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = UInt8::new_witness(cs.clone(), || Ok(128))?; + /// + /// let f = Boolean::FALSE; + /// let t = Boolean::TRUE; + /// + /// // Construct [0, 0, 0, 0, 0, 0, 0, 1] + /// let mut bits = vec![f.clone(); 7]; + /// bits.push(t); + /// + /// let mut c = UInt8::from_bits_le(&bits); + /// var.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs")] + pub fn from_bits_le(bits: &[Boolean]) -> Self { + assert_eq!(bits.len(), N); + let bits = <&[Boolean; N]>::try_from(bits).unwrap().clone(); + let value_exists = bits.iter().all(|b| b.value().is_ok()); + let mut value = T::zero(); + for (i, b) in bits.iter().enumerate() { + if let Ok(b) = b.value() { + value = value + (T::from(b as u8).unwrap() << i); + } + } + let value = value_exists.then_some(value); + Self { bits, value } + } +} + +impl ToBitsGadget for UInt { + fn to_bits_le(&self) -> Result>, SynthesisError> { + Ok(self.bits.to_vec()) + } +} + +impl ToBitsGadget for [UInt] { + /// Interprets `self` as an integer, and outputs the little-endian + /// bit-wise decomposition of that integer. + fn to_bits_le(&self) -> Result>, SynthesisError> { + let bits = self.iter().flat_map(|b| &b.bits).cloned().collect(); + Ok(bits) + } +} + +/*****************************************************************************************/ +/********************************* Conversions to bytes. *********************************/ +/*****************************************************************************************/ + +impl ToBytesGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(self))] + fn to_bytes(&self) -> Result>, SynthesisError> { + Ok(self + .to_bits_le()? + .chunks(8) + .map(UInt8::from_bits_le) + .collect()) + } +} + +impl ToBytesGadget for [UInt] { + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut bytes = Vec::with_capacity(self.len() * (N / 8)); + for elem in self { + bytes.extend_from_slice(&elem.to_bytes()?); + } + Ok(bytes) + } +} + +impl ToBytesGadget for Vec> { + fn to_bytes(&self) -> Result>, SynthesisError> { + self.as_slice().to_bytes() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> ToBytesGadget for &'a [UInt] { + fn to_bytes(&self) -> Result>, SynthesisError> { + (*self).to_bytes() + } +} diff --git a/src/uint/eq.rs b/src/uint/eq.rs new file mode 100644 index 00000000..1b7c386f --- /dev/null +++ b/src/uint/eq.rs @@ -0,0 +1,173 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::vec::Vec; + +use crate::boolean::Boolean; +use crate::eq::EqGadget; + +use super::*; + +impl EqGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + let chunks_are_eq = self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + .map(|(a, b)| { + let a = Boolean::le_bits_to_fp(a)?; + let b = Boolean::le_bits_to_fp(b)?; + a.is_eq(&b) + }) + .collect::, _>>()?; + Boolean::kary_and(&chunks_are_eq) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + for (a, b) in self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + { + let a = Boolean::le_bits_to_fp(a)?; + let b = Boolean::le_bits_to_fp(b)?; + a.conditional_enforce_equal(&b, condition)?; + } + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap(); + for (a, b) in self + .bits + .chunks(chunk_size) + .zip(other.bits.chunks(chunk_size)) + { + let a = Boolean::le_bits_to_fp(a)?; + let b = Boolean::le_bits_to_fp(b)?; + a.conditional_enforce_not_equal(&b, condition)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_eq( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.is_eq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + fn uint_neq( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = a.is_neq(&b)?; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_eq() { + run_binary_exhaustive(uint_eq::).unwrap() + } + + #[test] + fn u16_eq() { + run_binary_random::<1000, 16, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u32_eq() { + run_binary_random::<1000, 32, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u64_eq() { + run_binary_random::<1000, 64, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u128_eq() { + run_binary_random::<1000, 128, _, _>(uint_eq::).unwrap() + } + + #[test] + fn u8_neq() { + run_binary_exhaustive(uint_neq::).unwrap() + } + + #[test] + fn u16_neq() { + run_binary_random::<1000, 16, _, _>(uint_neq::).unwrap() + } + + #[test] + fn u32_neq() { + run_binary_random::<1000, 32, _, _>(uint_neq::).unwrap() + } + + #[test] + fn u64_neq() { + run_binary_random::<1000, 64, _, _>(uint_neq::).unwrap() + } + + #[test] + fn u128_neq() { + run_binary_random::<1000, 128, _, _>(uint_neq::).unwrap() + } +} diff --git a/src/uint/mod.rs b/src/uint/mod.rs new file mode 100644 index 00000000..1544f24a --- /dev/null +++ b/src/uint/mod.rs @@ -0,0 +1,160 @@ +use ark_ff::{Field, PrimeField}; +use core::{borrow::Borrow, convert::TryFrom, fmt::Debug}; + +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; + +use crate::{boolean::Boolean, prelude::*, Assignment, Vec}; + +mod add; +mod and; +mod cmp; +mod convert; +mod eq; +mod not; +mod or; +mod rotate; +mod select; +mod shl; +mod shr; +mod xor; + +#[doc(hidden)] +pub mod prim_uint; +pub use prim_uint::*; + +#[cfg(test)] +pub(crate) mod test_utils; + +/// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s. +#[derive(Clone, Debug)] +pub struct UInt { + #[doc(hidden)] + pub bits: [Boolean; N], + #[doc(hidden)] + pub value: Option, +} + +impl R1CSVar for UInt { + type Value = T; + + fn cs(&self) -> ConstraintSystemRef { + self.bits.as_ref().cs() + } + + fn value(&self) -> Result { + let mut value = T::zero(); + for (i, bit) in self.bits.iter().enumerate() { + value = value + (T::from(bit.value()? as u8).unwrap() << i); + } + debug_assert_eq!(self.value, Some(value)); + Ok(value) + } +} + +impl UInt { + pub const MAX: Self = Self { + bits: [Boolean::TRUE; N], + value: Some(T::MAX), + }; + + /// Construct a constant [`UInt`] from the native unsigned integer type. + /// + /// This *does not* create new variables or constraints. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// + /// let constant = UInt8::constant(2); + /// var.enforce_equal(&constant)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn constant(value: T) -> Self { + let mut bits = [Boolean::FALSE; N]; + let mut bit_values = value; + + for i in 0..N { + bits[i] = Boolean::constant((bit_values & T::one()) == T::one()); + bit_values = bit_values >> 1u8; + } + + Self { + bits, + value: Some(value), + } + } + + /// Construct a constant vector of [`UInt`] from a vector of the native type + /// + /// This *does not* create any new variables or constraints. + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let var = vec![UInt8::new_witness(cs.clone(), || Ok(2))?]; + /// + /// let constant = UInt8::constant_vec(&[2]); + /// var.enforce_equal(&constant)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn constant_vec(values: &[T]) -> Vec { + values.iter().map(|v| Self::constant(*v)).collect() + } + + /// Allocates a slice of `uN`'s as private witnesses. + pub fn new_witness_vec( + cs: impl Into>, + values: &[impl Into> + Copy], + ) -> Result, SynthesisError> { + let ns = cs.into(); + let cs = ns.cs(); + let mut output_vec = Vec::with_capacity(values.len()); + for value in values { + let byte: Option = Into::into(*value); + output_vec.push(Self::new_witness(cs.clone(), || byte.get())?); + } + Ok(output_vec) + } +} + +impl AllocVar + for UInt +{ + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + let value = f().map(|f| *f.borrow()).ok(); + + let mut values = [None; N]; + if let Some(val) = value { + values + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = Some(((val >> i) & T::one()) == T::one())); + } + + let mut bits = [Boolean::FALSE; N]; + for (b, v) in bits.iter_mut().zip(&values) { + *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; + } + Ok(Self { bits, value }) + } +} diff --git a/src/uint/not.rs b/src/uint/not.rs new file mode 100644 index 00000000..1bb883d5 --- /dev/null +++ b/src/uint/not.rs @@ -0,0 +1,131 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::ops::Not; + +use super::*; + +impl UInt { + fn _not(&self) -> Result { + let mut result = self.clone(); + for a in &mut result.bits { + *a = !&*a + } + result.value = self.value.map(Not::not); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> Not for &'a UInt { + type Output = UInt; + /// Outputs `!self`. + /// + /// If `self` is a constant, then this method *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(!2))?; + /// + /// (!a).enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> Not for UInt { + type Output = UInt; + + /// Outputs `!self`. + /// + /// If `self` is a constant, then this method *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(!2))?; + /// + /// (!a).enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + fn not(self) -> Self::Output { + self._not().unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_unary_exhaustive, run_unary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_not( + a: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let computed = !&a; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + UInt::::new_variable(cs.clone(), || Ok(!a.value()?), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_not() { + run_unary_exhaustive(uint_not::).unwrap() + } + + #[test] + fn u16_not() { + run_unary_random::<1000, 16, _, _>(uint_not::).unwrap() + } + + #[test] + fn u32_not() { + run_unary_random::<1000, 32, _, _>(uint_not::).unwrap() + } + + #[test] + fn u64_not() { + run_unary_random::<1000, 64, _, _>(uint_not::).unwrap() + } + + #[test] + fn u128() { + run_unary_random::<1000, 128, _, _>(uint_not::).unwrap() + } +} diff --git a/src/uint/or.rs b/src/uint/or.rs new file mode 100644 index 00000000..c69fc8db --- /dev/null +++ b/src/uint/or.rs @@ -0,0 +1,176 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitOr, ops::BitOrAssign}; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _or(&self, other: &Self) -> Result { + let mut result = self.clone(); + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a |= b; + } + result.value = self.value.and_then(|a| Some(a | other.value?)); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr for &'a UInt { + type Output = UInt; + + /// Output `self | other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 | 17))?; + /// + /// (a | b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a Self> for UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: &Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr> for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: UInt) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOr for UInt { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor(self, other: Self) -> Self::Output { + self._or(&other).unwrap() + } +} + +impl BitOrAssign for UInt { + /// Sets `self = self | other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 | 17))?; + /// + /// a |= b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: Self) { + let result = self._or(&other).unwrap(); + *self = result; + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a Self> for UInt { + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitor_assign(&mut self, other: &'a Self) { + let result = self._or(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_or( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a | &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? | b.value()?), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_or() { + run_binary_exhaustive(uint_or::).unwrap() + } + + #[test] + fn u16_or() { + run_binary_random::<1000, 16, _, _>(uint_or::).unwrap() + } + + #[test] + fn u32_or() { + run_binary_random::<1000, 32, _, _>(uint_or::).unwrap() + } + + #[test] + fn u64_or() { + run_binary_random::<1000, 64, _, _>(uint_or::).unwrap() + } + + #[test] + fn u128_or() { + run_binary_random::<1000, 128, _, _>(uint_or::).unwrap() + } +} diff --git a/src/uint/prim_uint.rs b/src/uint/prim_uint.rs new file mode 100644 index 00000000..3b3fff81 --- /dev/null +++ b/src/uint/prim_uint.rs @@ -0,0 +1,175 @@ +use core::ops::{Shl, ShlAssign, Shr, ShrAssign}; +use core::usize; + +#[doc(hidden)] +// Adapted from +pub trait PrimUInt: + core::fmt::Debug + + num_traits::PrimInt + + num_traits::WrappingAdd + + num_traits::SaturatingAdd + + Shl + + Shl + + Shl + + Shl + + Shl + + Shl + + Shr + + Shr + + Shr + + Shr + + Shr + + Shr + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShlAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + ShrAssign + + Into + + _private::Sealed + + ark_std::UniformRand +{ + type Bytes: NumBytes; + const MAX: Self; + #[doc(hidden)] + const MAX_VALUE_BIT_DECOMP: &'static [bool]; + + /// Return the memory representation of this number as a byte array in little-endian byte order. + /// + /// # Examples + /// + /// ``` + /// use ark_r1cs_std::uint::PrimUInt; + /// + /// let bytes = PrimUInt::to_le_bytes(&0x12345678u32); + /// assert_eq!(bytes, [0x78, 0x56, 0x34, 0x12]); + /// ``` + fn to_le_bytes(&self) -> Self::Bytes; + + /// Return the memory representation of this number as a byte array in big-endian byte order. + /// + /// # Examples + /// + /// ``` + /// use ark_r1cs_std::uint::PrimUInt; + /// + /// let bytes = PrimUInt::to_be_bytes(&0x12345678u32); + /// assert_eq!(bytes, [0x12, 0x34, 0x56, 0x78]); + /// ``` + fn to_be_bytes(&self) -> Self::Bytes; +} + +impl PrimUInt for u8 { + const MAX: Self = u8::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 8]; + type Bytes = [u8; 1]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u8::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u8::to_be_bytes(*self) + } +} + +impl PrimUInt for u16 { + const MAX: Self = u16::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 16]; + type Bytes = [u8; 2]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u16::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u16::to_be_bytes(*self) + } +} + +impl PrimUInt for u32 { + const MAX: Self = u32::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 32]; + type Bytes = [u8; 4]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u32::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u32::to_be_bytes(*self) + } +} + +impl PrimUInt for u64 { + const MAX: Self = u64::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 64]; + type Bytes = [u8; 8]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u64::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u64::to_be_bytes(*self) + } +} + +impl PrimUInt for u128 { + const MAX: Self = u128::MAX; + const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 128]; + type Bytes = [u8; 16]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + u128::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + u128::to_be_bytes(*self) + } +} + +#[doc(hidden)] +pub trait NumBytes: + core::fmt::Debug + + AsRef<[u8]> + + AsMut<[u8]> + + PartialEq + + Eq + + PartialOrd + + Ord + + core::hash::Hash + + core::borrow::Borrow<[u8]> + + core::borrow::BorrowMut<[u8]> +{ +} + +#[doc(hidden)] +impl NumBytes for [u8; N] {} + +mod _private { + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for u128 {} +} diff --git a/src/uint/rotate.rs b/src/uint/rotate.rs new file mode 100644 index 00000000..f2d50d5d --- /dev/null +++ b/src/uint/rotate.rs @@ -0,0 +1,174 @@ +use super::*; + +impl UInt { + /// Rotates `self` to the right by `by` steps, wrapping around. + /// + /// # Examples + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?; + /// let b = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?; + /// + /// a.rotate_right(8).enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn rotate_right(&self, by: usize) -> Self { + let by = by % N; + let mut result = self.clone(); + // `[T]::rotate_left` corresponds to a `rotate_right` of the bits. + result.bits.rotate_left(by); + result.value = self.value.map(|v| v.rotate_right(by as u32)); + result + } + + /// Rotates `self` to the left by `by` steps, wrapping around. + /// + /// # Examples + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?; + /// let b = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?; + /// + /// a.rotate_left(8).enforce_equal(&b)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn rotate_left(&self, by: usize) -> Self { + let by = by % N; + let mut result = self.clone(); + // `[T]::rotate_right` corresponds to a `rotate_left` of the bits. + result.bits.rotate_right(by); + result.value = self.value.map(|v| v.rotate_left(by as u32)); + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_unary_exhaustive, run_unary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_rotate_left( + a: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for shift in 0..N { + let computed = a.rotate_left(shift); + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()?.rotate_left(shift as u32)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + } + + fn uint_rotate_right( + a: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for shift in 0..N { + let computed = a.rotate_right(shift); + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()?.rotate_right(shift as u32)), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + } + + #[test] + fn u8_rotate_left() { + run_unary_exhaustive(uint_rotate_left::).unwrap() + } + + #[test] + fn u16_rotate_left() { + run_unary_random::<1000, 16, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u32_rotate_left() { + run_unary_random::<1000, 32, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u64_rotate_left() { + run_unary_random::<200, 64, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u128_rotate_left() { + run_unary_random::<100, 128, _, _>(uint_rotate_left::).unwrap() + } + + #[test] + fn u8_rotate_right() { + run_unary_exhaustive(uint_rotate_right::).unwrap() + } + + #[test] + fn u16_rotate_right() { + run_unary_random::<1000, 16, _, _>(uint_rotate_right::).unwrap() + } + + #[test] + fn u32_rotate_right() { + run_unary_random::<1000, 32, _, _>(uint_rotate_right::).unwrap() + } + + #[test] + fn u64_rotate_right() { + run_unary_random::<200, 64, _, _>(uint_rotate_right::).unwrap() + } + + #[test] + fn u128_rotate_right() { + run_unary_random::<100, 128, _, _>(uint_rotate_right::).unwrap() + } +} diff --git a/src/uint/select.rs b/src/uint/select.rs new file mode 100644 index 00000000..6c061fe2 --- /dev/null +++ b/src/uint/select.rs @@ -0,0 +1,98 @@ +use super::*; +use crate::select::CondSelectGadget; + +impl CondSelectGadget + for UInt +{ + #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> Result { + let selected_bits = true_value + .bits + .iter() + .zip(&false_value.bits) + .map(|(t, f)| cond.select(t, f)); + let mut bits = [Boolean::FALSE; N]; + for (result, new) in bits.iter_mut().zip(selected_bits) { + *result = new?; + } + + let value = cond.value().ok().and_then(|cond| { + if cond { + true_value.value().ok() + } else { + false_value.value().ok() + } + }); + Ok(Self { bits, value }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_select( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + for cond in [true, false] { + let expected = UInt::new_variable( + cs.clone(), + || Ok(if cond { a.value()? } else { b.value()? }), + expected_mode, + )?; + let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?; + let computed = cond.select(&a, &b)?; + + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + } + Ok(()) + } + + #[test] + fn u8_select() { + run_binary_exhaustive(uint_select::).unwrap() + } + + #[test] + fn u16_select() { + run_binary_random::<1000, 16, _, _>(uint_select::).unwrap() + } + + #[test] + fn u32_select() { + run_binary_random::<1000, 32, _, _>(uint_select::).unwrap() + } + + #[test] + fn u64_select() { + run_binary_random::<1000, 64, _, _>(uint_select::).unwrap() + } + + #[test] + fn u128_select() { + run_binary_random::<1000, 128, _, _>(uint_select::).unwrap() + } +} diff --git a/src/uint/shl.rs b/src/uint/shl.rs new file mode 100644 index 00000000..645a07a4 --- /dev/null +++ b/src/uint/shl.rs @@ -0,0 +1,154 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::Shl, ops::ShlAssign}; + +use crate::boolean::Boolean; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _shl_u128(&self, other: u128) -> Result { + if other < N as u128 { + let mut bits = [Boolean::FALSE; N]; + for (a, b) in bits[other as usize..].iter_mut().zip(&self.bits) { + *a = b.clone(); + } + + let value = self.value.and_then(|a| Some(a << other)); + Ok(Self { bits, value }) + } else { + panic!("attempt to shift left with overflow") + } + } +} + +impl Shl for UInt { + type Output = Self; + + /// Output `self << other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1u8; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; + /// + /// (a << b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl(self, other: T2) -> Self::Output { + self._shl_u128(other.into()).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl(self, other: T2) -> Self::Output { + self._shl_u128(other.into()).unwrap() + } +} + +impl ShlAssign for UInt { + /// Sets `self = self << other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1u8; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?; + /// + /// a <<= b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shl_assign(&mut self, other: T2) { + let result = self._shl_u128(other.into()).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_shl( + a: UInt, + b: T, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let b = b.into() % (N as u128); + let computed = &a << b; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + UInt::::new_variable(cs.clone(), || Ok(a.value()? << b), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_shl() { + run_binary_exhaustive_with_native(uint_shl::).unwrap() + } + + #[test] + fn u16_shl() { + run_binary_random_with_native::<1000, 16, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u32_shl() { + run_binary_random_with_native::<1000, 32, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u64_shl() { + run_binary_random_with_native::<1000, 64, _, _>(uint_shl::).unwrap() + } + + #[test] + fn u128_shl() { + run_binary_random_with_native::<1000, 128, _, _>(uint_shl::).unwrap() + } +} diff --git a/src/uint/shr.rs b/src/uint/shr.rs new file mode 100644 index 00000000..7630855c --- /dev/null +++ b/src/uint/shr.rs @@ -0,0 +1,154 @@ +use ark_ff::PrimeField; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::Shr, ops::ShrAssign}; + +use crate::boolean::Boolean; + +use super::{PrimUInt, UInt}; + +impl UInt { + fn _shr_u128(&self, other: u128) -> Result { + if other < N as u128 { + let mut bits = [Boolean::FALSE; N]; + for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) { + *a = b.clone(); + } + + let value = self.value.and_then(|a| Some(a >> other)); + Ok(Self { bits, value }) + } else { + panic!("attempt to shift right with overflow") + } + } +} + +impl Shr for UInt { + type Output = Self; + + /// Output `self >> other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1u8; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; + /// + /// (a >> b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr(self, other: T2) -> Self::Output { + self._shr_u128(other.into()).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr(self, other: T2) -> Self::Output { + self._shr_u128(other.into()).unwrap() + } +} + +impl ShrAssign for UInt { + /// Sets `self = self >> other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = 1u8; + /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?; + /// + /// a >>= b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn shr_assign(&mut self, other: T2) { + let result = self._shr_u128(other.into()).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_shr( + a: UInt, + b: T, + ) -> Result<(), SynthesisError> { + let cs = a.cs(); + let b = b.into() % (N as u128); + let computed = &a >> b; + let expected_mode = if a.is_constant() { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = + UInt::::new_variable(cs.clone(), || Ok(a.value()? >> b), expected_mode)?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !a.is_constant() { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_shr() { + run_binary_exhaustive_with_native(uint_shr::).unwrap() + } + + #[test] + fn u16_shr() { + run_binary_random_with_native::<1000, 16, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u32_shr() { + run_binary_random_with_native::<1000, 32, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u64_shr() { + run_binary_random_with_native::<1000, 64, _, _>(uint_shr::).unwrap() + } + + #[test] + fn u128_shr() { + run_binary_random_with_native::<1000, 128, _, _>(uint_shr::).unwrap() + } +} diff --git a/src/uint/test_utils.rs b/src/uint/test_utils.rs new file mode 100644 index 00000000..0600bbeb --- /dev/null +++ b/src/uint/test_utils.rs @@ -0,0 +1,144 @@ +use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; +use std::ops::RangeInclusive; + +use crate::test_utils::{self, modes}; + +use super::*; + +pub(crate) fn test_unary_op( + a: T, + mode: AllocationMode, + test: impl FnOnce(UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode)?; + test(a) +} + +pub(crate) fn test_binary_op( + a: T, + b: T, + mode_a: AllocationMode, + mode_b: AllocationMode, + test: impl FnOnce(UInt, UInt) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode_a)?; + let b = UInt::::new_variable(cs.clone(), || Ok(b), mode_b)?; + test(a, b) +} + +pub(crate) fn test_binary_op_with_native( + a: T, + b: T, + mode_a: AllocationMode, + test: impl FnOnce(UInt, T) -> Result<(), SynthesisError>, +) -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let a = UInt::::new_variable(cs.clone(), || Ok(a), mode_a)?; + test(a, b) +} + +pub(crate) fn run_binary_random( + test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for _ in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + for mode_b in modes() { + let b = T::rand(&mut rng); + test_binary_op(a, b, mode_a, mode_b, test)?; + } + } + } + Ok(()) +} + +pub(crate) fn run_binary_exhaustive( + test: impl Fn(UInt, UInt) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) { + for (mode_b, b) in test_utils::combination(T::min_value()..=T::max_value()) { + test_binary_op(a, b, mode_a, mode_b, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_binary_random_with_native( + test: impl Fn(UInt, T) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for _ in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + let b = T::rand(&mut rng); + test_binary_op_with_native(a, b, mode_a, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_binary_exhaustive_with_native( + test: impl Fn(UInt, T) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) { + for b in T::min_value()..=T::max_value() { + test_binary_op_with_native(a, b, mode_a, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_unary_random( + test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, +{ + let mut rng = ark_std::test_rng(); + + for _ in 0..ITERATIONS { + for mode_a in modes() { + let a = T::rand(&mut rng); + test_unary_op(a, mode_a, test)?; + } + } + Ok(()) +} + +pub(crate) fn run_unary_exhaustive( + test: impl Fn(UInt) -> Result<(), SynthesisError> + Copy, +) -> Result<(), SynthesisError> +where + T: PrimUInt, + F: PrimeField, + RangeInclusive: Iterator, +{ + for (mode, a) in test_utils::combination(T::min_value()..=T::max_value()) { + test_unary_op(a, mode, test)?; + } + Ok(()) +} diff --git a/src/uint/xor.rs b/src/uint/xor.rs new file mode 100644 index 00000000..22f52398 --- /dev/null +++ b/src/uint/xor.rs @@ -0,0 +1,175 @@ +use ark_ff::Field; +use ark_relations::r1cs::SynthesisError; +use ark_std::{ops::BitXor, ops::BitXorAssign}; + +use super::*; + +impl UInt { + fn _xor(&self, other: &Self) -> Result { + let mut result = self.clone(); + for (a, b) in result.bits.iter_mut().zip(&other.bits) { + *a ^= b; + } + result.value = self.value.and_then(|a| Some(a ^ other.value?)); + Ok(result) + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor for &'a UInt { + type Output = UInt; + /// Outputs `self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// (a ^ &b).enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a Self> for UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: &Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor> for &'a UInt { + type Output = UInt; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: UInt) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXor for UInt { + type Output = Self; + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor(self, other: Self) -> Self::Output { + self._xor(&other).unwrap() + } +} + +impl BitXorAssign for UInt { + /// Sets `self = self ^ other`. + /// + /// If at least one of `self` and `other` are constants, then this method + /// *does not* create any constraints or variables. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?; + /// let b = UInt8::new_witness(cs.clone(), || Ok(17))?; + /// let c = UInt8::new_witness(cs.clone(), || Ok(1))?; + /// + /// a ^= b; + /// a.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: Self) { + let result = self._xor(&other).unwrap(); + *self = result; + } +} + +impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a Self> for UInt { + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn bitxor_assign(&mut self, other: &'a Self) { + let result = self._xor(other).unwrap(); + *self = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alloc::{AllocVar, AllocationMode}, + prelude::EqGadget, + uint::test_utils::{run_binary_exhaustive, run_binary_random}, + R1CSVar, + }; + use ark_ff::PrimeField; + use ark_test_curves::bls12_381::Fr; + + fn uint_xor( + a: UInt, + b: UInt, + ) -> Result<(), SynthesisError> { + let cs = a.cs().or(b.cs()); + let both_constant = a.is_constant() && b.is_constant(); + let computed = &a ^ &b; + let expected_mode = if both_constant { + AllocationMode::Constant + } else { + AllocationMode::Witness + }; + let expected = UInt::::new_variable( + cs.clone(), + || Ok(a.value()? ^ b.value()?), + expected_mode, + )?; + assert_eq!(expected.value(), computed.value()); + expected.enforce_equal(&computed)?; + if !both_constant { + assert!(cs.is_satisfied().unwrap()); + } + Ok(()) + } + + #[test] + fn u8_xor() { + run_binary_exhaustive(uint_xor::).unwrap() + } + + #[test] + fn u16_xor() { + run_binary_random::<1000, 16, _, _>(uint_xor::).unwrap() + } + + #[test] + fn u32_xor() { + run_binary_random::<1000, 32, _, _>(uint_xor::).unwrap() + } + + #[test] + fn u64_xor() { + run_binary_random::<1000, 64, _, _>(uint_xor::).unwrap() + } + + #[test] + fn u128_xor() { + run_binary_random::<1000, 128, _, _>(uint_xor::).unwrap() + } +} diff --git a/src/uint8.rs b/src/uint8.rs new file mode 100644 index 00000000..1f952de4 --- /dev/null +++ b/src/uint8.rs @@ -0,0 +1,283 @@ +use ark_ff::{Field, PrimeField, ToConstraintField}; + +use ark_relations::r1cs::{Namespace, SynthesisError}; + +use crate::{ + convert::{ToBitsGadget, ToConstraintFieldGadget}, + fields::fp::{AllocatedFp, FpVar}, + prelude::*, + Vec, +}; + +pub type UInt8 = super::uint::UInt<8, u8, F>; + +impl UInt8 { + /// Allocates a slice of `u8`'s as public inputs by first packing them into + /// elements of `F`, (thus reducing the number of input allocations), + /// allocating these elements as public inputs, and then converting + /// these field variables `FpVar` variables back into bytes. + /// + /// From a user perspective, this trade-off adds constraints, but improves + /// verifier time and verification key size. + /// + /// ``` + /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { + /// // We'll use the BLS12-381 scalar field for our constraints. + /// use ark_test_curves::bls12_381::Fr; + /// use ark_relations::r1cs::*; + /// use ark_r1cs_std::prelude::*; + /// + /// let cs = ConstraintSystem::::new_ref(); + /// let two = UInt8::new_witness(cs.clone(), || Ok(2))?; + /// let var = vec![two.clone(); 32]; + /// + /// let c = UInt8::new_input_vec(cs.clone(), &[2; 32])?; + /// var.enforce_equal(&c)?; + /// assert!(cs.is_satisfied().unwrap()); + /// # Ok(()) + /// # } + /// ``` + pub fn new_input_vec( + cs: impl Into>, + values: &[u8], + ) -> Result, SynthesisError> + where + F: PrimeField, + { + let ns = cs.into(); + let cs = ns.cs(); + let values_len = values.len(); + let field_elements: Vec = ToConstraintField::::to_field_elements(values).unwrap(); + + let max_size = 8 * ((F::MODULUS_BIT_SIZE - 1) / 8) as usize; + let mut allocated_bits = Vec::new(); + for field_element in field_elements.into_iter() { + let fe = AllocatedFp::new_input(cs.clone(), || Ok(field_element))?; + let fe_bits = fe.to_bits_le()?; + + // Remove the most significant bit, because we know it should be zero + // because `values.to_field_elements()` only + // packs field elements up to the penultimate bit. + // That is, the most significant bit (`ConstraintF::NUM_BITS`-th bit) is + // unset, so we can just pop it off. + allocated_bits.extend_from_slice(&fe_bits[0..max_size]); + } + + // Chunk up slices of 8 bit into bytes. + Ok(allocated_bits[0..(8 * values_len)] + .chunks(8) + .map(Self::from_bits_le) + .collect()) + } +} + +/// Parses the `Vec>` in fixed-sized +/// `ConstraintF::MODULUS_BIT_SIZE - 1` chunks and converts each chunk, which is +/// assumed to be little-endian, to its `FpVar` representation. +/// This is the gadget counterpart to the `[u8]` implementation of +/// [`ToConstraintField``]. +impl ToConstraintFieldGadget for [UInt8] { + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + let max_size = ((ConstraintF::MODULUS_BIT_SIZE - 1) / 8) as usize; + self.chunks(max_size) + .map(|chunk| Boolean::le_bits_to_fp(chunk.to_bits_le()?.as_slice())) + .collect::, SynthesisError>>() + } +} + +impl ToConstraintFieldGadget for Vec> { + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> Result>, SynthesisError> { + self.as_slice().to_constraint_field() + } +} + +#[cfg(test)] +mod test { + use super::UInt8; + use crate::{ + convert::{ToBitsGadget, ToConstraintFieldGadget}, + fields::fp::FpVar, + prelude::{ + AllocationMode::{Constant, Input, Witness}, + *, + }, + Vec, + }; + use ark_ff::{PrimeField, ToConstraintField}; + use ark_relations::r1cs::{ConstraintSystem, SynthesisError}; + use ark_std::rand::{distributions::Uniform, Rng}; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn test_uint8_from_bits_to_bits() -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let byte_val = 0b01110001; + let byte = + UInt8::new_witness(ark_relations::ns!(cs, "alloc value"), || Ok(byte_val)).unwrap(); + let bits = byte.to_bits_le()?; + for (i, bit) in bits.iter().enumerate() { + assert_eq!(bit.value()?, (byte_val >> i) & 1 == 1) + } + Ok(()) + } + + #[test] + fn test_uint8_new_input_vec() -> Result<(), SynthesisError> { + let cs = ConstraintSystem::::new_ref(); + let byte_vals = (64u8..128u8).collect::>(); + let bytes = + UInt8::new_input_vec(ark_relations::ns!(cs, "alloc value"), &byte_vals).unwrap(); + for (native, variable) in byte_vals.into_iter().zip(bytes) { + let bits = variable.to_bits_le()?; + for (i, bit) in bits.iter().enumerate() { + assert_eq!( + bit.value()?, + (native >> i) & 1 == 1, + "native value {}: bit {:?}", + native, + i + ) + } + } + Ok(()) + } + + #[test] + fn test_uint8_from_bits() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let v = (0..8) + .map(|_| Boolean::::Constant(rng.gen())) + .collect::>(); + + let val = UInt8::from_bits_le(&v); + + let value = val.value()?; + for (i, bit) in val.bits.iter().enumerate() { + match bit { + Boolean::Constant(b) => assert_eq!(*b, ((value >> i) & 1 == 1)), + _ => unreachable!(), + } + } + + let expected_to_be_same = val.to_bits_le()?; + + for x in v.iter().zip(expected_to_be_same.iter()) { + match x { + (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, + (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, + _ => unreachable!(), + } + } + } + Ok(()) + } + + #[test] + fn test_uint8_xor() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + + for _ in 0..1000 { + let cs = ConstraintSystem::::new_ref(); + + let a: u8 = rng.gen(); + let b: u8 = rng.gen(); + let c: u8 = rng.gen(); + + let mut expected = a ^ b ^ c; + + let a_bit = UInt8::new_witness(ark_relations::ns!(cs, "a_bit"), || Ok(a)).unwrap(); + let b_bit = UInt8::constant(b); + let c_bit = UInt8::new_witness(ark_relations::ns!(cs, "c_bit"), || Ok(c)).unwrap(); + + let mut r = a_bit ^ b_bit; + r ^= &c_bit; + + assert!(cs.is_satisfied().unwrap()); + + assert_eq!(r.value, Some(expected)); + + for b in r.bits.iter() { + match b { + Boolean::Var(b) => assert!(b.value()? == (expected & 1 == 1)), + Boolean::Constant(b) => assert!(*b == (expected & 1 == 1)), + } + + expected >>= 1; + } + } + Ok(()) + } + + #[test] + fn test_uint8_to_constraint_field() -> Result<(), SynthesisError> { + let mut rng = ark_std::test_rng(); + let max_size = ((::MODULUS_BIT_SIZE - 1) / 8) as usize; + + let modes = [Input, Witness, Constant]; + for mode in &modes { + for _ in 0..1000 { + let cs = ConstraintSystem::::new_ref(); + + let bytes: Vec = (&mut rng) + .sample_iter(&Uniform::new_inclusive(0, u8::max_value())) + .take(max_size * 3 + 5) + .collect(); + + let bytes_var = bytes + .iter() + .map(|byte| UInt8::new_variable(cs.clone(), || Ok(*byte), *mode)) + .collect::, SynthesisError>>()?; + + let f_vec: Vec = bytes.to_field_elements().unwrap(); + let f_var_vec: Vec> = bytes_var.to_constraint_field()?; + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(f_vec, f_var_vec.value()?); + } + } + + Ok(()) + } + + #[test] + fn test_uint8_random_access() { + let mut rng = ark_std::test_rng(); + + for _ in 0..100 { + let cs = ConstraintSystem::::new_ref(); + + // value array + let values: Vec = (0..128).map(|_| rng.gen()).collect(); + let values_const: Vec> = values.iter().map(|x| UInt8::constant(*x)).collect(); + + // index array + let position: Vec = (0..7).map(|_| rng.gen()).collect(); + let position_var: Vec> = position + .iter() + .map(|b| { + Boolean::new_witness(ark_relations::ns!(cs, "index_arr_element"), || Ok(*b)) + .unwrap() + }) + .collect(); + + // index + let mut index = 0; + for x in position { + index *= 2; + index += if x { 1 } else { 0 }; + } + + assert_eq!( + UInt8::conditionally_select_power_of_two_vector(&position_var, &values_const) + .unwrap() + .value() + .unwrap(), + values[index] + ) + } + } +} diff --git a/tests/to_constraint_field_test.rs b/tests/to_constraint_field_test.rs index db985b9d..75c79c34 100644 --- a/tests/to_constraint_field_test.rs +++ b/tests/to_constraint_field_test.rs @@ -1,5 +1,5 @@ use ark_r1cs_std::{ - alloc::AllocVar, fields::emulated_fp::EmulatedFpVar, R1CSVar, ToConstraintFieldGadget, + alloc::AllocVar, convert::ToConstraintFieldGadget, fields::emulated_fp::EmulatedFpVar, R1CSVar, }; use ark_relations::r1cs::ConstraintSystem;