diff --git a/.gitignore b/.gitignore index 69fc9b0..3825e56 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ outputs/temp/ *.pdf scripts/__pycache__/ .DS_Store -outputs/ \ No newline at end of file +outputs/ +.idea diff --git a/Cargo.lock b/Cargo.lock index 9846ff1..374bd05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1066,38 +1066,3 @@ dependencies = [ "quote", "syn 2.0.87", ] - -[[patch.unused]] -name = "ark-crypto-primitives" -version = "0.4.0" -source = "git+https://github.com/arkworks-rs/crypto-primitives#b13983815e5b3a0fbeed0e7da0edec751beac270" - -[[patch.unused]] -name = "ark-ec" -version = "0.4.2" -source = "git+https://github.com/WizardOfMenlo/algebra?branch=fft_extensions#fe932806ba46425ba8adc7536dbc20f6d61ae13a" - -[[patch.unused]] -name = "ark-ff" -version = "0.4.2" -source = "git+https://github.com/WizardOfMenlo/algebra?branch=fft_extensions#fe932806ba46425ba8adc7536dbc20f6d61ae13a" - -[[patch.unused]] -name = "ark-poly" -version = "0.4.2" -source = "git+https://github.com/WizardOfMenlo/algebra?branch=fft_extensions#fe932806ba46425ba8adc7536dbc20f6d61ae13a" - -[[patch.unused]] -name = "ark-serialize" -version = "0.4.2" -source = "git+https://github.com/WizardOfMenlo/algebra?branch=fft_extensions#fe932806ba46425ba8adc7536dbc20f6d61ae13a" - -[[patch.unused]] -name = "ark-std" -version = "0.4.0" -source = "git+https://github.com/arkworks-rs/std#db4367e68ff60da31ac759831e38f60171f4e03d" - -[[patch.unused]] -name = "ark-test-curves" -version = "0.4.2" -source = "git+https://github.com/WizardOfMenlo/algebra?branch=fft_extensions#fe932806ba46425ba8adc7536dbc20f6d61ae13a" diff --git a/Cargo.toml b/Cargo.toml index a7a466e..cf6e503 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" default-run = "main" [dependencies] -ark-std = {version = "0.5", features = ["std"]} +ark-std = { version = "0.5", features = ["std"] } ark-ff = { version = "0.5", features = ["asm", "std"] } ark-serialize = "0.5" ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } @@ -23,7 +23,7 @@ clap = { version = "4.4.17", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" nimue = { git = "https://github.com/arkworks-rs/nimue", features = ["ark"] } -nimue-pow = { git = "https://github.com/arkworks-rs/nimue"} +nimue-pow = { git = "https://github.com/arkworks-rs/nimue" } lazy_static = "1.4" rayon = { version = "1.10.0", optional = true } @@ -41,11 +41,4 @@ parallel = [ ] rayon = ["dep:rayon"] -[patch.crates-io] -ark-std = { git = "https://github.com/arkworks-rs/std" } -ark-crypto-primitives = { git = "https://github.com/arkworks-rs/crypto-primitives" } -ark-test-curves = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" } -ark-ff = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" } -ark-poly = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" } -ark-serialize = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" } -ark-ec = { git = "https://github.com/WizardOfMenlo/algebra", branch = "fft_extensions" } + diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index 2a67c5d..d63e1f7 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -19,7 +19,7 @@ use whir::{ }, parameters::*, poly_utils::coeffs::CoefficientList, - whir::Statement, + whir::{iopattern::WhirIOPattern, Statement}, }; use serde::Serialize; @@ -228,6 +228,7 @@ fn run_whir( let mv_params = MultivariateParameters::::new(num_variables); let whir_params = WhirParameters:: { + initial_statement: true, security_level, pow_bits, folding_factor, @@ -253,11 +254,15 @@ fn run_whir( whir_ldt_verifier_hashes, ) = { // Run LDT - use whir::whir_ldt::{ + use whir::whir::{ committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, verifier::Verifier, whir_proof_size, }; + let whir_params = WhirParameters:: { + initial_statement: false, + ..whir_params.clone() + }; let params = WhirConfig::::new(mv_params, whir_params.clone()); if !params.check_pow_bits() { @@ -280,7 +285,9 @@ fn run_whir( let prover = Prover(params.clone()); - let proof = prover.prove(&mut merlin, witness).unwrap(); + let proof = prover + .prove(&mut merlin, Statement::default(), witness) + .unwrap(); let whir_ldt_prover_time = whir_ldt_prover_time.elapsed(); let whir_ldt_argument_size = whir_proof_size(merlin.transcript(), &proof); @@ -293,7 +300,9 @@ fn run_whir( let whir_ldt_verifier_time = Instant::now(); for _ in 0..reps { let mut arthur = io.to_arthur(merlin.transcript()); - verifier.verify(&mut arthur, &proof).unwrap(); + verifier + .verify(&mut arthur, &Statement::default(), &proof) + .unwrap(); } let whir_ldt_verifier_time = whir_ldt_verifier_time.elapsed(); @@ -339,7 +348,7 @@ fn run_whir( .collect(); let evaluations = points .iter() - .map(|point| polynomial.evaluate_at_extension(&point)) + .map(|point| polynomial.evaluate_at_extension(point)) .collect(); let statement = Statement { points, diff --git a/src/bin/main.rs b/src/bin/main.rs index a1d83b8..848878f 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -15,6 +15,7 @@ use whir::{ }, parameters::*, poly_utils::{coeffs::CoefficientList, MultilinearPoint}, + whir::Statement, }; use nimue_pow::blake3::Blake3PoW; @@ -199,9 +200,9 @@ fn run_whir_as_ldt( MerkleConfig: Config + Clone, MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, { - use whir::whir_ldt::{ + use whir::whir::{ committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, - verifier::Verifier, whir_proof_size, + verifier::Verifier, }; // Runs as a LDT @@ -223,6 +224,7 @@ fn run_whir_as_ldt( let mv_params = MultivariateParameters::::new(num_variables); let whir_params = WhirParameters:: { + initial_statement: false, security_level, pow_bits, folding_factor, @@ -234,12 +236,11 @@ fn run_whir_as_ldt( starting_log_inv_rate: starting_rate, }; - let params = WhirConfig::::new(mv_params, whir_params); + let params = WhirConfig::::new(mv_params, whir_params.clone()); let io = IOPattern::::new("🌪️") .commit_statement(¶ms) - .add_whir_proof(¶ms) - .clone(); + .add_whir_proof(¶ms); let mut merlin = io.to_merlin(); @@ -265,19 +266,30 @@ fn run_whir_as_ldt( let prover = Prover(params.clone()); - let proof = prover.prove(&mut merlin, witness).unwrap(); + let proof = prover + .prove(&mut merlin, Statement::default(), witness) + .unwrap(); dbg!(whir_prover_time.elapsed()); - dbg!(whir_proof_size(merlin.transcript(), &proof)); + + // Serialize proof + let transcript = merlin.transcript().to_vec(); + let mut proof_bytes = vec![]; + proof.serialize_compressed(&mut proof_bytes).unwrap(); + + let proof_size = transcript.len() + proof_bytes.len(); + dbg!(proof_size); // Just not to count that initial inversion (which could be precomputed) - let verifier = Verifier::new(params); + let verifier = Verifier::new(params.clone()); HashCounter::reset(); let whir_verifier_time = Instant::now(); for _ in 0..reps { - let mut arthur = io.to_arthur(merlin.transcript()); - verifier.verify(&mut arthur, &proof).unwrap(); + let mut arthur = io.to_arthur(&transcript); + verifier + .verify(&mut arthur, &Statement::default(), &proof) + .unwrap(); } dbg!(whir_verifier_time.elapsed() / reps as u32); dbg!(HashCounter::get() as f64 / reps as f64); @@ -317,6 +329,7 @@ fn run_whir_pcs( let mv_params = MultivariateParameters::::new(num_variables); let whir_params = WhirParameters:: { + initial_statement: true, security_level, pow_bits, folding_factor, @@ -356,7 +369,7 @@ fn run_whir_pcs( .collect(); let evaluations = points .iter() - .map(|point| polynomial.evaluate_at_extension(&point)) + .map(|point| polynomial.evaluate_at_extension(point)) .collect(); let statement = Statement { diff --git a/src/crypto/merkle_tree/blake3.rs b/src/crypto/merkle_tree/blake3.rs index 4f30592..6090915 100644 --- a/src/crypto/merkle_tree/blake3.rs +++ b/src/crypto/merkle_tree/blake3.rs @@ -122,8 +122,8 @@ pub fn default_config( as CRHScheme>::Parameters, ::Parameters, ) { - let leaf_hash_params = as CRHScheme>::setup(rng).unwrap(); - let two_to_one_params = ::setup(rng).unwrap(); + as CRHScheme>::setup(rng).unwrap(); + ::setup(rng).unwrap(); - (leaf_hash_params, two_to_one_params) + ((), ()) } diff --git a/src/crypto/merkle_tree/keccak.rs b/src/crypto/merkle_tree/keccak.rs index 93a97b9..c4f8558 100644 --- a/src/crypto/merkle_tree/keccak.rs +++ b/src/crypto/merkle_tree/keccak.rs @@ -83,8 +83,8 @@ impl TwoToOneCRHScheme for KeccakTwoToOneCRHScheme { right_input: T, ) -> Result { let mut h = sha3::Keccak256::new(); - h.update(&left_input.borrow().0); - h.update(&right_input.borrow().0); + h.update(left_input.borrow().0); + h.update(right_input.borrow().0); let mut output = [0; 32]; output.copy_from_slice(&h.finalize()[..]); HashCounter::add(); @@ -123,8 +123,8 @@ pub fn default_config( as CRHScheme>::Parameters, ::Parameters, ) { - let leaf_hash_params = as CRHScheme>::setup(rng).unwrap(); - let two_to_one_params = ::setup(rng).unwrap(); + as CRHScheme>::setup(rng).unwrap(); + ::setup(rng).unwrap(); - (leaf_hash_params, two_to_one_params) + ((), ()) } diff --git a/src/crypto/merkle_tree/mock.rs b/src/crypto/merkle_tree/mock.rs index 8ac9f3b..9102490 100644 --- a/src/crypto/merkle_tree/mock.rs +++ b/src/crypto/merkle_tree/mock.rs @@ -58,10 +58,12 @@ pub fn default_config( as CRHScheme>::Parameters, ::Parameters, ) { - let leaf_hash_params = as CRHScheme>::setup(rng).unwrap(); - let two_to_one_params = ::setup(rng) - .unwrap() - .clone(); + as CRHScheme>::setup(rng).unwrap(); + { + ::setup(rng) + .unwrap(); + + }; - (leaf_hash_params, two_to_one_params) + ((), ()) } diff --git a/src/lib.rs b/src/lib.rs index 153baf3..7bb62c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,4 +8,3 @@ pub mod poly_utils; // Utils for polynomials pub mod sumcheck; // Sumcheck specialised pub mod utils; // Utils in general pub mod whir; // The real prover -pub mod whir_ldt; // Whir as a LDT // Shared parameters diff --git a/src/ntt/transpose.rs b/src/ntt/transpose.rs index e6bc6c8..be81ebf 100644 --- a/src/ntt/transpose.rs +++ b/src/ntt/transpose.rs @@ -54,9 +54,9 @@ fn transpose_copy(src: MatrixMut, dst: MatrixMut) /// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible. #[cfg(feature = "parallel")] -fn transpose_copy_parallel<'a, 'b, F: Sized + Copy + Send>( - src: MatrixMut<'a, F>, - mut dst: MatrixMut<'b, F>, +fn transpose_copy_parallel( + src: MatrixMut<'_, F>, + mut dst: MatrixMut<'_, F>, ) { assert_eq!(src.rows(), dst.cols()); assert_eq!(src.cols(), dst.rows()); @@ -85,9 +85,9 @@ fn transpose_copy_parallel<'a, 'b, F: Sized + Copy + Send>( /// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible. /// This is the non-parallel version -fn transpose_copy_not_parallel<'a, 'b, F: Sized + Copy>( - src: MatrixMut<'a, F>, - mut dst: MatrixMut<'b, F>, +fn transpose_copy_not_parallel( + src: MatrixMut<'_, F>, + mut dst: MatrixMut<'_, F>, ) { assert_eq!(src.rows(), dst.cols()); assert_eq!(src.cols(), dst.rows()); diff --git a/src/ntt/utils.rs b/src/ntt/utils.rs index 11ccdc3..e177d4e 100644 --- a/src/ntt/utils.rs +++ b/src/ntt/utils.rs @@ -142,7 +142,6 @@ mod tests { ); let should_not_work = std::panic::catch_unwind(|| { as_chunks_exact_mut::<_, 2>(&mut [1, 2, 3]); - return; }); assert!(should_not_work.is_err()) } diff --git a/src/parameters.rs b/src/parameters.rs index e1f8f45..633670d 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -101,6 +101,7 @@ pub struct WhirParameters where MerkleConfig: Config, { + pub initial_statement: bool, pub starting_log_inv_rate: usize, pub folding_factor: usize, pub soundness_type: SoundnessType, diff --git a/src/poly_utils/fold.rs b/src/poly_utils/fold.rs index 44bf9a8..69f5e90 100644 --- a/src/poly_utils/fold.rs +++ b/src/poly_utils/fold.rs @@ -179,7 +179,7 @@ mod tests { // Evaluate the polynomial on the domain let domain_evaluations: Vec<_> = (0..domain_size) - .map(|w| root_of_unity.pow([w as u64])) + .map(|w| root_of_unity.pow([w])) .map(|point| { poly.evaluate(&MultilinearPoint::expand_from_univariate( point, @@ -199,10 +199,10 @@ mod tests { ); let num = domain_size / folding_factor_exp; - let coset_gen_inv = root_of_unity_inv.pow(&[num]); + let coset_gen_inv = root_of_unity_inv.pow([num]); for index in 0..num { - let offset_inv = root_of_unity_inv.pow(&[index]); + let offset_inv = root_of_unity_inv.pow([index]); let span = (index * folding_factor_exp) as usize..((index + 1) * folding_factor_exp) as usize; diff --git a/src/poly_utils/sequential_lag_poly.rs b/src/poly_utils/sequential_lag_poly.rs index c59355f..d9a203f 100644 --- a/src/poly_utils/sequential_lag_poly.rs +++ b/src/poly_utils/sequential_lag_poly.rs @@ -6,6 +6,7 @@ use super::{hypercube::BinaryHypercubePoint, MultilinearPoint}; /// There is an alternative (possibly more efficient) implementation that iterates over the x in Gray code ordering. +/// /// LagrangePolynomialIterator for a given multilinear n-dimensional `point` iterates over pairs (x, y) /// where x ranges over all possible {0,1}^n /// and y equals the product y_1 * ... * y_n where @@ -60,7 +61,7 @@ impl Iterator for LagrangePolynomialIterator { // Iterator implementation for the struct fn next(&mut self) -> Option { // a) Check if this is the first iteration - if self.last_position == None { + if self.last_position.is_none() { // Initialize last position self.last_position = Some(0); // Return the top of the stack diff --git a/src/poly_utils/streaming_evaluation_helper.rs b/src/poly_utils/streaming_evaluation_helper.rs index 322a6e5..8b2cfe3 100644 --- a/src/poly_utils/streaming_evaluation_helper.rs +++ b/src/poly_utils/streaming_evaluation_helper.rs @@ -37,7 +37,7 @@ impl Iterator for TermPolynomialIterator { // Iterator implementation for the struct fn next(&mut self) -> Option { // a) Check if this is the first iteration - if self.last_position == None { + if self.last_position.is_none() { // Initialize last position self.last_position = Some(0); // Return the top of the stack diff --git a/src/sumcheck/mod.rs b/src/sumcheck/mod.rs index eaeac85..0c1d45a 100644 --- a/src/sumcheck/mod.rs +++ b/src/sumcheck/mod.rs @@ -99,7 +99,7 @@ mod tests { // First, check that is sums to the right value over the hypercube assert_eq!(poly_1.sum_over_hypercube(), claimed_value); - let combination_randomness = vec![F::from(293), F::from(42)]; + let combination_randomness = [F::from(293), F::from(42)]; let folding_randomness = MultilinearPoint(vec![F::from(335), F::from(222)]); let new_eval_point = MultilinearPoint(vec![F::from(32); num_variables - folding_factor]); @@ -146,7 +146,7 @@ mod tests { let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; let folding_randomness_1 = MultilinearPoint(vec![F::from(11), F::from(31)]); let fold_point = MultilinearPoint(vec![F::from(31), F::from(15)]); - let combination_randomness = vec![F::from(31), F::from(4999)]; + let combination_randomness = [F::from(31), F::from(4999)]; let folding_randomness_2 = MultilinearPoint(vec![F::from(97), F::from(36)]); let mut prover = SumcheckCore::new( @@ -184,7 +184,7 @@ mod tests { ); let full_folding = - MultilinearPoint(vec![folding_randomness_2.0.clone(), folding_randomness_1.0].concat()); + MultilinearPoint([folding_randomness_2.0.clone(), folding_randomness_1.0].concat()); let eval_coeff = folded_poly_1.fold(&folding_randomness_2).coeffs()[0]; assert_eq!( sumcheck_poly_2.evaluate_at_point(&folding_randomness_2), @@ -217,8 +217,8 @@ mod tests { let fold_point_12 = MultilinearPoint(vec![F::from(1231), F::from(15), F::from(4231), F::from(15)]); let fold_point_2 = MultilinearPoint(vec![F::from(311), F::from(115)]); - let combination_randomness_1 = vec![F::from(1289), F::from(3281), F::from(10921)]; - let combination_randomness_2 = vec![F::from(3281), F::from(3232)]; + let combination_randomness_1 = [F::from(1289), F::from(3281), F::from(10921)]; + let combination_randomness_2 = [F::from(3281), F::from(3232)]; let mut prover = SumcheckCore::new( polynomial.clone(), diff --git a/src/sumcheck/proof.rs b/src/sumcheck/proof.rs index 36cfe27..2455541 100644 --- a/src/sumcheck/proof.rs +++ b/src/sumcheck/proof.rs @@ -90,7 +90,7 @@ mod tests { let num_evaluation_points = 3_usize.pow(num_variables as u32); let evaluations = (0..num_evaluation_points as u64).map(F::from).collect(); - let poly = SumcheckPolynomial::new(evaluations, num_variables as usize); + let poly = SumcheckPolynomial::new(evaluations, num_variables); for i in 0..num_evaluation_points { let decomp = base_decomposition(i, 3, num_variables); diff --git a/src/utils.rs b/src/utils.rs index 9fcc0aa..fc0104a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -121,11 +121,11 @@ mod tests { #[test] fn test_is_power_of_two() { - assert_eq!(is_power_of_two(0), false); - assert_eq!(is_power_of_two(1), true); - assert_eq!(is_power_of_two(2), true); - assert_eq!(is_power_of_two(3), false); - assert_eq!(is_power_of_two(usize::MAX), false); + assert!(!is_power_of_two(0)); + assert!(is_power_of_two(1)); + assert!(is_power_of_two(2)); + assert!(!is_power_of_two(3)); + assert!(!is_power_of_two(usize::MAX)); } #[test] diff --git a/src/whir/fs_utils.rs b/src/whir/fs_utils.rs new file mode 100644 index 0000000..e435b99 --- /dev/null +++ b/src/whir/fs_utils.rs @@ -0,0 +1,26 @@ +use crate::utils::dedup; +use nimue::{ByteChallenges, ProofResult}; + +pub fn get_challenge_stir_queries( + domain_size: usize, + folding_factor: usize, + num_queries: usize, + transcript: &mut T, +) -> ProofResult> +where + T: ByteChallenges, +{ + let folded_domain_size = domain_size / (1 << folding_factor); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + let mut queries = vec![0u8; num_queries * domain_size_bytes]; + transcript.fill_challenge_bytes(&mut queries)?; + let indices = queries.chunks_exact(domain_size_bytes).map(|chunk| { + let mut result = 0; + for byte in chunk { + result <<= 8; + result |= *byte as usize; + } + result % folded_domain_size + }); + Ok(dedup(indices)) +} diff --git a/src/whir/iopattern.rs b/src/whir/iopattern.rs index d0e2485..06e478a 100644 --- a/src/whir/iopattern.rs +++ b/src/whir/iopattern.rs @@ -34,8 +34,12 @@ where params: &WhirConfig, ) -> Self { // TODO: Add params - self.add_bytes(32, "merkle_digest") - .add_ood(params.committment_ood_samples) + let mut this = self.add_bytes(32, "merkle_digest"); + if params.committment_ood_samples > 0 { + assert!(params.initial_statement); + this = this.add_ood(params.committment_ood_samples); + } + this } fn add_whir_proof( @@ -43,22 +47,34 @@ where params: &WhirConfig, ) -> Self { // TODO: Add statement - self = self - .challenge_scalars(1, "initial_combination_randomness") - .add_sumcheck(params.folding_factor, params.starting_folding_pow_bits); + if params.initial_statement { + self = self + .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck(params.folding_factor, params.starting_folding_pow_bits); + } else { + self = self + .challenge_scalars(params.folding_factor, "folding_randomness") + .pow(params.starting_folding_pow_bits); + } + + let mut folded_domain_size = params.starting_domain.folded_size(params.folding_factor); for r in ¶ms.round_parameters { + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; self = self .add_bytes(32, "merkle_digest") .add_ood(r.ood_samples) - .challenge_bytes(32, "stir_queries_seed") + .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") .pow(r.pow_bits) .challenge_scalars(1, "combination_randomness") .add_sumcheck(params.folding_factor, r.folding_pow_bits); + folded_domain_size /= 2; } + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") - .challenge_bytes(32, "final_queries_seed") + .challenge_bytes(domain_size_bytes * params.final_queries, "final_queries") .pow(params.final_pow_bits) .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) } diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 8243508..51bab88 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -8,8 +8,9 @@ pub mod iopattern; pub mod parameters; pub mod prover; pub mod verifier; +mod fs_utils; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Statement { pub points: Vec>, pub evaluations: Vec, @@ -69,6 +70,7 @@ mod tests { let mv_params = MultivariateParameters::::new(num_variables); let whir_params = WhirParameters:: { + initial_statement: true, security_level: 32, pow_bits, folding_factor, diff --git a/src/whir/parameters.rs b/src/whir/parameters.rs index 9ee3476..7a9d711 100644 --- a/src/whir/parameters.rs +++ b/src/whir/parameters.rs @@ -22,6 +22,7 @@ where pub(crate) max_pow_bits: usize, pub(crate) committment_ood_samples: usize, + pub(crate) initial_statement: bool, pub(crate) starting_domain: Domain, pub(crate) starting_log_inv_rate: usize, pub(crate) starting_folding_pow_bits: f64, @@ -86,29 +87,47 @@ where let field_size_bits = F::field_size_in_bits(); - let committment_ood_samples = Self::ood_samples( - whir_parameters.security_level, - whir_parameters.soundness_type, - mv_parameters.num_variables, - whir_parameters.starting_log_inv_rate, - Self::log_eta( + let committment_ood_samples = if whir_parameters.initial_statement { + Self::ood_samples( + whir_parameters.security_level, whir_parameters.soundness_type, + mv_parameters.num_variables, whir_parameters.starting_log_inv_rate, - ), - field_size_bits, - ); + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + field_size_bits, + ) + } else { + 0 + }; - let starting_folding_pow_bits = Self::folding_pow_bits( - whir_parameters.security_level, - whir_parameters.soundness_type, - field_size_bits, - mv_parameters.num_variables, - whir_parameters.starting_log_inv_rate, - Self::log_eta( + let starting_folding_pow_bits = if whir_parameters.initial_statement { + Self::folding_pow_bits( + whir_parameters.security_level, + whir_parameters.soundness_type, + field_size_bits, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + ) + } else { + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( whir_parameters.soundness_type, + field_size_bits, + mv_parameters.num_variables, whir_parameters.starting_log_inv_rate, - ), - ); + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + ) + (whir_parameters.folding_factor as f64).log2(); + 0_f64.max(whir_parameters.security_level as f64 - prox_gaps_error) + }; let mut round_parameters = Vec::with_capacity(num_rounds); let mut num_variables = mv_parameters.num_variables - whir_parameters.folding_factor; @@ -186,6 +205,7 @@ where WhirConfig { security_level: whir_parameters.security_level, max_pow_bits: whir_parameters.pow_bits, + initial_statement: whir_parameters.initial_statement, committment_ood_samples, mv_parameters, starting_domain, @@ -198,7 +218,7 @@ where final_pow_bits, final_sumcheck_rounds, final_folding_pow_bits, - pow_strategy: PhantomData::default(), + pow_strategy: PhantomData, fold_optimisation: whir_parameters.fold_optimisation, final_log_inv_rate: log_inv_rate, leaf_hash_params: whir_parameters.leaf_hash_params, @@ -240,14 +260,10 @@ where log_eta: f64, ) -> f64 { match soundness_type { - SoundnessType::ConjectureList => { - let result = (num_variables + log_inv_rate) as f64 - log_eta; - result - } + SoundnessType::ConjectureList => (num_variables + log_inv_rate) as f64 - log_eta, SoundnessType::ProvableList => { let log_inv_sqrt_rate: f64 = log_inv_rate as f64 / 2.; - let result = log_inv_sqrt_rate - (1. + log_eta); - result + log_inv_sqrt_rate - (1. + log_eta) } SoundnessType::UniqueDecoding => 0.0, } @@ -385,7 +401,8 @@ where num_queries: usize, ) -> f64 { let num_queries = num_queries as f64; - let bits_of_sec_queries = match soundness_type { + + match soundness_type { SoundnessType::UniqueDecoding => { let rate = 1. / ((1 << log_inv_rate) as f64); let denom = -(0.5 * (1. + rate)).log2(); @@ -394,9 +411,7 @@ where } SoundnessType::ProvableList => num_queries * 0.5 * log_inv_rate as f64, SoundnessType::ConjectureList => num_queries * log_inv_rate as f64, - }; - - bits_of_sec_queries + } } pub fn rbr_soundness_queries_combination( @@ -488,7 +503,7 @@ where writeln!( f, "{:.1} bits -- (x{}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", - prox_gaps_error.min(sumcheck_error) + self.starting_folding_pow_bits as f64, + prox_gaps_error.min(sumcheck_error) + self.starting_folding_pow_bits, self.folding_factor, prox_gaps_error, sumcheck_error, @@ -529,7 +544,7 @@ where writeln!( f, "{:.1} bits -- query error: {:.1}, combination: {:.1}, pow: {:.1}", - query_error.min(combination_error) + r.pow_bits as f64, + query_error.min(combination_error) + r.pow_bits, query_error, combination_error, r.pow_bits, @@ -553,7 +568,7 @@ where writeln!( f, "{:.1} bits -- (x{}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", - prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits as f64, + prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits, self.folding_factor, prox_gaps_error, sumcheck_error, @@ -571,7 +586,7 @@ where writeln!( f, "{:.1} bits -- query error: {:.1}, pow: {:.1}", - query_error + self.final_pow_bits as f64, + query_error + self.final_pow_bits, query_error, self.final_pow_bits, )?; @@ -581,7 +596,7 @@ where writeln!( f, "{:.1} bits -- (x{}) combination: {:.1}, pow: {:.1}", - combination_error + self.final_pow_bits as f64, + combination_error + self.final_pow_bits, self.final_sumcheck_rounds, combination_error, self.final_folding_pow_bits, diff --git a/src/whir/prover.rs b/src/whir/prover.rs index e4b4591..53264df 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -15,12 +15,11 @@ use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; use ark_poly::EvaluationDomain; use nimue::{ - plugins::ark::{FieldChallenges, FieldWriter}, - ByteChallenges, ByteWriter, Merlin, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, ByteWriter, Merlin, ProofResult, }; use nimue_pow::{self, PoWChallenge}; -use rand::{Rng, SeedableRng}; +use crate::whir::fs_utils::get_challenge_stir_queries; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -42,13 +41,27 @@ where } fn validate_statement(&self, statement: &Statement) -> bool { - statement + if statement.points.len() != statement.evaluations.len() { + return false; + } + if !statement .points .iter() .all(|point| point.0.len() == self.0.mv_parameters.num_variables) + { + return false; + } + if !self.0.initial_statement && !statement.points.is_empty() { + return false; + } + true } fn validate_witness(&self, witness: &Witness) -> bool { + assert_eq!(witness.ood_points.len(), witness.ood_answers.len()); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } witness.polynomial.num_variables() == self.0.mv_parameters.num_variables } @@ -65,7 +78,6 @@ where assert!(self.validate_statement(&statement)); assert!(self.validate_witness(&witness)); - let [combination_randomness_gen] = merlin.challenge_scalars()?; let initial_claims: Vec<_> = witness .ood_points .into_iter() @@ -77,26 +89,49 @@ where }) .chain(statement.points) .collect(); - let combination_randomness = - expand_randomness(combination_randomness_gen, initial_claims.len()); let initial_answers: Vec<_> = witness .ood_answers .into_iter() .chain(statement.evaluations) .collect(); - let mut sumcheck_prover = SumcheckProverNotSkipping::new( - witness.polynomial.clone(), - &initial_claims, - &combination_randomness, - &initial_answers, - ); + if !self.0.initial_statement { + assert!( + initial_answers.is_empty(), + "Can not have initial answers without initial statement" + ); + } - let folding_randomness = sumcheck_prover.compute_sumcheck_polynomials::( - merlin, - self.0.folding_factor, - self.0.starting_folding_pow_bits, - )?; + let mut sumcheck_prover = None; + let folding_randomness = if self.0.initial_statement { + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + + sumcheck_prover = Some(SumcheckProverNotSkipping::new( + witness.polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + + sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + self.0.starting_folding_pow_bits, + )? + } else { + let mut folding_randomness = vec![F::ZERO; self.0.folding_factor]; + merlin.fill_challenge_scalars(&mut folding_randomness)?; + + if self.0.starting_folding_pow_bits > 0. { + merlin.challenge_pow::(self.0.starting_folding_pow_bits)?; + } + MultilinearPoint(folding_randomness) + }; let round_state = RoundState { domain: self.0.starting_domain.clone(), @@ -131,13 +166,13 @@ where merlin.add_scalars(folded_coefficients.coeffs())?; // Final verifier queries and answers - let mut queries_seed = [0u8; 32]; - merlin.fill_challenge_bytes(&mut queries_seed)?; - let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed); - let final_challenge_indexes = utils::dedup((0..self.0.final_queries).map(|_| { - final_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor)) - })); - + let final_challenge_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor, + self.0.final_queries, + merlin, + )?; + let merkle_proof = round_state .prev_merkle .generate_multi_proof(final_challenge_indexes.clone()) @@ -157,13 +192,18 @@ where } // Final sumcheck - round_state - .sumcheck_prover - .compute_sumcheck_polynomials::( - merlin, - self.0.final_sumcheck_rounds, - self.0.final_folding_pow_bits, - )?; + if self.0.final_sumcheck_rounds > 0 { + round_state + .sumcheck_prover + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new(folded_coefficients.clone(), &[], &[], &[]) + }) + .compute_sumcheck_polynomials::( + merlin, + self.0.final_sumcheck_rounds, + self.0.final_folding_pow_bits, + )?; + } return Ok(WhirProof(round_state.merkle_proofs)); } @@ -214,13 +254,12 @@ where } // STIR queries - let mut stir_queries_seed = [0u8; 32]; - merlin.fill_challenge_bytes(&mut stir_queries_seed)?; - let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed); - let stir_challenges_indexes = - utils::dedup((0..round_params.num_queries).map(|_| { - stir_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor)) - })); + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor, + round_params.num_queries, + merlin, + )?; let domain_scaled_gen = round_state .domain .backing_domain @@ -288,24 +327,36 @@ where let combination_randomness = expand_randomness(combination_randomness_gen, stir_challenges.len()); - round_state.sumcheck_prover.add_new_equality( - &stir_challenges, - &combination_randomness, - &stir_evaluations, - ); - - let folding_randomness = round_state + let mut sumcheck_prover = round_state .sumcheck_prover - .compute_sumcheck_polynomials::( - merlin, - self.0.folding_factor, - round_params.folding_pow_bits, - )?; + .take() + .map(|mut sumcheck_prover| { + sumcheck_prover.add_new_equality( + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ); + sumcheck_prover + }) + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new( + folded_coefficients.clone(), + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ) + }); + + let folding_randomness = sumcheck_prover.compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + round_params.folding_pow_bits, + )?; let round_state = RoundState { round: round_state.round + 1, domain: new_domain, - sumcheck_prover: round_state.sumcheck_prover, + sumcheck_prover: Some(sumcheck_prover), folding_randomness, coefficients: folded_coefficients, // TODO: Is this redundant with `sumcheck_prover.coeff` ? prev_merkle: merkle_tree, @@ -324,7 +375,7 @@ where { round: usize, domain: Domain, - sumcheck_prover: SumcheckProverNotSkipping, + sumcheck_prover: Option>, folding_randomness: MultilinearPoint, coefficients: CoefficientList, prev_merkle: MerkleTree, diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index 2fb131b..77bf893 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -5,20 +5,19 @@ use ark_ff::FftField; use ark_poly::EvaluationDomain; use nimue::{ plugins::ark::{FieldChallenges, FieldReader}, - Arthur, ByteChallenges, ByteReader, ProofError, ProofResult, + Arthur, ByteReader, ProofError, ProofResult, }; use nimue_pow::{self, PoWChallenge}; -use rand::{Rng, SeedableRng}; +use super::{parameters::WhirConfig, Statement, WhirProof}; +use crate::whir::fs_utils::get_challenge_stir_queries; use crate::{ parameters::FoldType, poly_utils::{coeffs::CoefficientList, eq_poly_outside, fold::compute_fold, MultilinearPoint}, sumcheck::proof::SumcheckPolynomial, - utils::{self, expand_randomness}, + utils::{expand_randomness}, }; -use super::{parameters::WhirConfig, Statement, WhirProof}; - pub struct Verifier where F: FftField, @@ -104,28 +103,47 @@ where statement: &Statement, // Will be needed later whir_proof: &WhirProof, ) -> ProofResult> { - // Derive combination randomness and first sumcheck polynomial - let [combination_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; - let initial_combination_randomness = expand_randomness( - combination_randomness_gen, - parsed_commitment.ood_points.len() + statement.points.len(), - ); + let mut sumcheck_rounds = Vec::new(); + let mut folding_randomness: MultilinearPoint; + let initial_combination_randomness; + if self.params.initial_statement { + // Derive combination randomness and first sumcheck polynomial + let [combination_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + initial_combination_randomness = expand_randomness( + combination_randomness_gen, + parsed_commitment.ood_points.len() + statement.points.len(), + ); - // Initial sumcheck - let mut sumcheck_rounds = Vec::with_capacity(self.params.folding_factor); - for _ in 0..self.params.folding_factor { - let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; - let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); - let [folding_randomness_single] = arthur.challenge_scalars()?; - sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + // Initial sumcheck + sumcheck_rounds.reserve_exact(self.params.folding_factor); + for _ in 0..self.params.folding_factor { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + + folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + } else { + assert_eq!(parsed_commitment.ood_points.len(), 0); + assert_eq!(statement.points.len(), 0); + + initial_combination_randomness = vec![F::ONE]; + + let mut folding_randomness_vec = vec![F::ZERO; self.params.folding_factor]; + arthur.fill_challenge_scalars(&mut folding_randomness_vec)?; + folding_randomness = MultilinearPoint(folding_randomness_vec); + // PoW if self.params.starting_folding_pow_bits > 0. { arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; } - } - - let mut folding_randomness = - MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + }; let mut prev_root = parsed_commitment.root.clone(); let domain_gen = self.params.starting_domain.backing_domain.group_gen(); @@ -147,13 +165,13 @@ where arthur.fill_next_scalars(&mut ood_answers)?; } - let mut stir_queries_seed = [0u8; 32]; - arthur.fill_challenge_bytes(&mut stir_queries_seed)?; - let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed); - let folded_domain_size = domain_size / (1 << self.params.folding_factor); - let stir_challenges_indexes = utils::dedup( - (0..round_params.num_queries).map(|_| stir_gen.gen_range(0..folded_domain_size)), - ); + let stir_challenges_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor, + round_params.num_queries, + arthur, + )?; + let stir_challenges_points = stir_challenges_indexes .iter() .map(|index| exp_domain_gen.pow([*index as u64])) @@ -222,13 +240,12 @@ where let final_coefficients = CoefficientList::new(final_coefficients); // Final queries verify - let mut queries_seed = [0u8; 32]; - arthur.fill_challenge_bytes(&mut queries_seed)?; - let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed); - let folded_domain_size = domain_size / (1 << self.params.folding_factor); - let final_randomness_indexes = utils::dedup( - (0..self.params.final_queries).map(|_| final_gen.gen_range(0..folded_domain_size)), - ); + let final_randomness_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor, + self.params.final_queries, + arthur, + )?; let final_randomness_points = final_randomness_indexes .iter() .map(|index| exp_domain_gen.pow([*index as u64])) @@ -336,7 +353,7 @@ where .map(|(point, rand)| point * rand) .sum(); - value = value + sum_of_claims; + value += sum_of_claims; } value @@ -453,29 +470,35 @@ where let computed_folds = self.compute_folds(&parsed); - // Check the first polynomial - let (mut prev_poly, mut randomness) = parsed.initial_sumcheck_rounds[0].clone(); - if prev_poly.sum_over_hypercube() - != parsed_commitment - .ood_answers - .iter() - .copied() - .chain(statement.evaluations.clone()) - .zip(&parsed.initial_combination_randomness) - .map(|(ans, rand)| ans * rand) - .sum() - { - return Err(ProofError::InvalidProof); - } - - // Check the rest of the rounds - for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { - if sumcheck_poly.sum_over_hypercube() != prev_poly.evaluate_at_point(&randomness.into()) + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != parsed_commitment + .ood_answers + .iter() + .copied() + .chain(statement.evaluations.clone()) + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() { return Err(ProofError::InvalidProof); } - prev_poly = sumcheck_poly.clone(); - randomness = *new_randomness; + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); } for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { @@ -483,7 +506,12 @@ where let values = round.ood_answers.iter().copied().chain(folds.clone()); - let claimed_sum = prev_poly.evaluate_at_point(&randomness.into()) + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + values .zip(&round.combination_randomness) .map(|(val, rand)| val * rand) @@ -493,18 +521,17 @@ where return Err(ProofError::InvalidProof); } - prev_poly = sumcheck_poly.clone(); - randomness = *new_randomness; + prev = Some((sumcheck_poly.clone(), *new_randomness)); // Check the rest of the round for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); if sumcheck_poly.sum_over_hypercube() != prev_poly.evaluate_at_point(&randomness.into()) { return Err(ProofError::InvalidProof); } - prev_poly = sumcheck_poly.clone(); - randomness = *new_randomness; + prev = Some((sumcheck_poly.clone(), *new_randomness)); } } @@ -523,32 +550,42 @@ where // Check the final sumchecks if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); - let claimed_sum = prev_poly.evaluate_at_point(&randomness.into()); + let claimed_sum = prev_sumcheck_poly_eval; if sumcheck_poly.sum_over_hypercube() != claimed_sum { return Err(ProofError::InvalidProof); } - prev_poly = sumcheck_poly.clone(); - randomness = *new_randomness; + prev = Some((sumcheck_poly.clone(), *new_randomness)); // Check the rest of the round for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); if sumcheck_poly.sum_over_hypercube() != prev_poly.evaluate_at_point(&randomness.into()) { return Err(ProofError::InvalidProof); } - prev_poly = sumcheck_poly.clone(); - randomness = *new_randomness; + prev = Some((sumcheck_poly.clone(), *new_randomness)); } } + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + // Check the final sumcheck evaluation let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, statement, &parsed); - if prev_poly.evaluate_at_point(&randomness.into()) + if prev_sumcheck_poly_eval != evaluation_of_v_poly * parsed .final_coefficients diff --git a/src/whir_ldt/committer.rs b/src/whir_ldt/committer.rs deleted file mode 100644 index ec0c373..0000000 --- a/src/whir_ldt/committer.rs +++ /dev/null @@ -1,94 +0,0 @@ -use super::parameters::WhirConfig; -use crate::{ - ntt::expand_from_coeff, - poly_utils::{coeffs::CoefficientList, fold::restructure_evaluations}, - utils, -}; -use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; -use ark_ff::FftField; -use ark_poly::EvaluationDomain; -use nimue::{plugins::ark::FieldChallenges, ByteWriter, Merlin, ProofResult}; - -#[cfg(feature = "parallel")] -use rayon::prelude::*; - -pub struct Witness -where - MerkleConfig: Config, -{ - pub(crate) polynomial: CoefficientList, - pub(crate) merkle_tree: MerkleTree, - pub(crate) merkle_leaves: Vec, -} - -pub struct Committer(WhirConfig) -where - F: FftField, - MerkleConfig: Config; - -impl Committer -where - F: FftField, - MerkleConfig: Config, - MerkleConfig::InnerDigest: AsRef<[u8]>, -{ - pub fn new(config: WhirConfig) -> Self { - Self(config) - } - - pub fn commit( - &self, - merlin: &mut Merlin, - polynomial: CoefficientList, - ) -> ProofResult> - where - Merlin: FieldChallenges + ByteWriter, - { - let base_domain = self.0.starting_domain.base_domain.unwrap(); - let expansion = base_domain.size() / polynomial.num_coeffs(); - let evals = expand_from_coeff(polynomial.coeffs(), expansion); - // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. - // They also partially overlap and undo one another. We should merge them. - let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor); - let folded_evals = restructure_evaluations( - folded_evals, - self.0.fold_optimisation, - base_domain.group_gen(), - base_domain.group_gen_inv(), - self.0.folding_factor, - ); - - // Convert to extension field. - // This is not necessary for the commit, but in further rounds - // we will need the extension field. For symplicity we do it here too. - // TODO: Commit to base field directly. - let folded_evals = folded_evals - .into_iter() - .map(F::from_base_prime_field) - .collect::>(); - - // Group folds together as a leaf. - let fold_size = 1 << self.0.folding_factor; - #[cfg(not(feature = "parallel"))] - let leafs_iter = folded_evals.chunks_exact(fold_size); - #[cfg(feature = "parallel")] - let leafs_iter = folded_evals.par_chunks_exact(fold_size); - let merkle_tree = MerkleTree::::new( - &self.0.leaf_hash_params, - &self.0.two_to_one_params, - leafs_iter, - ) - .unwrap(); - - let root = merkle_tree.root(); - - merlin.add_bytes(root.as_ref())?; - - let polynomial = polynomial.to_extension(); - Ok(Witness { - polynomial, - merkle_tree, - merkle_leaves: folded_evals, - }) - } -} diff --git a/src/whir_ldt/iopattern.rs b/src/whir_ldt/iopattern.rs deleted file mode 100644 index 2c6df0f..0000000 --- a/src/whir_ldt/iopattern.rs +++ /dev/null @@ -1,63 +0,0 @@ -use ark_crypto_primitives::merkle_tree::Config; -use ark_ff::FftField; -use nimue::plugins::ark::*; - -use crate::{ - fs_utils::{OODIOPattern, WhirPoWIOPattern}, - sumcheck::prover_not_skipping::SumcheckNotSkippingIOPattern, -}; - -use super::parameters::WhirConfig; - -pub trait WhirIOPattern { - fn commit_statement( - self, - params: &WhirConfig, - ) -> Self; - fn add_whir_proof( - self, - params: &WhirConfig, - ) -> Self; -} - -impl WhirIOPattern for IOPattern -where - F: FftField, - IOPattern: ByteIOPattern - + FieldIOPattern - + SumcheckNotSkippingIOPattern - + WhirPoWIOPattern - + OODIOPattern, -{ - fn commit_statement( - self, - _params: &WhirConfig, - ) -> Self { - self.add_bytes(32, "merkle_digest") - } - - fn add_whir_proof( - mut self, - params: &WhirConfig, - ) -> Self { - // TODO: Add statement - self = self - .challenge_scalars(params.folding_factor, "folding_randomness") - .pow(params.starting_folding_pow_bits); - - for r in ¶ms.round_parameters { - self = self - .add_bytes(32, "merkle_digest") - .add_ood(r.ood_samples) - .challenge_bytes(32, "stir_queries_seed") - .pow(r.pow_bits) - .challenge_scalars(1, "combination_randomness") - .add_sumcheck(params.folding_factor, r.folding_pow_bits); - } - - self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") - .challenge_bytes(32, "final_queries_seed") - .pow(params.final_pow_bits) - .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) - } -} diff --git a/src/whir_ldt/mod.rs b/src/whir_ldt/mod.rs deleted file mode 100644 index e1aaa95..0000000 --- a/src/whir_ldt/mod.rs +++ /dev/null @@ -1,125 +0,0 @@ -use ark_crypto_primitives::merkle_tree::{Config, MultiPath}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; - -pub mod committer; -pub mod iopattern; -pub mod parameters; -pub mod prover; -pub mod verifier; - -// Only includes the authentication paths -#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct WhirProof(Vec<(MultiPath, Vec>)>) -where - MerkleConfig: Config, - F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize; - -pub fn whir_proof_size( - transcript: &[u8], - whir_proof: &WhirProof, -) -> usize -where - MerkleConfig: Config, - F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, -{ - transcript.len() + whir_proof.serialized_size(ark_serialize::Compress::Yes) -} - -#[cfg(test)] -mod tests { - use nimue::{DefaultHash, IOPattern}; - use nimue_pow::blake3::Blake3PoW; - - use crate::crypto::fields::Field64; - use crate::crypto::merkle_tree::blake3 as merkle_tree; - use crate::parameters::{FoldType, MultivariateParameters, SoundnessType, WhirParameters}; - use crate::poly_utils::coeffs::CoefficientList; - use crate::whir_ldt::{ - committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, - verifier::Verifier, - }; - - type MerkleConfig = merkle_tree::MerkleTreeParams; - type PowStrategy = Blake3PoW; - type F = Field64; - - fn make_whir_things( - num_variables: usize, - folding_factor: usize, - soundness_type: SoundnessType, - pow_bits: usize, - fold_type: FoldType, - ) { - let num_coeffs = 1 << num_variables; - - let mut rng = ark_std::test_rng(); - let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); - - let mv_params = MultivariateParameters::::new(num_variables); - - let whir_params = WhirParameters:: { - security_level: 32, - pow_bits, - folding_factor, - leaf_hash_params, - two_to_one_params, - fold_optimisation: fold_type, - soundness_type, - starting_log_inv_rate: 1, - _pow_parameters: Default::default(), - }; - - let params = WhirConfig::::new(mv_params, whir_params); - - let polynomial = CoefficientList::new(vec![F::from(1); num_coeffs]); - - let io = IOPattern::::new("🌪️") - .commit_statement(¶ms) - .add_whir_proof(¶ms) - .clone(); - - let mut merlin = io.to_merlin(); - - let committer = Committer::new(params.clone()); - let witness = committer.commit(&mut merlin, polynomial).unwrap(); - - let prover = Prover(params.clone()); - - let proof = prover.prove(&mut merlin, witness).unwrap(); - - let verifier = Verifier::new(params); - let mut arthur = io.to_arthur(merlin.transcript()); - assert!(verifier.verify(&mut arthur, &proof).is_ok()); - } - - #[test] - fn test_whir_ldt() { - let folding_factors = [1, 2, 3, 4]; - let fold_types = [FoldType::Naive, FoldType::ProverHelps]; - let soundness_type = [ - SoundnessType::ConjectureList, - SoundnessType::ProvableList, - SoundnessType::UniqueDecoding, - ]; - let pow_bits = [0, 5, 10]; - - for folding_factor in folding_factors { - let num_variables = folding_factor..=3 * folding_factor; - for num_variables in num_variables { - for fold_type in fold_types { - for soundness_type in soundness_type { - for pow_bits in pow_bits { - make_whir_things( - num_variables, - folding_factor, - soundness_type, - pow_bits, - fold_type, - ); - } - } - } - } - } - } -} diff --git a/src/whir_ldt/parameters.rs b/src/whir_ldt/parameters.rs deleted file mode 100644 index 2a579e6..0000000 --- a/src/whir_ldt/parameters.rs +++ /dev/null @@ -1,562 +0,0 @@ -use core::panic; -use std::{f64::consts::LOG2_10, fmt::Display, marker::PhantomData}; - -use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; -use ark_ff::FftField; - -use crate::{ - crypto::fields::FieldWithSize, - domain::Domain, - parameters::{FoldType, MultivariateParameters, SoundnessType, WhirParameters}, -}; - -#[derive(Clone)] -pub struct WhirConfig -where - F: FftField, - MerkleConfig: Config, -{ - pub(crate) mv_parameters: MultivariateParameters, - pub(crate) soundness_type: SoundnessType, - pub(crate) security_level: usize, - pub(crate) max_pow_bits: usize, - - pub(crate) starting_domain: Domain, - pub(crate) starting_log_inv_rate: usize, - pub(crate) starting_folding_pow_bits: f64, - - pub(crate) folding_factor: usize, - pub(crate) round_parameters: Vec, - pub(crate) fold_optimisation: FoldType, - - pub(crate) final_queries: usize, - pub(crate) final_pow_bits: f64, - pub(crate) final_log_inv_rate: usize, - pub(crate) final_sumcheck_rounds: usize, - pub(crate) final_folding_pow_bits: f64, - - // PoW parameters - pub(crate) pow_strategy: PhantomData, - - // Merkle tree parameters - pub(crate) leaf_hash_params: LeafParam, - pub(crate) two_to_one_params: TwoToOneParam, -} - -#[derive(Debug, Clone)] -pub(crate) struct RoundConfig { - pub(crate) pow_bits: f64, - pub(crate) folding_pow_bits: f64, - pub(crate) num_queries: usize, - pub(crate) ood_samples: usize, - pub(crate) log_inv_rate: usize, -} - -impl WhirConfig -where - F: FftField + FieldWithSize, - MerkleConfig: Config, -{ - pub fn new( - mv_parameters: MultivariateParameters, - whir_parameters: WhirParameters, - ) -> Self { - // We need to fold at least some time - assert!( - whir_parameters.folding_factor > 0, - "folding factor should be non zero" - ); - // If less, just send the damn polynomials - assert!(mv_parameters.num_variables >= whir_parameters.folding_factor); - let protocol_security_level = - 0.max(whir_parameters.security_level - whir_parameters.pow_bits); - - let starting_domain = Domain::new( - 1 << mv_parameters.num_variables, - whir_parameters.starting_log_inv_rate, - ) - .expect("Should have found an appropriate domain"); - - let final_sumcheck_rounds = mv_parameters.num_variables % whir_parameters.folding_factor; - let num_rounds = ((mv_parameters.num_variables - final_sumcheck_rounds) - / whir_parameters.folding_factor) - - 1; - - let field_size_bits = F::field_size_in_bits(); - - let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( - whir_parameters.soundness_type, - field_size_bits, - mv_parameters.num_variables, - whir_parameters.starting_log_inv_rate, - Self::log_eta(whir_parameters.starting_log_inv_rate), - ) + (whir_parameters.folding_factor as f64).log2(); - let starting_folding_pow_bits = - 0_f64.max(whir_parameters.security_level as f64 - prox_gaps_error); - - let mut round_parameters = Vec::with_capacity(num_rounds); - let mut num_variables = mv_parameters.num_variables - whir_parameters.folding_factor; - let mut log_inv_rate = whir_parameters.starting_log_inv_rate; - for _ in 0..num_rounds { - // Queries are set w.r.t. to old rate, while the rest to the new rate - let next_rate = log_inv_rate + (whir_parameters.folding_factor - 1); - - let log_next_eta = Self::log_eta(next_rate); - let num_queries = Self::queries( - whir_parameters.soundness_type, - protocol_security_level, - log_inv_rate, - ); - - let ood_samples = Self::ood_samples( - whir_parameters.security_level, - whir_parameters.soundness_type, - num_variables, - next_rate, - log_next_eta, - field_size_bits, - ); - - let query_error = - Self::rbr_queries(whir_parameters.soundness_type, log_inv_rate, num_queries); - let combination_error = Self::rbr_soundness_queries_combination( - whir_parameters.soundness_type, - field_size_bits, - num_variables, - next_rate, - log_next_eta, - ood_samples, - num_queries, - ); - - let pow_bits = 0_f64 - .max(whir_parameters.security_level as f64 - (query_error.min(combination_error))); - - let folding_pow_bits = Self::folding_pow_bits( - whir_parameters.security_level, - whir_parameters.soundness_type, - field_size_bits, - num_variables, - next_rate, - log_next_eta, - ); - - round_parameters.push(RoundConfig { - ood_samples, - num_queries, - pow_bits, - folding_pow_bits, - log_inv_rate, - }); - - num_variables -= whir_parameters.folding_factor; - log_inv_rate = next_rate; - } - - let final_queries = Self::queries( - whir_parameters.soundness_type, - protocol_security_level, - log_inv_rate, - ); - - let final_pow_bits = 0_f64.max( - whir_parameters.security_level as f64 - - Self::rbr_queries(whir_parameters.soundness_type, log_inv_rate, final_queries), - ); - - let final_folding_pow_bits = - 0_f64.max(whir_parameters.security_level as f64 - (field_size_bits - 1) as f64); - - WhirConfig { - security_level: whir_parameters.security_level, - max_pow_bits: whir_parameters.pow_bits, - mv_parameters, - starting_domain, - soundness_type: whir_parameters.soundness_type, - starting_log_inv_rate: whir_parameters.starting_log_inv_rate, - starting_folding_pow_bits, - folding_factor: whir_parameters.folding_factor, - round_parameters, - final_queries, - final_pow_bits, - final_sumcheck_rounds, - final_folding_pow_bits, - pow_strategy: PhantomData::default(), - fold_optimisation: whir_parameters.fold_optimisation, - final_log_inv_rate: log_inv_rate, - leaf_hash_params: whir_parameters.leaf_hash_params, - two_to_one_params: whir_parameters.two_to_one_params, - } - } - - pub fn n_rounds(&self) -> usize { - self.round_parameters.len() - } - - pub fn check_pow_bits(&self) -> bool { - [ - self.starting_folding_pow_bits, - self.final_pow_bits, - self.final_folding_pow_bits, - ] - .into_iter() - .all(|x| x <= self.max_pow_bits as f64) - && self.round_parameters.iter().all(|r| { - r.pow_bits <= self.max_pow_bits as f64 - && r.folding_pow_bits <= self.max_pow_bits as f64 - }) - } - - pub fn log_eta(log_inv_rate: usize) -> f64 { - -(log_inv_rate as f64 + LOG2_10) - } - - pub fn list_size_bits( - soundness_type: SoundnessType, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - ) -> f64 { - match soundness_type { - SoundnessType::ConjectureList => { - let result = (num_variables + log_inv_rate) as f64 - log_eta; - result - } - SoundnessType::ProvableList => { - let log_inv_sqrt_rate: f64 = log_inv_rate as f64 / 2.; - let result = log_inv_sqrt_rate - (1. + log_eta); - result - } - SoundnessType::UniqueDecoding => 0.0, - } - } - - pub fn rbr_ood_sample( - soundness_type: SoundnessType, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - field_size_bits: usize, - ood_samples: usize, - ) -> f64 { - let list_size_bits = - Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); - - let error = 2. * list_size_bits + (num_variables * ood_samples) as f64; - (ood_samples * field_size_bits) as f64 + 1. - error - } - - pub fn ood_samples( - security_level: usize, // We don't do PoW for OOD - soundness_type: SoundnessType, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - field_size_bits: usize, - ) -> usize { - if matches!(soundness_type, SoundnessType::UniqueDecoding) { - 0 - } else { - for ood_samples in 1..64 { - if Self::rbr_ood_sample( - soundness_type, - num_variables, - log_inv_rate, - log_eta, - field_size_bits, - ood_samples, - ) >= security_level as f64 - { - return ood_samples; - } - } - - panic!("Could not find an appropriate number of OOD samples"); - } - } - - // Compute the proximity gaps term of the fold - pub fn rbr_soundness_fold_prox_gaps( - soundness_type: SoundnessType, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - ) -> f64 { - // Recall, at each round we are only folding by two at a time - let error = match soundness_type { - SoundnessType::ConjectureList => (num_variables + log_inv_rate) as f64 - log_eta, - SoundnessType::ProvableList => { - LOG2_10 + 3.5 * log_inv_rate as f64 + 2. * num_variables as f64 - } - SoundnessType::UniqueDecoding => (num_variables + log_inv_rate) as f64, - }; - - field_size_bits as f64 - error - } - - pub fn rbr_soundness_fold_sumcheck( - soundness_type: SoundnessType, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - ) -> f64 { - let list_size = Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); - - field_size_bits as f64 - (list_size + 1.) - } - - pub fn folding_pow_bits( - security_level: usize, - soundness_type: SoundnessType, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - ) -> f64 { - let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( - soundness_type, - field_size_bits, - num_variables, - log_inv_rate, - log_eta, - ); - let sumcheck_error = Self::rbr_soundness_fold_sumcheck( - soundness_type, - field_size_bits, - num_variables, - log_inv_rate, - log_eta, - ); - - let error = prox_gaps_error.min(sumcheck_error); - - 0_f64.max(security_level as f64 - error) - } - - // Used to select the number of queries - pub fn queries( - soundness_type: SoundnessType, - protocol_security_level: usize, - log_inv_rate: usize, - ) -> usize { - let num_queries_f = match soundness_type { - SoundnessType::UniqueDecoding => { - let rate = 1. / ((1 << log_inv_rate) as f64); - let denom = (0.5 * (1. + rate)).log2(); - - -(protocol_security_level as f64) / denom - } - SoundnessType::ProvableList => { - (2 * protocol_security_level) as f64 / log_inv_rate as f64 - } - SoundnessType::ConjectureList => protocol_security_level as f64 / log_inv_rate as f64, - }; - num_queries_f.ceil() as usize - } - - // This is the bits of security of the query step - pub fn rbr_queries( - soundness_type: SoundnessType, - log_inv_rate: usize, - num_queries: usize, - ) -> f64 { - let num_queries = num_queries as f64; - let bits_of_sec_queries = match soundness_type { - SoundnessType::UniqueDecoding => { - let rate = 1. / ((1 << log_inv_rate) as f64); - let denom = -(0.5 * (1. + rate)).log2(); - - num_queries * denom - } - SoundnessType::ProvableList => num_queries * 0.5 * log_inv_rate as f64, - SoundnessType::ConjectureList => num_queries * log_inv_rate as f64, - }; - - bits_of_sec_queries - } - - pub fn rbr_soundness_queries_combination( - soundness_type: SoundnessType, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - log_eta: f64, - ood_samples: usize, - num_queries: usize, - ) -> f64 { - let list_size = Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); - - let log_combination = ((ood_samples + num_queries) as f64).log2(); - - field_size_bits as f64 - (log_combination + list_size + 1.) - } -} - -impl Display for WhirConfig -where - F: FftField, - MerkleConfig: Config, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.mv_parameters.fmt(f)?; - writeln!(f, ", folding factor: {}", self.folding_factor)?; - writeln!( - f, - "Security level: {} bits using {} security and {} bits of PoW", - self.security_level, self.soundness_type, self.max_pow_bits - )?; - - writeln!( - f, - "initial_folding_pow_bits: {}", - self.starting_folding_pow_bits - )?; - for r in &self.round_parameters { - r.fmt(f)?; - } - - writeln!( - f, - "final_queries: {}, final_rate: 2^-{}, final_pow_bits: {}, final_folding_pow_bits: {}", - self.final_queries, - self.final_log_inv_rate, - self.final_pow_bits, - self.final_folding_pow_bits, - )?; - - writeln!(f, "------------------------------------")?; - writeln!(f, "Round by round soundness analysis:")?; - writeln!(f, "------------------------------------")?; - - let field_size_bits = F::field_size_in_bits(); - let log_eta = Self::log_eta(self.starting_log_inv_rate); - let mut num_variables = self.mv_parameters.num_variables; - - let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( - self.soundness_type, - field_size_bits, - num_variables, - self.starting_log_inv_rate, - log_eta, - ) - (self.folding_factor as f64).log2(); - - writeln!( - f, - "{:.1} bits -- prox gaps: {:.1}, pow: {:.1}", - prox_gaps_error + self.starting_folding_pow_bits as f64, - prox_gaps_error, - self.starting_folding_pow_bits, - )?; - - num_variables -= self.folding_factor; - - for r in &self.round_parameters { - let next_rate = r.log_inv_rate + (self.folding_factor - 1); - let log_eta = Self::log_eta(next_rate); - - if r.ood_samples > 0 { - writeln!( - f, - "{:.1} bits -- OOD sample", - Self::rbr_ood_sample( - self.soundness_type, - num_variables, - next_rate, - log_eta, - field_size_bits, - r.ood_samples - ) - )?; - } - - let query_error = Self::rbr_queries(self.soundness_type, r.log_inv_rate, r.num_queries); - let combination_error = Self::rbr_soundness_queries_combination( - self.soundness_type, - field_size_bits, - num_variables, - next_rate, - log_eta, - r.ood_samples, - r.num_queries, - ); - writeln!( - f, - "{:.1} bits -- query error: {:.1}, combination: {:.1}, pow: {:.1}", - query_error.min(combination_error) + r.pow_bits as f64, - query_error, - combination_error, - r.pow_bits, - )?; - - let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( - self.soundness_type, - field_size_bits, - num_variables, - next_rate, - log_eta, - ); - let sumcheck_error = Self::rbr_soundness_fold_sumcheck( - self.soundness_type, - field_size_bits, - num_variables, - next_rate, - log_eta, - ); - - writeln!( - f, - "{:.1} bits -- (x{}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", - prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits as f64, - self.folding_factor, - prox_gaps_error, - sumcheck_error, - r.folding_pow_bits, - )?; - - num_variables -= self.folding_factor; - } - - let query_error = Self::rbr_queries( - self.soundness_type, - self.final_log_inv_rate, - self.final_queries, - ); - writeln!( - f, - "{:.1} bits -- query error: {:.1}, pow: {:.1}", - query_error + self.final_pow_bits as f64, - query_error, - self.final_pow_bits, - )?; - - if self.final_sumcheck_rounds > 0 { - let combination_error = field_size_bits as f64 - 1.; - writeln!( - f, - "{:.1} bits -- (x{}) combination: {:.1}, pow: {:.1}", - combination_error + self.final_pow_bits as f64, - self.final_sumcheck_rounds, - combination_error, - self.final_folding_pow_bits, - )?; - } - - Ok(()) - } -} - -impl Display for RoundConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!( - f, - "Num_queries: {}, rate: 2^-{}, pow_bits: {}, ood_samples: {}, folding_pow: {}", - self.num_queries, - self.log_inv_rate, - self.pow_bits, - self.ood_samples, - self.folding_pow_bits, - ) - } -} diff --git a/src/whir_ldt/prover.rs b/src/whir_ldt/prover.rs deleted file mode 100644 index f011060..0000000 --- a/src/whir_ldt/prover.rs +++ /dev/null @@ -1,319 +0,0 @@ -use super::{committer::Witness, parameters::WhirConfig, WhirProof}; -use crate::{ - domain::Domain, - ntt::expand_from_coeff, - parameters::FoldType, - poly_utils::{ - coeffs::CoefficientList, - fold::{compute_fold, restructure_evaluations}, - MultilinearPoint, - }, - sumcheck::prover_not_skipping::SumcheckProverNotSkipping, - utils::{self, expand_randomness}, -}; -use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; -use ark_ff::FftField; -use ark_poly::EvaluationDomain; -use nimue::{ - plugins::ark::{FieldChallenges, FieldWriter}, - ByteChallenges, ByteWriter, Merlin, ProofResult, -}; -use nimue_pow::{self, PoWChallenge}; -use rand::{Rng, SeedableRng}; - -#[cfg(feature = "parallel")] -use rayon::prelude::*; - -pub struct Prover(pub WhirConfig) -where - F: FftField, - MerkleConfig: Config; - -impl Prover -where - F: FftField, - MerkleConfig: Config, - MerkleConfig::InnerDigest: AsRef<[u8]>, - PowStrategy: nimue_pow::PowStrategy, -{ - fn validate_parameters(&self) -> bool { - self.0.mv_parameters.num_variables - == (self.0.n_rounds() + 1) * self.0.folding_factor + self.0.final_sumcheck_rounds - } - - fn validate_witness(&self, witness: &Witness) -> bool { - witness.polynomial.num_variables() == self.0.mv_parameters.num_variables - } - - pub fn prove( - &self, - merlin: &mut Merlin, - witness: Witness, - ) -> ProofResult> - where - Merlin: FieldChallenges + ByteWriter, - { - assert!(self.validate_parameters()); - assert!(self.validate_witness(&witness)); - - let mut folding_randomness = vec![F::ZERO; self.0.folding_factor]; - merlin.fill_challenge_scalars(&mut folding_randomness)?; - let folding_randomness = MultilinearPoint(folding_randomness); - - // PoW - if self.0.starting_folding_pow_bits > 0. { - merlin.challenge_pow::(self.0.starting_folding_pow_bits)?; - } - - let round_state = RoundState { - domain: self.0.starting_domain.clone(), - round: 0, - sumcheck_prover: None, - folding_randomness, - coefficients: witness.polynomial, - prev_merkle: witness.merkle_tree, - prev_merkle_answers: witness.merkle_leaves, - merkle_proofs: vec![], - }; - - self.round(merlin, round_state) - } - - fn round( - &self, - merlin: &mut Merlin, - mut round_state: RoundState, - ) -> ProofResult> { - // Fold the coefficients - let folded_coefficients = round_state - .coefficients - .fold(&round_state.folding_randomness); - - let num_variables = - self.0.mv_parameters.num_variables - (round_state.round + 1) * self.0.folding_factor; - - // Base case - if round_state.round == self.0.n_rounds() { - // Coefficients of the polynomial - merlin.add_scalars(folded_coefficients.coeffs())?; - - // Final verifier queries and answers - let mut queries_seed = [0u8; 32]; - merlin.fill_challenge_bytes(&mut queries_seed)?; - let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed); - let final_challenge_indexes = utils::dedup((0..self.0.final_queries).map(|_| { - final_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor)) - })); - - let merkle_proof = round_state - .prev_merkle - .generate_multi_proof(final_challenge_indexes.clone()) - .unwrap(); - let fold_size = 1 << self.0.folding_factor; - let answers = final_challenge_indexes - .into_iter() - .map(|i| { - round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec() - }) - .collect(); - round_state.merkle_proofs.push((merkle_proof, answers)); - - // PoW - if self.0.final_pow_bits > 0. { - merlin.challenge_pow::(self.0.final_pow_bits)?; - } - - // Final sumcheck - if self.0.final_sumcheck_rounds > 0 { - round_state - .sumcheck_prover - .unwrap_or_else(|| { - SumcheckProverNotSkipping::new(folded_coefficients.clone(), &[], &[], &[]) - }) - .compute_sumcheck_polynomials::( - merlin, - self.0.final_sumcheck_rounds, - self.0.final_folding_pow_bits, - )?; - } - - return Ok(WhirProof(round_state.merkle_proofs)); - } - - let round_params = &self.0.round_parameters[round_state.round]; - - // Fold the coefficients, and compute fft of polynomial (and commit) - let new_domain = round_state.domain.scale(2); - let expansion = new_domain.size() / folded_coefficients.num_coeffs(); - let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); - // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. - // They also partially overlap and undo one another. We should merge them. - let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor); - let folded_evals = restructure_evaluations( - folded_evals, - self.0.fold_optimisation, - new_domain.backing_domain.group_gen(), - new_domain.backing_domain.group_gen_inv(), - self.0.folding_factor, - ); - - #[cfg(not(feature = "parallel"))] - let leaf_iter = folded_evals.chunks_exact(1 << self.0.folding_factor); - - #[cfg(feature = "parallel")] - let leaf_iter = folded_evals.par_chunks_exact(1 << self.0.folding_factor); - - let merkle_tree = MerkleTree::::new( - &self.0.leaf_hash_params, - &self.0.two_to_one_params, - leaf_iter, - ) - .unwrap(); - - let root = merkle_tree.root(); - merlin.add_bytes(root.as_ref())?; - - // OOD Samples - let mut ood_points = vec![F::ZERO; round_params.ood_samples]; - let mut ood_answers = Vec::with_capacity(round_params.ood_samples); - if round_params.ood_samples > 0 { - merlin.fill_challenge_scalars(&mut ood_points)?; - ood_answers.extend(ood_points.iter().map(|ood_point| { - folded_coefficients.evaluate(&MultilinearPoint::expand_from_univariate( - *ood_point, - num_variables, - )) - })); - merlin.add_scalars(&ood_answers)?; - } - - // STIR queries - let mut stir_queries_seed = [0u8; 32]; - merlin.fill_challenge_bytes(&mut stir_queries_seed)?; - let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed); - let stir_challenges_indexes = - utils::dedup((0..round_params.num_queries).map(|_| { - stir_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor)) - })); - let domain_scaled_gen = round_state - .domain - .backing_domain - .element(1 << self.0.folding_factor); - let stir_challenges: Vec<_> = ood_points - .into_iter() - .chain( - stir_challenges_indexes - .iter() - .map(|i| domain_scaled_gen.pow([*i as u64])), - ) - .map(|univariate| MultilinearPoint::expand_from_univariate(univariate, num_variables)) - .collect(); - - let merkle_proof = round_state - .prev_merkle - .generate_multi_proof(stir_challenges_indexes.clone()) - .unwrap(); - let fold_size = 1 << self.0.folding_factor; - let answers: Vec<_> = stir_challenges_indexes - .iter() - .map(|i| round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) - .collect(); - // Evaluate answers in the folding randomness. - let mut stir_evaluations = ood_answers.clone(); - match self.0.fold_optimisation { - FoldType::Naive => { - // See `Verifier::compute_folds_full` - let domain_size = round_state.domain.backing_domain.size(); - let domain_gen = round_state.domain.backing_domain.element(1); - let domain_gen_inv = domain_gen.inverse().unwrap(); - let coset_domain_size = 1 << self.0.folding_factor; - let coset_generator_inv = - domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); - stir_evaluations.extend(stir_challenges_indexes.iter().zip(&answers).map( - |(index, answers)| { - // The coset is w^index * - //let _coset_offset = domain_gen.pow(&[*index as u64]); - let coset_offset_inv = domain_gen_inv.pow([*index as u64]); - - compute_fold( - answers, - &round_state.folding_randomness.0, - coset_offset_inv, - coset_generator_inv, - F::from(2).inverse().unwrap(), - self.0.folding_factor, - ) - }, - )) - } - FoldType::ProverHelps => stir_evaluations.extend(answers.iter().map(|answers| { - CoefficientList::new(answers.to_vec()).evaluate(&round_state.folding_randomness) - })), - } - round_state.merkle_proofs.push((merkle_proof, answers)); - - // PoW - if round_params.pow_bits > 0. { - merlin.challenge_pow::(round_params.pow_bits)?; - } - - // Randomness for combination - let [combination_randomness_gen] = merlin.challenge_scalars()?; - let combination_randomness = - expand_randomness(combination_randomness_gen, stir_challenges.len()); - - let mut sumcheck_prover = round_state - .sumcheck_prover - .take() - .map(|mut sumcheck_prover| { - sumcheck_prover.add_new_equality( - &stir_challenges, - &combination_randomness, - &stir_evaluations, - ); - sumcheck_prover - }) - .unwrap_or_else(|| { - SumcheckProverNotSkipping::new( - folded_coefficients.clone(), - &stir_challenges, - &combination_randomness, - &stir_evaluations, - ) - }); - - let folding_randomness = sumcheck_prover.compute_sumcheck_polynomials::( - merlin, - self.0.folding_factor, - round_params.folding_pow_bits, - )?; - - let round_state = RoundState { - round: round_state.round + 1, - domain: new_domain, - sumcheck_prover: Some(sumcheck_prover), - folding_randomness, - coefficients: folded_coefficients, - prev_merkle: merkle_tree, - prev_merkle_answers: folded_evals, - merkle_proofs: round_state.merkle_proofs, - }; - - self.round(merlin, round_state) - } -} - -struct RoundState -where - F: FftField, - MerkleConfig: Config, -{ - round: usize, - domain: Domain, - sumcheck_prover: Option>, - folding_randomness: MultilinearPoint, - coefficients: CoefficientList, - prev_merkle: MerkleTree, - prev_merkle_answers: Vec, - merkle_proofs: Vec<(MultiPath, Vec>)>, -} diff --git a/src/whir_ldt/verifier.rs b/src/whir_ldt/verifier.rs deleted file mode 100644 index e5e47f8..0000000 --- a/src/whir_ldt/verifier.rs +++ /dev/null @@ -1,507 +0,0 @@ -use std::iter; - -use ark_crypto_primitives::merkle_tree::Config; -use ark_ff::FftField; -use ark_poly::EvaluationDomain; -use nimue::{ - plugins::ark::{FieldChallenges, FieldReader}, - Arthur, ByteChallenges, ByteReader, ProofError, ProofResult, -}; -use nimue_pow::{self, PoWChallenge}; -use rand::{Rng, SeedableRng}; - -use crate::{ - parameters::FoldType, - poly_utils::{coeffs::CoefficientList, eq_poly_outside, fold::compute_fold, MultilinearPoint}, - sumcheck::proof::SumcheckPolynomial, - utils::{self, expand_randomness}, -}; - -use super::{parameters::WhirConfig, WhirProof}; - -pub struct Verifier -where - F: FftField, - MerkleConfig: Config, -{ - params: WhirConfig, - two_inv: F, -} - -#[derive(Clone)] -struct ParsedCommitment { - root: D, -} - -#[derive(Clone)] -struct ParsedProof { - rounds: Vec>, - final_domain_gen_inv: F, - final_randomness_indexes: Vec, - final_randomness_points: Vec, - final_randomness_answers: Vec>, - final_folding_randomness: MultilinearPoint, - final_sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, - final_sumcheck_randomness: MultilinearPoint, - final_coefficients: CoefficientList, -} - -#[derive(Debug, Clone)] -struct ParsedRound { - folding_randomness: MultilinearPoint, - ood_points: Vec, - ood_answers: Vec, - stir_challenges_indexes: Vec, - stir_challenges_points: Vec, - stir_challenges_answers: Vec>, - combination_randomness: Vec, - sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, - domain_gen_inv: F, -} - -impl Verifier -where - F: FftField, - MerkleConfig: Config, - MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, - PowStrategy: nimue_pow::PowStrategy, -{ - pub fn new(params: WhirConfig) -> Self { - Verifier { - params, - two_inv: F::from(2).inverse().unwrap(), // The only inverse in the entire code :) - } - } - - fn parse_commitment( - &self, - arthur: &mut Arthur, - ) -> ProofResult> { - let root: [u8; 32] = arthur.next_bytes()?; - - Ok(ParsedCommitment { root: root.into() }) - } - - fn parse_proof( - &self, - arthur: &mut Arthur, - parsed_commitment: &ParsedCommitment, - whir_proof: &WhirProof, - ) -> ProofResult> { - // Derive initial combination randomness - let mut folding_randomness = vec![F::ZERO; self.params.folding_factor]; - arthur.fill_challenge_scalars(&mut folding_randomness)?; - let mut folding_randomness = MultilinearPoint(folding_randomness); - - // PoW - if self.params.starting_folding_pow_bits > 0. { - arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; - } - - let mut prev_root = parsed_commitment.root.clone(); - let domain_gen = self.params.starting_domain.backing_domain.group_gen(); - let mut exp_domain_gen = domain_gen.pow([1 << self.params.folding_factor]); - let mut domain_gen_inv = self.params.starting_domain.backing_domain.group_gen_inv(); - let mut domain_size = self.params.starting_domain.size(); - let mut rounds = vec![]; - - for r in 0..self.params.n_rounds() { - let (merkle_proof, answers) = &whir_proof.0[r]; - let round_params = &self.params.round_parameters[r]; - - let new_root: [u8; 32] = arthur.next_bytes()?; - - let mut ood_points = vec![F::ZERO; round_params.ood_samples]; - let mut ood_answers = vec![F::ZERO; round_params.ood_samples]; - if round_params.ood_samples > 0 { - arthur.fill_challenge_scalars(&mut ood_points)?; - arthur.fill_next_scalars(&mut ood_answers)?; - } - - let mut stir_queries_seed = [0u8; 32]; - arthur.fill_challenge_bytes(&mut stir_queries_seed)?; - let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed); - let folded_domain_size = domain_size / (1 << self.params.folding_factor); - let stir_challenges_indexes = utils::dedup( - (0..round_params.num_queries).map(|_| stir_gen.gen_range(0..folded_domain_size)), - ); - let stir_challenges_points = stir_challenges_indexes - .iter() - .map(|index| exp_domain_gen.pow([*index as u64])) - .collect(); - - if !merkle_proof - .verify( - &self.params.leaf_hash_params, - &self.params.two_to_one_params, - &prev_root, - answers.iter().map(|a| a.as_ref()), - ) - .unwrap() - || merkle_proof.leaf_indexes != stir_challenges_indexes - { - return Err(ProofError::InvalidProof); - } - - if round_params.pow_bits > 0. { - arthur.challenge_pow::(round_params.pow_bits)?; - } - - let [combination_randomness_gen] = arthur.challenge_scalars()?; - let combination_randomness = expand_randomness( - combination_randomness_gen, - stir_challenges_indexes.len() + round_params.ood_samples, - ); - - let mut sumcheck_rounds = Vec::with_capacity(self.params.folding_factor); - for _ in 0..self.params.folding_factor { - let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; - let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); - let [folding_randomness_single] = arthur.challenge_scalars()?; - sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); - - if round_params.folding_pow_bits > 0. { - arthur.challenge_pow::(round_params.folding_pow_bits)?; - } - } - - let new_folding_randomness = - MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); - - rounds.push(ParsedRound { - folding_randomness, - ood_points, - ood_answers, - stir_challenges_indexes, - stir_challenges_points, - stir_challenges_answers: answers.to_vec(), - combination_randomness, - sumcheck_rounds, - domain_gen_inv, - }); - - folding_randomness = new_folding_randomness; - - prev_root = new_root.into(); - exp_domain_gen = exp_domain_gen * exp_domain_gen; - domain_gen_inv = domain_gen_inv * domain_gen_inv; - domain_size /= 2; - } - - let mut final_coefficients = vec![F::ZERO; 1 << self.params.final_sumcheck_rounds]; - arthur.fill_next_scalars(&mut final_coefficients)?; - let final_coefficients = CoefficientList::new(final_coefficients); - - // Final queries verify - let mut queries_seed = [0u8; 32]; - arthur.fill_challenge_bytes(&mut queries_seed)?; - let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed); - let folded_domain_size = domain_size / (1 << self.params.folding_factor); - let final_randomness_indexes = utils::dedup( - (0..self.params.final_queries).map(|_| final_gen.gen_range(0..folded_domain_size)), - ); - let final_randomness_points = final_randomness_indexes - .iter() - .map(|index| exp_domain_gen.pow([*index as u64])) - .collect(); - - let (final_merkle_proof, final_randomness_answers) = &whir_proof.0[whir_proof.0.len() - 1]; - if !final_merkle_proof - .verify( - &self.params.leaf_hash_params, - &self.params.two_to_one_params, - &prev_root, - final_randomness_answers.iter().map(|a| a.as_ref()), - ) - .unwrap() - || final_merkle_proof.leaf_indexes != final_randomness_indexes - { - return Err(ProofError::InvalidProof); - } - - if self.params.final_pow_bits > 0. { - arthur.challenge_pow::(self.params.final_pow_bits)?; - } - - let mut final_sumcheck_rounds = Vec::with_capacity(self.params.final_sumcheck_rounds); - for _ in 0..self.params.final_sumcheck_rounds { - let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; - let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); - let [folding_randomness_single] = arthur.challenge_scalars()?; - final_sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); - - if self.params.final_folding_pow_bits > 0. { - arthur.challenge_pow::(self.params.final_folding_pow_bits)?; - } - } - let final_sumcheck_randomness = MultilinearPoint( - final_sumcheck_rounds - .iter() - .map(|&(_, r)| r) - .rev() - .collect(), - ); - - Ok(ParsedProof { - rounds, - final_domain_gen_inv: domain_gen_inv, - final_folding_randomness: folding_randomness, - final_randomness_indexes, - final_randomness_answers: final_randomness_answers.to_vec(), - final_coefficients, - final_randomness_points, - final_sumcheck_rounds, - final_sumcheck_randomness, - }) - } - - fn compute_v_poly(&self, proof: &ParsedProof) -> F { - let mut num_variables = self.params.mv_parameters.num_variables; - - let mut folding_randomness = MultilinearPoint( - iter::once(&proof.final_sumcheck_randomness.0) - .chain(iter::once(&proof.final_folding_randomness.0)) - .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) - .flatten() - .copied() - .collect(), - ); - - let mut value = F::ZERO; - - for round_proof in &proof.rounds { - num_variables -= self.params.folding_factor; - folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); - - let ood_points = &round_proof.ood_points; - let stir_challenges_points = &round_proof.stir_challenges_points; - let stir_challenges: Vec<_> = ood_points - .iter() - .chain(stir_challenges_points) - .cloned() - .map(|univariate| { - MultilinearPoint::expand_from_univariate(univariate, num_variables) - }) - .collect(); - - let sum_of_claims: F = stir_challenges - .into_iter() - .map(|point| eq_poly_outside(&point, &folding_randomness)) - .zip(&round_proof.combination_randomness) - .map(|(point, rand)| point * rand) - .sum(); - - value = value + sum_of_claims; - } - - value - } - - fn compute_folds(&self, parsed: &ParsedProof) -> Vec> { - match self.params.fold_optimisation { - FoldType::Naive => self.compute_folds_full(parsed), - FoldType::ProverHelps => self.compute_folds_helped(parsed), - } - } - - fn compute_folds_full(&self, parsed: &ParsedProof) -> Vec> { - let mut domain_size = self.params.starting_domain.backing_domain.size(); - let coset_domain_size = 1 << self.params.folding_factor; - - let mut result = Vec::new(); - - for round in &parsed.rounds { - // This is such that coset_generator^coset_domain_size = F::ONE - //let _coset_generator = domain_gen.pow(&[(domain_size / coset_domain_size) as u64]); - let coset_generator_inv = round - .domain_gen_inv - .pow([(domain_size / coset_domain_size) as u64]); - - let evaluations: Vec<_> = round - .stir_challenges_indexes - .iter() - .zip(&round.stir_challenges_answers) - .map(|(index, answers)| { - // The coset is w^index * - //let _coset_offset = domain_gen.pow(&[*index as u64]); - let coset_offset_inv = round.domain_gen_inv.pow([*index as u64]); - - compute_fold( - answers, - &round.folding_randomness.0, - coset_offset_inv, - coset_generator_inv, - self.two_inv, - self.params.folding_factor, - ) - }) - .collect(); - result.push(evaluations); - domain_size /= 2; - } - - let domain_gen_inv = parsed.final_domain_gen_inv; - - // Final round - let coset_generator_inv = domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); - let evaluations: Vec<_> = parsed - .final_randomness_indexes - .iter() - .zip(&parsed.final_randomness_answers) - .map(|(index, answers)| { - // The coset is w^index * - //let _coset_offset = domain_gen.pow(&[*index as u64]); - let coset_offset_inv = domain_gen_inv.pow([*index as u64]); - - compute_fold( - answers, - &parsed.final_folding_randomness.0, - coset_offset_inv, - coset_generator_inv, - self.two_inv, - self.params.folding_factor, - ) - }) - .collect(); - result.push(evaluations); - - result - } - - fn compute_folds_helped(&self, parsed: &ParsedProof) -> Vec> { - let mut result = Vec::new(); - - for round in &parsed.rounds { - let evaluations: Vec<_> = round - .stir_challenges_answers - .iter() - .map(|answers| { - CoefficientList::new(answers.to_vec()).evaluate(&round.folding_randomness) - }) - .collect(); - result.push(evaluations); - } - - // Final round - let evaluations: Vec<_> = parsed - .final_randomness_answers - .iter() - .map(|answers| { - CoefficientList::new(answers.to_vec()).evaluate(&parsed.final_folding_randomness) - }) - .collect(); - result.push(evaluations); - - result - } - - pub fn verify( - &self, - arthur: &mut Arthur, - whir_proof: &WhirProof, - ) -> ProofResult<()> { - // We first do a pass in which we rederive all the FS challenges - // Then we will check the algebraic part (so to optimise inversions) - let parsed_commitment = self.parse_commitment(arthur)?; - let parsed = self.parse_proof(arthur, &parsed_commitment, whir_proof)?; - - let computed_folds = self.compute_folds(&parsed); - - let mut prev: Option<(SumcheckPolynomial, F)> = None; - for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { - let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); - - let values = round.ood_answers.iter().copied().chain(folds.clone()); - - let prev_eval = if let Some((prev_poly, randomness)) = prev { - prev_poly.evaluate_at_point(&randomness.into()) - } else { - F::ZERO - }; - let claimed_sum = prev_eval - + values - .zip(&round.combination_randomness) - .map(|(val, rand)| val * rand) - .sum::(); - - if sumcheck_poly.sum_over_hypercube() != claimed_sum { - return Err(ProofError::InvalidProof); - } - - prev = Some((sumcheck_poly.clone(), *new_randomness)); - - // Check the rest of the round - for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { - let (prev_poly, randomness) = prev.unwrap(); - if sumcheck_poly.sum_over_hypercube() - != prev_poly.evaluate_at_point(&randomness.into()) - { - return Err(ProofError::InvalidProof); - } - prev = Some((sumcheck_poly.clone(), *new_randomness)); - } - } - - // Check the foldings computed from the proof match the evaluations of the polynomial - let final_folds = &computed_folds[computed_folds.len() - 1]; - let final_evaluations = parsed - .final_coefficients - .evaluate_at_univariate(&parsed.final_randomness_points); - if !final_folds - .iter() - .zip(final_evaluations) - .all(|(&fold, eval)| fold == eval) - { - return Err(ProofError::InvalidProof); - } - - // Check the final sumchecks - if self.params.final_sumcheck_rounds > 0 { - let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { - prev_poly.evaluate_at_point(&randomness.into()) - } else { - F::ZERO - }; - - let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); - - let claimed_sum = prev_sumcheck_poly_eval; - if sumcheck_poly.sum_over_hypercube() != claimed_sum { - return Err(ProofError::InvalidProof); - } - - prev = Some((sumcheck_poly.clone(), *new_randomness)); - - // Check the rest of the round - for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { - let (prev_poly, randomness) = prev.unwrap(); - if sumcheck_poly.sum_over_hypercube() - != prev_poly.evaluate_at_point(&randomness.into()) - { - return Err(ProofError::InvalidProof); - } - prev = Some((sumcheck_poly.clone(), *new_randomness)); - } - } - - let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { - prev_poly.evaluate_at_point(&randomness.into()) - } else { - F::ZERO - }; - - // Check the final sumcheck evaluation - let evaluation_of_v_poly = self.compute_v_poly(&parsed); - - if prev_sumcheck_poly_eval - != evaluation_of_v_poly - * parsed - .final_coefficients - .evaluate(&parsed.final_sumcheck_randomness) - { - return Err(ProofError::InvalidProof); - } - - Ok(()) - } -}