Skip to content

Commit

Permalink
halo2_poseidon: Refactor code so it compiles in its new crate
Browse files Browse the repository at this point in the history
  • Loading branch information
str4d committed Dec 16, 2024
1 parent de7219d commit 0d88368
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 58 deletions.
28 changes: 9 additions & 19 deletions halo2_gadgets/src/poseidon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! The Poseidon algebraic hash function.
use std::convert::TryInto;
use std::fmt;
use std::marker::PhantomData;

Expand Down Expand Up @@ -148,13 +147,7 @@ impl<
pub fn new(chip: PoseidonChip, mut layouter: impl Layouter<F>) -> Result<Self, Error> {
chip.initial_state(&mut layouter).map(|state| Sponge {
chip,
mode: Absorbing(
(0..RATE)
.map(|_| None)
.collect::<Vec<_>>()
.try_into()
.unwrap(),
),
mode: Absorbing::init_empty(),
state,
_marker: PhantomData::default(),

Check warning on line 152 in halo2_gadgets/src/poseidon.rs

View workflow job for this annotation

GitHub Actions / Clippy (beta)

use of `default` to create a unit struct

warning: use of `default` to create a unit struct --> halo2_gadgets/src/poseidon.rs:152:33 | 152 | _marker: PhantomData::default(), | ^^^^^^^^^^^ help: remove this call to `default` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#default_constructed_unit_structs = note: `-W clippy::default-constructed-unit-structs` implied by `-W clippy::all` = help: to override `-W clippy::all` add `#[allow(clippy::default_constructed_unit_structs)]`

Check warning on line 152 in halo2_gadgets/src/poseidon.rs

View workflow job for this annotation

GitHub Actions / Clippy (beta)

use of `default` to create a unit struct

warning: use of `default` to create a unit struct --> halo2_gadgets/src/poseidon.rs:152:33 | 152 | _marker: PhantomData::default(), | ^^^^^^^^^^^ help: remove this call to `default` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#default_constructed_unit_structs = note: `-W clippy::default-constructed-unit-structs` implied by `-W clippy::all` = help: to override `-W clippy::all` add `#[allow(clippy::default_constructed_unit_structs)]`
})
Expand All @@ -166,12 +159,10 @@ impl<
mut layouter: impl Layouter<F>,
value: PaddedWord<F>,
) -> Result<(), Error> {
for entry in self.mode.0.iter_mut() {
if entry.is_none() {
*entry = Some(value);
return Ok(());
}
}
let value = match self.mode.absorb(value) {
Ok(()) => return Ok(()),
Err(value) => value,
};

// We've already absorbed as many elements as we can
let _ = poseidon_sponge(
Expand All @@ -180,7 +171,8 @@ impl<
&mut self.state,
Some(&self.mode),
)?;
self.mode = Absorbing::init_with(value);
self.mode = Absorbing::init_empty();
self.mode.absorb(value).expect("state is not full");

Ok(())
}
Expand Down Expand Up @@ -220,10 +212,8 @@ impl<
/// Squeezes an element from the sponge.
pub fn squeeze(&mut self, mut layouter: impl Layouter<F>) -> Result<AssignedCell<F, F>, Error> {
loop {
for entry in self.mode.0.iter_mut() {
if let Some(inner) = entry.take() {
return Ok(inner.into());
}
if let Some(value) = self.mode.squeeze() {
return Ok(value.into());
}

// We've already squeezed out all available elements
Expand Down
28 changes: 14 additions & 14 deletions halo2_gadgets/src/poseidon/pow5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,19 +340,20 @@ impl<
let initial_state = initial_state?;

// Load the input into this region.
let load_input_word = |i: usize| {
let (cell, value) = match input.0[i].clone() {
let load_input_word = |(i, input_word): (usize, &Option<PaddedWord<F>>)| {
let (cell, value) = match input_word {
Some(PaddedWord::Message(word)) => (word.cell(), word.value().copied()),
Some(PaddedWord::Padding(padding_value)) => {
let value = Value::known(*padding_value);
let cell = region
.assign_fixed(
|| format!("load pad_{}", i),
config.rc_b[i],
1,
|| Value::known(padding_value),
|| value,
)?
.cell();
(cell, Value::known(padding_value))
(cell, value)
}
_ => panic!("Input is not padded"),
};
Expand All @@ -366,7 +367,12 @@ impl<

Ok(StateWord(var))
};
let input: Result<Vec<_>, Error> = (0..RATE).map(load_input_word).collect();
let input: Result<Vec<_>, Error> = input
.expose_inner()
.iter()
.enumerate()
.map(load_input_word)
.collect();
let input = input?;

// Constrain the output.
Expand Down Expand Up @@ -394,14 +400,8 @@ impl<
}

fn get_output(state: &State<Self::Word, WIDTH>) -> Squeezing<Self::Word, RATE> {
Squeezing(
state[..RATE]
.iter()
.map(|word| Some(word.clone()))
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
let vals = state[..RATE].to_vec();
Squeezing::init_full(vals.try_into().expect("correct length"))
}
}

Expand Down Expand Up @@ -687,7 +687,7 @@ mod tests {
.try_into()
.unwrap();
let (round_constants, mds, _) = S::constants();
poseidon::permute::<_, S, WIDTH, RATE>(
poseidon::test_only_permute::<_, S, WIDTH, RATE>(
&mut expected_final_state,
&mds,
&round_constants,
Expand Down
88 changes: 83 additions & 5 deletions halo2_poseidon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ pub(crate) mod fq;
pub(crate) mod grain;
pub(crate) mod mds;

#[cfg(test)]
pub(crate) mod test_vectors;
#[cfg(any(test, feature = "test-dependencies"))]
pub mod test_vectors;

mod p128pow5t3;
pub use p128pow5t3::P128Pow5T3;

use grain::SboxType;

/// The type used to hold permutation state.
pub(crate) type State<F, const T: usize> = [F; T];
pub type State<F, const T: usize> = [F; T];

/// The type used to hold sponge rate.
pub(crate) type SpongeRate<F, const RATE: usize> = [Option<F>; RATE];
Expand Down Expand Up @@ -83,6 +83,18 @@ pub fn generate_constants<
(round_constants, mds, mds_inv)
}

/// Runs the Poseidon permutation on the given state.
///
/// Exposed for testing purposes only.
#[cfg(feature = "test-dependencies")]
pub fn test_only_permute<F: Field, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
state: &mut State<F, T>,
mds: &Mds<F, T>,
round_constants: &[[F; T]],
) {
permute::<F, S, T, RATE>(state, mds, round_constants);
}

/// Runs the Poseidon permutation on the given state.
pub(crate) fn permute<F: Field, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
state: &mut State<F, T>,
Expand Down Expand Up @@ -176,16 +188,82 @@ impl<F, const RATE: usize> SpongeMode for Squeezing<F, RATE> {}

impl<F: fmt::Debug, const RATE: usize> Absorbing<F, RATE> {
pub(crate) fn init_with(val: F) -> Self {
let mut state = Self::init_empty();
state.absorb(val).expect("state is not full");
state
}

/// Initializes an empty sponge in the absorbing state.
pub fn init_empty() -> Self {
Self(
iter::once(Some(val))
.chain((1..RATE).map(|_| None))
(0..RATE)
.map(|_| None)
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
}
}

impl<F, const RATE: usize> Absorbing<F, RATE> {
/// Attempts to absorb a value into the sponge state.
///
/// Returns the value if it was not absorbed because the sponge is full.
pub fn absorb(&mut self, value: F) -> Result<(), F> {
for entry in self.0.iter_mut() {
if entry.is_none() {
*entry = Some(value);
return Ok(());
}
}
// Sponge is full.
Err(value)
}

/// Exposes the inner state of the sponge.
///
/// This is a low-level API, requiring a detailed understanding of this specific
/// Poseidon implementation to use correctly and securely. It is exposed for use by
/// the circuit implementation in `halo2_gadgets`, and may be removed from the public
/// API if refactoring enables the circuit implementation to move into this crate.
pub fn expose_inner(&self) -> &SpongeRate<F, RATE> {
&self.0
}
}

impl<F: fmt::Debug, const RATE: usize> Squeezing<F, RATE> {
/// Initializes a full sponge in the squeezing state.
///
/// This is a low-level API, requiring a detailed understanding of this specific
/// Poseidon implementation to use correctly and securely. It is exposed for use by
/// the circuit implementation in `halo2_gadgets`, and may be removed from the public
/// API if refactoring enables the circuit implementation to move into this crate.
pub fn init_full(vals: [F; RATE]) -> Self {
Self(
vals.into_iter()
.map(Some)
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
}
}

impl<F, const RATE: usize> Squeezing<F, RATE> {
/// Attempts to squeeze a value from the sponge state.
///
/// Returns `None` if the sponge is empty.
pub fn squeeze(&mut self) -> Option<F> {
for entry in self.0.iter_mut() {
if let Some(inner) = entry.take() {
return Some(inner);
}
}
// Sponge is empty.
None
}
}

/// A Poseidon sponge.
pub(crate) struct Sponge<
F: Field,
Expand Down
14 changes: 6 additions & 8 deletions halo2_poseidon/src/p128pow5t3.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use halo2_proofs::arithmetic::Field;
use ff::Field;
use pasta_curves::{pallas::Base as Fp, vesta::Base as Fq};

use super::{Mds, Spec};
Expand Down Expand Up @@ -73,9 +73,7 @@ mod tests {
super::{fp, fq},
Fp, Fq,
};
use crate::poseidon::primitives::{
generate_constants, permute, ConstantLength, Hash, Mds, Spec,
};
use crate::{generate_constants, permute, ConstantLength, Hash, Mds, Spec};

/// The same Poseidon specification as poseidon::P128Pow5T3, but constructed
/// such that its constants will be generated at runtime.
Expand Down Expand Up @@ -257,7 +255,7 @@ mod tests {
{
let (round_constants, mds, _) = super::P128Pow5T3::constants();

for tv in crate::poseidon::primitives::test_vectors::fp::permute() {
for tv in crate::test_vectors::fp::permute() {
let mut state = [
Fp::from_repr(tv.initial_state[0]).unwrap(),
Fp::from_repr(tv.initial_state[1]).unwrap(),
Expand All @@ -275,7 +273,7 @@ mod tests {
{
let (round_constants, mds, _) = super::P128Pow5T3::constants();

for tv in crate::poseidon::primitives::test_vectors::fq::permute() {
for tv in crate::test_vectors::fq::permute() {
let mut state = [
Fq::from_repr(tv.initial_state[0]).unwrap(),
Fq::from_repr(tv.initial_state[1]).unwrap(),
Expand All @@ -293,7 +291,7 @@ mod tests {

#[test]
fn hash_test_vectors() {
for tv in crate::poseidon::primitives::test_vectors::fp::hash() {
for tv in crate::test_vectors::fp::hash() {
let message = [
Fp::from_repr(tv.input[0]).unwrap(),
Fp::from_repr(tv.input[1]).unwrap(),
Expand All @@ -305,7 +303,7 @@ mod tests {
assert_eq!(result.to_repr(), tv.output);
}

for tv in crate::poseidon::primitives::test_vectors::fq::hash() {
for tv in crate::test_vectors::fq::hash() {
let message = [
Fq::from_repr(tv.input[0]).unwrap(),
Fq::from_repr(tv.input[1]).unwrap(),
Expand Down
24 changes: 12 additions & 12 deletions halo2_poseidon/src/test_vectors.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
//! Test vectors for [`OrchardNullifier`].

Check warning on line 1 in halo2_poseidon/src/test_vectors.rs

View workflow job for this annotation

GitHub Actions / Intra-doc links

unresolved link to `OrchardNullifier`

Check warning on line 1 in halo2_poseidon/src/test_vectors.rs

View workflow job for this annotation

GitHub Actions / Intra-doc links

unresolved link to `OrchardNullifier`
pub(crate) struct PermuteTestVector {
pub(crate) initial_state: [[u8; 32]; 3],
pub(crate) final_state: [[u8; 32]; 3],
pub struct PermuteTestVector {
pub initial_state: [[u8; 32]; 3],
pub final_state: [[u8; 32]; 3],
}

pub(crate) struct HashTestVector {
pub(crate) input: [[u8; 32]; 2],
pub(crate) output: [u8; 32],
pub struct HashTestVector {
pub input: [[u8; 32]; 2],
pub output: [u8; 32],
}

pub(crate) mod fp {
pub mod fp {
use super::*;

pub(crate) fn permute() -> Vec<PermuteTestVector> {
pub fn permute() -> Vec<PermuteTestVector> {
use PermuteTestVector as TestVector;

// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/permute/fp.py
Expand Down Expand Up @@ -417,7 +417,7 @@ pub(crate) mod fp {
]
}

pub(crate) fn hash() -> Vec<HashTestVector> {
pub fn hash() -> Vec<HashTestVector> {
use HashTestVector as TestVector;

// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/hash/fp.py
Expand Down Expand Up @@ -635,10 +635,10 @@ pub(crate) mod fp {
}
}

pub(crate) mod fq {
pub mod fq {
use super::*;

pub(crate) fn permute() -> Vec<PermuteTestVector> {
pub fn permute() -> Vec<PermuteTestVector> {
use PermuteTestVector as TestVector;

// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/permute/fq.py
Expand Down Expand Up @@ -1042,7 +1042,7 @@ pub(crate) mod fq {
]
}

pub(crate) fn hash() -> Vec<HashTestVector> {
pub fn hash() -> Vec<HashTestVector> {
use HashTestVector as TestVector;

// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/hash/fq.py
Expand Down

0 comments on commit 0d88368

Please sign in to comment.