Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge PCS and LDT #16

Merged
merged 10 commits into from
Nov 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Merge branch 'main' into recmo/merge-ldt
recmo committed Nov 26, 2024
commit 6c74f8541f6f8a2fb68f39193583cd8971df4fa0
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -7,3 +7,4 @@ outputs/temp/
*.pdf
scripts/__pycache__/
.DS_Store
.idea
35 changes: 0 additions & 35 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -41,11 +41,4 @@ parallel = [
]
rayon = ["dep:rayon"]

[patch.crates-io]
ark-std = { git = "https://github.com/arkworks-rs/std" }
ark-crypto-primitives = { git = "https://github.com/arkworks-rs/crypto-primitives" }
ark-test-curves = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" }
ark-ff = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" }
ark-poly = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" }
ark-serialize = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" }
ark-ec = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" }

2 changes: 1 addition & 1 deletion src/bin/benchmark.rs
Original file line number Diff line number Diff line change
@@ -348,7 +348,7 @@ fn run_whir<F, MerkleConfig>(
.collect();
let evaluations = points
.iter()
.map(|point| polynomial.evaluate_at_extension(&point))
.map(|point| polynomial.evaluate_at_extension(point))
.collect();
let statement = Statement {
points,
6 changes: 3 additions & 3 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use ark_crypto_primitives::{
merkle_tree::Config,
};
use ark_ff::FftField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_serialize::CanonicalSerialize;
use nimue::{DefaultHash, IOPattern};
use whir::{
cmdline_utils::{AvailableFields, AvailableMerkle, WhirType},
@@ -202,7 +202,7 @@ fn run_whir_as_ldt<F, MerkleConfig>(
{
use whir::whir::{
committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover,
verifier::Verifier, WhirProof,
verifier::Verifier,
};

// Runs as a LDT
@@ -369,7 +369,7 @@ fn run_whir_pcs<F, MerkleConfig>(
.collect();
let evaluations = points
.iter()
.map(|point| polynomial.evaluate_at_extension(&point))
.map(|point| polynomial.evaluate_at_extension(point))
.collect();

let statement = Statement {
6 changes: 3 additions & 3 deletions src/crypto/merkle_tree/blake3.rs
Original file line number Diff line number Diff line change
@@ -122,8 +122,8 @@ pub fn default_config<F: CanonicalSerialize + Send>(
<LeafH<F> as CRHScheme>::Parameters,
<CompressH as TwoToOneCRHScheme>::Parameters,
) {
let leaf_hash_params = <LeafH<F> as CRHScheme>::setup(rng).unwrap();
let two_to_one_params = <CompressH as TwoToOneCRHScheme>::setup(rng).unwrap();
<LeafH<F> as CRHScheme>::setup(rng).unwrap();
<CompressH as TwoToOneCRHScheme>::setup(rng).unwrap();

(leaf_hash_params, two_to_one_params)
((), ())
}
10 changes: 5 additions & 5 deletions src/crypto/merkle_tree/keccak.rs
Original file line number Diff line number Diff line change
@@ -83,8 +83,8 @@ impl TwoToOneCRHScheme for KeccakTwoToOneCRHScheme {
right_input: T,
) -> Result<Self::Output, ark_crypto_primitives::Error> {
let mut h = sha3::Keccak256::new();
h.update(&left_input.borrow().0);
h.update(&right_input.borrow().0);
h.update(left_input.borrow().0);
h.update(right_input.borrow().0);
let mut output = [0; 32];
output.copy_from_slice(&h.finalize()[..]);
HashCounter::add();
@@ -123,8 +123,8 @@ pub fn default_config<F: CanonicalSerialize + Send>(
<LeafH<F> as CRHScheme>::Parameters,
<CompressH as TwoToOneCRHScheme>::Parameters,
) {
let leaf_hash_params = <LeafH<F> as CRHScheme>::setup(rng).unwrap();
let two_to_one_params = <CompressH as TwoToOneCRHScheme>::setup(rng).unwrap();
<LeafH<F> as CRHScheme>::setup(rng).unwrap();
<CompressH as TwoToOneCRHScheme>::setup(rng).unwrap();

(leaf_hash_params, two_to_one_params)
((), ())
}
12 changes: 7 additions & 5 deletions src/crypto/merkle_tree/mock.rs
Original file line number Diff line number Diff line change
@@ -58,10 +58,12 @@ pub fn default_config<F: CanonicalSerialize + Send>(
<LeafH<F> as CRHScheme>::Parameters,
<CompressH as TwoToOneCRHScheme>::Parameters,
) {
let leaf_hash_params = <LeafH<F> as CRHScheme>::setup(rng).unwrap();
let two_to_one_params = <CompressH as TwoToOneCRHScheme>::setup(rng)
.unwrap()
.clone();
<LeafH<F> as CRHScheme>::setup(rng).unwrap();
{
<CompressH as TwoToOneCRHScheme>::setup(rng)
.unwrap();

};

(leaf_hash_params, two_to_one_params)
((), ())
}
12 changes: 6 additions & 6 deletions src/ntt/transpose.rs
Original file line number Diff line number Diff line change
@@ -54,9 +54,9 @@ fn transpose_copy<F: Sized + Copy + Send>(src: MatrixMut<F>, dst: MatrixMut<F>)

/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible.
#[cfg(feature = "parallel")]
fn transpose_copy_parallel<'a, 'b, F: Sized + Copy + Send>(
src: MatrixMut<'a, F>,
mut dst: MatrixMut<'b, F>,
fn transpose_copy_parallel<F: Sized + Copy + Send>(
src: MatrixMut<'_, F>,
mut dst: MatrixMut<'_, F>,
) {
assert_eq!(src.rows(), dst.cols());
assert_eq!(src.cols(), dst.rows());
@@ -85,9 +85,9 @@ fn transpose_copy_parallel<'a, 'b, F: Sized + Copy + Send>(

/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible.
/// This is the non-parallel version
fn transpose_copy_not_parallel<'a, 'b, F: Sized + Copy>(
src: MatrixMut<'a, F>,
mut dst: MatrixMut<'b, F>,
fn transpose_copy_not_parallel<F: Sized + Copy>(
src: MatrixMut<'_, F>,
mut dst: MatrixMut<'_, F>,
) {
assert_eq!(src.rows(), dst.cols());
assert_eq!(src.cols(), dst.rows());
1 change: 0 additions & 1 deletion src/ntt/utils.rs
Original file line number Diff line number Diff line change
@@ -142,7 +142,6 @@ mod tests {
);
let should_not_work = std::panic::catch_unwind(|| {
as_chunks_exact_mut::<_, 2>(&mut [1, 2, 3]);
return;
});
assert!(should_not_work.is_err())
}
6 changes: 3 additions & 3 deletions src/poly_utils/fold.rs
Original file line number Diff line number Diff line change
@@ -179,7 +179,7 @@ mod tests {

// Evaluate the polynomial on the domain
let domain_evaluations: Vec<_> = (0..domain_size)
.map(|w| root_of_unity.pow([w as u64]))
.map(|w| root_of_unity.pow([w]))
.map(|point| {
poly.evaluate(&MultilinearPoint::expand_from_univariate(
point,
@@ -199,10 +199,10 @@ mod tests {
);

let num = domain_size / folding_factor_exp;
let coset_gen_inv = root_of_unity_inv.pow(&[num]);
let coset_gen_inv = root_of_unity_inv.pow([num]);

for index in 0..num {
let offset_inv = root_of_unity_inv.pow(&[index]);
let offset_inv = root_of_unity_inv.pow([index]);
let span =
(index * folding_factor_exp) as usize..((index + 1) * folding_factor_exp) as usize;

3 changes: 2 additions & 1 deletion src/poly_utils/sequential_lag_poly.rs
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ use super::{hypercube::BinaryHypercubePoint, MultilinearPoint};

/// There is an alternative (possibly more efficient) implementation that iterates over the x in Gray code ordering.
///
/// LagrangePolynomialIterator for a given multilinear n-dimensional `point` iterates over pairs (x, y)
/// where x ranges over all possible {0,1}^n
/// and y equals the product y_1 * ... * y_n where
@@ -60,7 +61,7 @@ impl<F: Field> Iterator for LagrangePolynomialIterator<F> {
// Iterator implementation for the struct
fn next(&mut self) -> Option<Self::Item> {
// a) Check if this is the first iteration
if self.last_position == None {
if self.last_position.is_none() {
// Initialize last position
self.last_position = Some(0);
// Return the top of the stack
2 changes: 1 addition & 1 deletion src/poly_utils/streaming_evaluation_helper.rs
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ impl<F: Field> Iterator for TermPolynomialIterator<F> {
// Iterator implementation for the struct
fn next(&mut self) -> Option<Self::Item> {
// a) Check if this is the first iteration
if self.last_position == None {
if self.last_position.is_none() {
// Initialize last position
self.last_position = Some(0);
// Return the top of the stack
10 changes: 5 additions & 5 deletions src/sumcheck/mod.rs
Original file line number Diff line number Diff line change
@@ -99,7 +99,7 @@ mod tests {
// First, check that is sums to the right value over the hypercube
assert_eq!(poly_1.sum_over_hypercube(), claimed_value);

let combination_randomness = vec![F::from(293), F::from(42)];
let combination_randomness = [F::from(293), F::from(42)];
let folding_randomness = MultilinearPoint(vec![F::from(335), F::from(222)]);

let new_eval_point = MultilinearPoint(vec![F::from(32); num_variables - folding_factor]);
@@ -146,7 +146,7 @@ mod tests {
let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)];
let folding_randomness_1 = MultilinearPoint(vec![F::from(11), F::from(31)]);
let fold_point = MultilinearPoint(vec![F::from(31), F::from(15)]);
let combination_randomness = vec![F::from(31), F::from(4999)];
let combination_randomness = [F::from(31), F::from(4999)];
let folding_randomness_2 = MultilinearPoint(vec![F::from(97), F::from(36)]);

let mut prover = SumcheckCore::new(
@@ -184,7 +184,7 @@ mod tests {
);

let full_folding =
MultilinearPoint(vec![folding_randomness_2.0.clone(), folding_randomness_1.0].concat());
MultilinearPoint([folding_randomness_2.0.clone(), folding_randomness_1.0].concat());
let eval_coeff = folded_poly_1.fold(&folding_randomness_2).coeffs()[0];
assert_eq!(
sumcheck_poly_2.evaluate_at_point(&folding_randomness_2),
@@ -217,8 +217,8 @@ mod tests {
let fold_point_12 =
MultilinearPoint(vec![F::from(1231), F::from(15), F::from(4231), F::from(15)]);
let fold_point_2 = MultilinearPoint(vec![F::from(311), F::from(115)]);
let combination_randomness_1 = vec![F::from(1289), F::from(3281), F::from(10921)];
let combination_randomness_2 = vec![F::from(3281), F::from(3232)];
let combination_randomness_1 = [F::from(1289), F::from(3281), F::from(10921)];
let combination_randomness_2 = [F::from(3281), F::from(3232)];

let mut prover = SumcheckCore::new(
polynomial.clone(),
2 changes: 1 addition & 1 deletion src/sumcheck/proof.rs
Original file line number Diff line number Diff line change
@@ -90,7 +90,7 @@ mod tests {
let num_evaluation_points = 3_usize.pow(num_variables as u32);
let evaluations = (0..num_evaluation_points as u64).map(F::from).collect();

let poly = SumcheckPolynomial::new(evaluations, num_variables as usize);
let poly = SumcheckPolynomial::new(evaluations, num_variables);

for i in 0..num_evaluation_points {
let decomp = base_decomposition(i, 3, num_variables);
10 changes: 5 additions & 5 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -121,11 +121,11 @@ mod tests {

#[test]
fn test_is_power_of_two() {
assert_eq!(is_power_of_two(0), false);
assert_eq!(is_power_of_two(1), true);
assert_eq!(is_power_of_two(2), true);
assert_eq!(is_power_of_two(3), false);
assert_eq!(is_power_of_two(usize::MAX), false);
assert!(!is_power_of_two(0));
assert!(is_power_of_two(1));
assert!(is_power_of_two(2));
assert!(!is_power_of_two(3));
assert!(!is_power_of_two(usize::MAX));
}

#[test]
26 changes: 26 additions & 0 deletions src/whir/fs_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::utils::dedup;
use nimue::{ByteChallenges, ProofResult};

pub fn get_challenge_stir_queries<T>(
domain_size: usize,
folding_factor: usize,
num_queries: usize,
transcript: &mut T,
) -> ProofResult<Vec<usize>>
where
T: ByteChallenges,
{
let folded_domain_size = domain_size / (1 << folding_factor);
let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8;
let mut queries = vec![0u8; num_queries * domain_size_bytes];
transcript.fill_challenge_bytes(&mut queries)?;
let indices = queries.chunks_exact(domain_size_bytes).map(|chunk| {
let mut result = 0;
for byte in chunk {
result <<= 8;
result |= *byte as usize;
}
result % folded_domain_size
});
Ok(dedup(indices))
}
10 changes: 8 additions & 2 deletions src/whir/iopattern.rs
Original file line number Diff line number Diff line change
@@ -57,18 +57,24 @@ where
.pow(params.starting_folding_pow_bits);
}

let mut folded_domain_size = params.starting_domain.folded_size(params.folding_factor);

for r in &params.round_parameters {
let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8;
self = self
.add_bytes(32, "merkle_digest")
.add_ood(r.ood_samples)
.challenge_bytes(32, "stir_queries_seed")
.challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries")
.pow(r.pow_bits)
.challenge_scalars(1, "combination_randomness")
.add_sumcheck(params.folding_factor, r.folding_pow_bits);
folded_domain_size /= 2;
}

let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8;

self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs")
.challenge_bytes(32, "final_queries_seed")
.challenge_bytes(domain_size_bytes * params.final_queries, "final_queries")
.pow(params.final_pow_bits)
.add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits)
}
1 change: 1 addition & 0 deletions src/whir/mod.rs
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ pub mod iopattern;
pub mod parameters;
pub mod prover;
pub mod verifier;
mod fs_utils;

#[derive(Debug, Clone, Default)]
pub struct Statement<F> {
17 changes: 8 additions & 9 deletions src/whir/parameters.rs
Original file line number Diff line number Diff line change
@@ -401,7 +401,8 @@ where
num_queries: usize,
) -> f64 {
let num_queries = num_queries as f64;
let bits_of_sec_queries = match soundness_type {

match soundness_type {
SoundnessType::UniqueDecoding => {
let rate = 1. / ((1 << log_inv_rate) as f64);
let denom = -(0.5 * (1. + rate)).log2();
@@ -410,9 +411,7 @@ where
}
SoundnessType::ProvableList => num_queries * 0.5 * log_inv_rate as f64,
SoundnessType::ConjectureList => num_queries * log_inv_rate as f64,
};

bits_of_sec_queries
}
}

pub fn rbr_soundness_queries_combination(
@@ -504,7 +503,7 @@ where
writeln!(
f,
"{:.1} bits -- (x{}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}",
prox_gaps_error.min(sumcheck_error) + self.starting_folding_pow_bits as f64,
prox_gaps_error.min(sumcheck_error) + self.starting_folding_pow_bits,
self.folding_factor,
prox_gaps_error,
sumcheck_error,
@@ -545,7 +544,7 @@ where
writeln!(
f,
"{:.1} bits -- query error: {:.1}, combination: {:.1}, pow: {:.1}",
query_error.min(combination_error) + r.pow_bits as f64,
query_error.min(combination_error) + r.pow_bits,
query_error,
combination_error,
r.pow_bits,
@@ -569,7 +568,7 @@ where
writeln!(
f,
"{:.1} bits -- (x{}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}",
prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits as f64,
prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits,
self.folding_factor,
prox_gaps_error,
sumcheck_error,
@@ -587,7 +586,7 @@ where
writeln!(
f,
"{:.1} bits -- query error: {:.1}, pow: {:.1}",
query_error + self.final_pow_bits as f64,
query_error + self.final_pow_bits,
query_error,
self.final_pow_bits,
)?;
@@ -597,7 +596,7 @@ where
writeln!(
f,
"{:.1} bits -- (x{}) combination: {:.1}, pow: {:.1}",
combination_error + self.final_pow_bits as f64,
combination_error + self.final_pow_bits,
self.final_sumcheck_rounds,
combination_error,
self.final_folding_pow_bits,
32 changes: 15 additions & 17 deletions src/whir/prover.rs
Original file line number Diff line number Diff line change
@@ -15,12 +15,11 @@ use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath};
use ark_ff::FftField;
use ark_poly::EvaluationDomain;
use nimue::{
plugins::ark::{FieldChallenges, FieldWriter},
ByteChallenges, ByteWriter, Merlin, ProofResult,
plugins::ark::{FieldChallenges, FieldWriter}, ByteWriter, Merlin, ProofResult,
};
use nimue_pow::{self, PoWChallenge};
use rand::{Rng, SeedableRng};

use crate::whir::fs_utils::get_challenge_stir_queries;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

@@ -167,13 +166,13 @@ where
merlin.add_scalars(folded_coefficients.coeffs())?;

// Final verifier queries and answers
let mut queries_seed = [0u8; 32];
merlin.fill_challenge_bytes(&mut queries_seed)?;
let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed);
let final_challenge_indexes = utils::dedup((0..self.0.final_queries).map(|_| {
final_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor))
}));

let final_challenge_indexes = get_challenge_stir_queries(
round_state.domain.size(),
self.0.folding_factor,
self.0.final_queries,
merlin,
)?;
let merkle_proof = round_state
.prev_merkle
.generate_multi_proof(final_challenge_indexes.clone())
@@ -255,13 +254,12 @@ where
}

// STIR queries
let mut stir_queries_seed = [0u8; 32];
merlin.fill_challenge_bytes(&mut stir_queries_seed)?;
let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed);
let stir_challenges_indexes =
utils::dedup((0..round_params.num_queries).map(|_| {
stir_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor))
}));
let stir_challenges_indexes = get_challenge_stir_queries(
round_state.domain.size(),
self.0.folding_factor,
round_params.num_queries,
merlin,
)?;
let domain_scaled_gen = round_state
.domain
.backing_domain
38 changes: 18 additions & 20 deletions src/whir/verifier.rs
Original file line number Diff line number Diff line change
@@ -5,20 +5,19 @@ use ark_ff::FftField;
use ark_poly::EvaluationDomain;
use nimue::{
plugins::ark::{FieldChallenges, FieldReader},
Arthur, ByteChallenges, ByteReader, ProofError, ProofResult,
Arthur, ByteReader, ProofError, ProofResult,
};
use nimue_pow::{self, PoWChallenge};
use rand::{Rng, SeedableRng};

use super::{parameters::WhirConfig, Statement, WhirProof};
use crate::whir::fs_utils::get_challenge_stir_queries;
use crate::{
parameters::FoldType,
poly_utils::{coeffs::CoefficientList, eq_poly_outside, fold::compute_fold, MultilinearPoint},
sumcheck::proof::SumcheckPolynomial,
utils::{self, expand_randomness},
utils::{expand_randomness},
};

use super::{parameters::WhirConfig, Statement, WhirProof};

pub struct Verifier<F, MerkleConfig, PowStrategy>
where
F: FftField,
@@ -166,13 +165,13 @@ where
arthur.fill_next_scalars(&mut ood_answers)?;
}

let mut stir_queries_seed = [0u8; 32];
arthur.fill_challenge_bytes(&mut stir_queries_seed)?;
let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed);
let folded_domain_size = domain_size / (1 << self.params.folding_factor);
let stir_challenges_indexes = utils::dedup(
(0..round_params.num_queries).map(|_| stir_gen.gen_range(0..folded_domain_size)),
);
let stir_challenges_indexes = get_challenge_stir_queries(
domain_size,
self.params.folding_factor,
round_params.num_queries,
arthur,
)?;

let stir_challenges_points = stir_challenges_indexes
.iter()
.map(|index| exp_domain_gen.pow([*index as u64]))
@@ -241,13 +240,12 @@ where
let final_coefficients = CoefficientList::new(final_coefficients);

// Final queries verify
let mut queries_seed = [0u8; 32];
arthur.fill_challenge_bytes(&mut queries_seed)?;
let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed);
let folded_domain_size = domain_size / (1 << self.params.folding_factor);
let final_randomness_indexes = utils::dedup(
(0..self.params.final_queries).map(|_| final_gen.gen_range(0..folded_domain_size)),
);
let final_randomness_indexes = get_challenge_stir_queries(
domain_size,
self.params.folding_factor,
self.params.final_queries,
arthur,
)?;
let final_randomness_points = final_randomness_indexes
.iter()
.map(|index| exp_domain_gen.pow([*index as u64]))
@@ -355,7 +353,7 @@ where
.map(|(point, rand)| point * rand)
.sum();

value = value + sum_of_claims;
value += sum_of_claims;
}

value
You are viewing a condensed version of this merge commit. You can view the full changes here.