Skip to content

Commit

Permalink
Merge pull request #24 from reilabs/generic-merkle
Browse files Browse the repository at this point in the history
Generalize the type of merkle roots
  • Loading branch information
WizardOfMenlo authored Nov 29, 2024
2 parents 9a328c5 + 654cbf9 commit 2a52497
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 36 deletions.
7 changes: 6 additions & 1 deletion src/bin/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ark_crypto_primitives::{
};
use ark_ff::{FftField, Field};
use ark_serialize::CanonicalSerialize;
use nimue::{DefaultHash, IOPattern};
use nimue::{Arthur, DefaultHash, IOPattern, Merlin};
use nimue_pow::blake3::Blake3PoW;
use whir::{
cmdline_utils::{AvailableFields, AvailableMerkle},
Expand All @@ -25,6 +25,8 @@ use whir::{
use serde::Serialize;

use clap::Parser;
use whir::whir::fs_utils::{DigestReader, DigestWriter};
use whir::whir::iopattern::DigestIOPattern;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -211,6 +213,9 @@ fn run_whir<F, MerkleConfig>(
F: FftField + CanonicalSerialize,
MerkleConfig: Config<Leaf = [F]> + Clone,
MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>,
IOPattern: DigestIOPattern<MerkleConfig>,
Merlin: DigestWriter<MerkleConfig>,
for<'a> Arthur<'a>: DigestReader<MerkleConfig>,
{
let security_level = args.security_level;
let pow_bits = args.pow_bits.unwrap();
Expand Down
13 changes: 12 additions & 1 deletion src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ark_crypto_primitives::{
};
use ark_ff::FftField;
use ark_serialize::CanonicalSerialize;
use nimue::{DefaultHash, IOPattern};
use nimue::{Arthur, DefaultHash, IOPattern, Merlin};
use whir::{
cmdline_utils::{AvailableFields, AvailableMerkle, WhirType},
crypto::{
Expand All @@ -21,6 +21,8 @@ use whir::{
use nimue_pow::blake3::Blake3PoW;

use clap::Parser;
use whir::whir::fs_utils::{DigestReader, DigestWriter};
use whir::whir::iopattern::DigestIOPattern;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -182,6 +184,9 @@ fn run_whir<F, MerkleConfig>(
F: FftField + CanonicalSerialize,
MerkleConfig: Config<Leaf = [F]> + Clone,
MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>,
IOPattern: DigestIOPattern<MerkleConfig>,
Merlin: DigestWriter<MerkleConfig>,
for<'a> Arthur<'a>: DigestReader<MerkleConfig>,
{
match args.protocol_type {
WhirType::PCS => run_whir_pcs::<F, MerkleConfig>(args, leaf_hash_params, two_to_one_params),
Expand All @@ -199,6 +204,9 @@ fn run_whir_as_ldt<F, MerkleConfig>(
F: FftField + CanonicalSerialize,
MerkleConfig: Config<Leaf = [F]> + Clone,
MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>,
IOPattern: DigestIOPattern<MerkleConfig>,
Merlin: DigestWriter<MerkleConfig>,
for<'a> Arthur<'a>: DigestReader<MerkleConfig>,
{
use whir::whir::{
committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover,
Expand Down Expand Up @@ -303,6 +311,9 @@ fn run_whir_pcs<F, MerkleConfig>(
F: FftField + CanonicalSerialize,
MerkleConfig: Config<Leaf = [F]> + Clone,
MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>,
IOPattern: DigestIOPattern<MerkleConfig>,
Merlin: DigestWriter<MerkleConfig>,
for<'a> Arthur<'a>: DigestReader<MerkleConfig>,
{
use whir::whir::{
committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover,
Expand Down
24 changes: 24 additions & 0 deletions src/crypto/merkle_tree/blake3.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::{borrow::Borrow, marker::PhantomData};

use super::{HashCounter, IdentityDigestConverter};
use crate::whir::fs_utils::{DigestReader, DigestWriter};
use crate::whir::iopattern::DigestIOPattern;
use ark_crypto_primitives::{
crh::{CRHScheme, TwoToOneCRHScheme},
merkle_tree::Config,
sponge::Absorb,
};
use ark_ff::Field;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use nimue::{Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult};
use rand::RngCore;

#[derive(
Expand Down Expand Up @@ -127,3 +131,23 @@ pub fn default_config<F: CanonicalSerialize + Send>(

((), ())
}

impl<F: Field> DigestIOPattern<MerkleTreeParams<F>> for IOPattern {
fn add_digest(self, label: &str) -> Self {
self.add_bytes(32, label)
}
}

impl<F: Field> DigestWriter<MerkleTreeParams<F>> for Merlin {
fn add_digest(&mut self, digest: Blake3Digest) -> ProofResult<()> {
self.add_bytes(&digest.0).map_err(ProofError::InvalidIO)
}
}

impl <'a, F: Field> DigestReader<MerkleTreeParams<F>> for Arthur<'a> {
fn read_digest(&mut self) -> ProofResult<Blake3Digest> {
let mut digest = [0; 32];
self.fill_next_bytes(&mut digest)?;
Ok(Blake3Digest(digest))
}
}
24 changes: 24 additions & 0 deletions src/crypto/merkle_tree/keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ use ark_crypto_primitives::{
merkle_tree::Config,
sponge::Absorb,
};
use ark_ff::Field;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use nimue::{Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult};
use rand::RngCore;
use sha3::Digest;
use crate::whir::fs_utils::{DigestReader, DigestWriter};
use crate::whir::iopattern::DigestIOPattern;

#[derive(
Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize,
Expand Down Expand Up @@ -128,3 +132,23 @@ pub fn default_config<F: CanonicalSerialize + Send>(

((), ())
}

impl <F: Field> DigestIOPattern<MerkleTreeParams<F>> for IOPattern {
fn add_digest(self, label: &str) -> Self {
self.add_bytes(32, label)
}
}

impl <F: Field> DigestWriter<MerkleTreeParams<F>> for Merlin {
fn add_digest(&mut self, digest: KeccakDigest) -> ProofResult<()> {
self.add_bytes(&digest.0).map_err(ProofError::InvalidIO)
}
}

impl <'a, F: Field> DigestReader<MerkleTreeParams<F>> for Arthur<'a> {
fn read_digest(&mut self) -> ProofResult<KeccakDigest> {
let mut digest = [0; 32];
self.fill_next_bytes(&mut digest)?;
Ok(KeccakDigest(digest))
}
}
2 changes: 1 addition & 1 deletion src/fs_utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ark_ff::Field;
use nimue::{plugins::ark::FieldIOPattern, IOPattern};
use nimue::plugins::ark::FieldIOPattern;
use nimue_pow::PoWIOPattern;
pub trait OODIOPattern<F: Field> {
fn add_ood(self, num_samples: usize) -> Self;
Expand Down
2 changes: 1 addition & 1 deletion src/sumcheck/prover_not_skipping.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ark_ff::Field;
use nimue::{plugins::ark::{FieldChallenges, FieldIOPattern, FieldWriter}, IOPattern, ProofResult};
use nimue::{plugins::ark::{FieldChallenges, FieldIOPattern, FieldWriter}, ProofResult};
use nimue_pow::{PoWChallenge, PowStrategy};

use crate::{
Expand Down
10 changes: 5 additions & 5 deletions src/whir/committer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use ark_ff::FftField;
use ark_poly::EvaluationDomain;
use nimue::{
plugins::ark::{FieldChallenges, FieldWriter},
ByteWriter, Merlin, ProofResult,
ByteWriter, ProofResult,
};

use crate::whir::fs_utils::DigestWriter;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand All @@ -34,8 +35,7 @@ where
impl<F, MerkleConfig, PowStrategy> Committer<F, MerkleConfig, PowStrategy>
where
F: FftField,
MerkleConfig: Config<Leaf = [F]>,
MerkleConfig::InnerDigest: AsRef<[u8]>,
MerkleConfig: Config<Leaf = [F]>
{
pub fn new(config: WhirConfig<F, MerkleConfig, PowStrategy>) -> Self {
Self(config)
Expand All @@ -47,7 +47,7 @@ where
polynomial: CoefficientList<F::BasePrimeField>,
) -> ProofResult<Witness<F, MerkleConfig>>
where
Merlin: FieldWriter<F> + FieldChallenges<F> + ByteWriter,
Merlin: FieldWriter<F> + FieldChallenges<F> + ByteWriter + DigestWriter<MerkleConfig>,
{
let base_domain = self.0.starting_domain.base_domain.unwrap();
let expansion = base_domain.size() / polynomial.num_coeffs();
Expand Down Expand Up @@ -88,7 +88,7 @@ where

let root = merkle_tree.root();

merlin.add_bytes(root.as_ref())?;
merlin.add_digest(root)?;

let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples];
let mut ood_answers = Vec::with_capacity(self.0.committment_ood_samples);
Expand Down
9 changes: 9 additions & 0 deletions src/whir/fs_utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use ark_crypto_primitives::merkle_tree::Config;
use crate::utils::dedup;
use nimue::{ByteChallenges, ProofResult};

Expand All @@ -24,3 +25,11 @@ where
});
Ok(dedup(indices))
}

pub trait DigestWriter<MerkleConfig: Config> {
fn add_digest(&mut self, digest: MerkleConfig::InnerDigest) -> ProofResult<()>;
}

pub trait DigestReader<MerkleConfig: Config> {
fn read_digest(&mut self) -> ProofResult<MerkleConfig::InnerDigest>;
}
28 changes: 16 additions & 12 deletions src/whir/iopattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,44 @@ use crate::{

use super::parameters::WhirConfig;

pub trait WhirIOPattern<F: FftField> {
fn commit_statement<MerkleConfig: Config, PowStrategy>(
self,
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
) -> Self;
fn add_whir_proof<MerkleConfig: Config, PowStrategy>(
pub trait DigestIOPattern<MerkleConfig: Config> {
fn add_digest(self, label: &str) -> Self;
}

pub trait WhirIOPattern<F: FftField, MerkleConfig: Config> {
fn commit_statement<PowStrategy>(
self,
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
) -> Self;
fn add_whir_proof<PowStrategy>(self, params: &WhirConfig<F, MerkleConfig, PowStrategy>)
-> Self;
}

impl<F, IOPattern> WhirIOPattern<F> for IOPattern
impl<F, MerkleConfig, IOPattern> WhirIOPattern<F, MerkleConfig> for IOPattern
where
F: FftField,
MerkleConfig: Config,
IOPattern: ByteIOPattern
+ FieldIOPattern<F>
+ SumcheckNotSkippingIOPattern<F>
+ WhirPoWIOPattern
+ OODIOPattern<F>,
+ OODIOPattern<F>
+ DigestIOPattern<MerkleConfig>,
{
fn commit_statement<MerkleConfig: Config, PowStrategy>(
fn commit_statement<PowStrategy>(
self,
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
) -> Self {
// TODO: Add params
let mut this = self.add_bytes(32, "merkle_digest");
let mut this = self.add_digest("merkle_digest");
if params.committment_ood_samples > 0 {
assert!(params.initial_statement);
this = this.add_ood(params.committment_ood_samples);
}
this
}

fn add_whir_proof<MerkleConfig: Config, PowStrategy>(
fn add_whir_proof<PowStrategy>(
mut self,
params: &WhirConfig<F, MerkleConfig, PowStrategy>,
) -> Self {
Expand All @@ -62,7 +66,7 @@ where
for r in &params.round_parameters {
let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8;
self = self
.add_bytes(32, "merkle_digest")
.add_digest("merkle_digest")
.add_ood(r.ood_samples)
.challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries")
.pow(r.pow_bits)
Expand Down
2 changes: 1 addition & 1 deletion src/whir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod iopattern;
pub mod parameters;
pub mod prover;
pub mod verifier;
mod fs_utils;
pub mod fs_utils;

#[derive(Debug, Clone, Default)]
pub struct Statement<F> {
Expand Down
9 changes: 4 additions & 5 deletions src/whir/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use nimue::{
};
use nimue_pow::{self, PoWChallenge};

use crate::whir::fs_utils::get_challenge_stir_queries;
use crate::whir::fs_utils::{get_challenge_stir_queries, DigestWriter};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand All @@ -33,7 +33,6 @@ impl<F, MerkleConfig, PowStrategy> Prover<F, MerkleConfig, PowStrategy>
where
F: FftField,
MerkleConfig: Config<Leaf = [F]>,
MerkleConfig::InnerDigest: AsRef<[u8]>,
PowStrategy: nimue_pow::PowStrategy,
{
fn validate_parameters(&self) -> bool {
Expand Down Expand Up @@ -73,7 +72,7 @@ where
witness: Witness<F, MerkleConfig>,
) -> ProofResult<WhirProof<MerkleConfig, F>>
where
Merlin: FieldChallenges<F> + FieldWriter<F> + ByteChallenges + ByteWriter + PoWChallenge,
Merlin: FieldChallenges<F> + FieldWriter<F> + ByteChallenges + ByteWriter + PoWChallenge + DigestWriter<MerkleConfig>,
{
assert!(self.validate_parameters());
assert!(self.validate_statement(&statement));
Expand Down Expand Up @@ -154,7 +153,7 @@ where
mut round_state: RoundState<F, MerkleConfig>,
) -> ProofResult<WhirProof<MerkleConfig, F>>
where
Merlin: FieldChallenges<F> + ByteChallenges + FieldWriter<F> + ByteWriter + PoWChallenge,
Merlin: FieldChallenges<F> + ByteChallenges + FieldWriter<F> + ByteWriter + PoWChallenge + DigestWriter<MerkleConfig>,
{
// Fold the coefficients
let folded_coefficients = round_state
Expand Down Expand Up @@ -241,7 +240,7 @@ where
.unwrap();

let root = merkle_tree.root();
merlin.add_bytes(root.as_ref())?;
merlin.add_digest(root)?;

// OOD Samples
let mut ood_points = vec![F::ZERO; round_params.ood_samples];
Expand Down
Loading

0 comments on commit 2a52497

Please sign in to comment.