diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 1ab14e3f1..ec634df94 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -32,8 +32,10 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; -use p3::field::FieldAlgebra; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, +}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, @@ -65,6 +67,7 @@ impl CpuEccProver { pub fn create_ecc_proof<'a, E: ExtensionField>( &self, + num_instances: usize, mut xs: Vec>, mut ys: Vec>, invs: Vec>, @@ -78,17 +81,39 @@ impl CpuEccProver { let out_rt = transcript.sample_and_append_vec(b"ecc", n); let num_threads = optimal_sumcheck_threads(out_rt.len()); - let alpha_pows = - transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + // expression with add (3 zero constrains) and bypass (2 zero constrains) + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); - let sel = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); - let num_instances = (1 << n) - 1; - let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, num_instances).unwrap(); - let sel_expr = expr_builder.lift(sel_mle.to_either()); + let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into()); + let mut sel_add_mle: MultilinearExtension<'_, E> = + sel_add.compute(&out_rt, num_instances).unwrap(); + // we construct sel_bypass witness here + // verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot` + let mut sel_bypass_mle: Vec = build_eq_x_r_vec(&out_rt); + match sel_add_mle.evaluations() { + FieldType::Ext(sel_add_mle) => sel_add_mle + .par_iter() + .zip_eq(sel_bypass_mle.par_iter_mut()) + .for_each(|(sel_add, sel_bypass)| { + if *sel_add != E::ZERO { + *sel_bypass = E::ZERO; + } + }), + _ => unreachable!(), + } + *sel_bypass_mle.last_mut().unwrap() = E::ZERO; + let mut sel_bypass_mle = sel_bypass_mle.into_mle(); + let sel_add_expr = expr_builder.lift(sel_add_mle.to_either()); + let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either()); - let mut exprs = vec![]; + let mut exprs_add = vec![]; + let mut exprs_bypass = vec![]; let filter_bj = |v: &[MultilinearExtension<'_, E>], j: usize| { v.iter() @@ -157,43 +182,58 @@ impl CpuEccProver { ); // affine addition // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) - exprs.extend( + exprs_add.extend( (s.clone() * (&x0 - &x1) - (&y0 - &y1)) .to_exprs() .into_iter() - .zip(alpha_pows.iter().take(SEPTIC_EXTENSION_DEGREE)) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) - exprs.extend( + exprs_add.extend( ((&s * &s) - &x0 - &x1 - &x3) .to_exprs() .into_iter() - .zip( - alpha_pows[SEPTIC_EXTENSION_DEGREE..] - .iter() - .take(SEPTIC_EXTENSION_DEGREE), - ) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) - exprs.extend( + exprs_add.extend( (s.clone() * (&x0 - &x3) - (&y0 + &y3)) .to_exprs() .into_iter() - .zip( - alpha_pows[SEPTIC_EXTENSION_DEGREE * 2..] - .iter() - .take(SEPTIC_EXTENSION_DEGREE), - ) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + let exprs_add = exprs_add.into_iter().sum::>() * sel_add_expr; + + // deal with bypass + // 0 = (x[1,b] - x[b,0]) + exprs_bypass.extend( + (&x3 - &x0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); + // 0 = (y[1,b] - y[b,0]) + exprs_bypass.extend( + (&y3 - &y0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + assert!(alpha_pows_iter.next().is_none()); + + let exprs_bypass = exprs_bypass.into_iter().sum::>() * sel_bypass_expr; + let (zerocheck_proof, state) = IOPProverState::prove( - expr_builder - .to_virtual_polys(&[exprs.into_iter().sum::>() * sel_expr], &[]), + expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]), transcript, ); @@ -202,10 +242,11 @@ impl CpuEccProver { assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt] - assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7); + assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); #[cfg(feature = "sanity-check")] { + let last_evaluation_index = (1 << n) - 1; let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); let x0 = filter_bj(&xs, 0); let y0 = filter_bj(&ys, 0); @@ -214,11 +255,11 @@ impl CpuEccProver { let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); let final_sum_x: SepticExtension = (x3.iter()) - .map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] .collect_vec() .into(); let final_sum_y: SepticExtension = (y3.iter()) - .map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] .collect_vec() .into(); let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); @@ -226,7 +267,7 @@ impl CpuEccProver { assert_eq!(final_sum, sum); // check evaluations assert_eq!( - eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt), + eq_eval_less_or_equal_than(last_evaluation_index - 1, &out_rt, &rt), evals[0] ); for i in 0..SEPTIC_EXTENSION_DEGREE { @@ -258,7 +299,7 @@ impl CpuEccProver { // TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] EccQuarkProof { zerocheck_proof, - num_vars: n, + num_instances, evals, sum, } @@ -1051,8 +1092,12 @@ where #[cfg(test)] mod tests { - use std::iter::repeat; - + use crate::scheme::{ + constants::SEPTIC_EXTENSION_DEGREE, + cpu::CpuEccProver, + septic_curve::{SepticExtension, SepticPoint}, + verifier::EccVerifier, + }; use ff_ext::BabyBearExt4; use itertools::Itertools; use multilinear_extensions::{ @@ -1060,22 +1105,22 @@ mod tests { util::transpose, }; use p3::babybear::BabyBear; + use std::iter::repeat_n; use transcript::BasicTranscript; - - use crate::scheme::{ - constants::SEPTIC_EXTENSION_DEGREE, - cpu::CpuEccProver, - septic_curve::{SepticExtension, SepticPoint}, - verifier::EccVerifier, - }; + use witness::next_pow2_instance_padding; #[test] fn test_ecc_quark_prover() { + for n_points in 1..2 ^ 10 { + test_ecc_quark_prover_inner(n_points) + } + } + + fn test_ecc_quark_prover_inner(n_points: usize) { type E = BabyBearExt4; type F = BabyBear; - let log2_n = 6; - let n_points = 1 << log2_n; + let log2_n = next_pow2_instance_padding(n_points).ilog2(); let mut rng = rand::thread_rng(); let final_sum; @@ -1085,7 +1130,11 @@ mod tests { let mut points = (0..n_points) .map(|_| SepticPoint::::random(&mut rng)) .collect_vec(); - let mut s = Vec::with_capacity(n_points); + points.extend(repeat_n( + SepticPoint::point_at_infinity(), + (1 << log2_n) - points.len(), + )); + let mut s = Vec::with_capacity(1 << (log2_n + 1)); for layer in (1..=log2_n).rev() { let num_inputs = 1 << layer; @@ -1094,17 +1143,19 @@ mod tests { s.extend(inputs.chunks_exact(2).map(|chunk| { let p = &chunk[0]; let q = &chunk[1]; - - (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + if q.is_infinity { + SepticExtension::zero() + } else { + (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + } })); points.extend( - points[points.len() - num_inputs..] + inputs .chunks_exact(2) .map(|chunk| { let p = chunk[0].clone(); let q = chunk[1].clone(); - p + q }) .collect_vec(), @@ -1113,11 +1164,14 @@ mod tests { final_sum = points.last().cloned().unwrap(); // padding to 2*N - s.extend(repeat(SepticExtension::zero()).take(n_points + 1)); + s.extend(repeat_n( + SepticExtension::zero(), + (1 << (log2_n + 1)) - s.len(), + )); points.push(SepticPoint::point_at_infinity()); - assert_eq!(s.len(), 2 * n_points); - assert_eq!(points.len(), 2 * n_points); + assert_eq!(s.len(), 1 << (log2_n + 1)); + assert_eq!(points.len(), 1 << (log2_n + 1)); // transform points to row major matrix let trace = points @@ -1144,6 +1198,7 @@ mod tests { let mut transcript = BasicTranscript::new(b"test"); let prover = CpuEccProver::new(); let quark_proof = prover.create_ecc_proof( + n_points, xs.to_vec(), ys.to_vec(), s.to_vec(), @@ -1156,6 +1211,7 @@ mod tests { assert!( verifier .verify_ecc_proof(&quark_proof, &mut transcript) + .inspect_err(|err| println!("err {:?}", err)) .is_ok() ); } diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index ed3030d23..99a5878a0 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -594,6 +594,28 @@ impl MulAssign for SepticExtension { #[derive(Clone, Debug)] pub struct SymbolicSepticExtension(pub Vec>); +impl SymbolicSepticExtension { + pub fn mul_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() * Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } + + pub fn add_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() + Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } +} + impl Add for &SymbolicSepticExtension { type Output = SymbolicSepticExtension; diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index e48978c51..f7ff6cd4b 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -2,7 +2,7 @@ use crate::{ scheme::{ constants::{MIN_PAR_SIZE, SEPTIC_JACOBIAN_NUM_MLES}, hal::{MainSumcheckProver, ProofInput, ProverDevice}, - septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, + septic_curve::{SepticExtension, SepticJacobianPoint}, }, structs::ComposedConstrainSystem, }; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index abe0f6ec2..65eba768f 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -5,11 +5,12 @@ use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::{gkr::GKRClaims, utils::eq_eval_less_or_equal_than}; +use gkr_iop::{gkr::GKRClaims, selector::SelectorType}; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Instance, StructuralWitIn, StructuralWitInType, + Expression, Instance, StructuralWitIn, StructuralWitInType, + StructuralWitInType::StackedConstantSequence, mle::IntoMLE, util::ceil_log2, utils::eval_by_expr_with_instance, @@ -812,7 +813,7 @@ impl TowerVerify { let max_num_variables = num_variables.iter().max().unwrap(); - let (next_rt, _) = (0..(max_num_variables-1)).try_fold( + let (next_rt, _) = (0..(max_num_variables - 1)).try_fold( ( PointAndEval { point: initial_rt, @@ -845,31 +846,31 @@ impl TowerVerify { // prod'[b] = prod[0,b] * prod[1,b] // prod'[out_rt] = \sum_b eq(out_rt,b) * prod'[b] = \sum_b eq(out_rt,b) * prod[0,b] * prod[1,b] eq * *alpha - * if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else { - E::ZERO - } + * if round < *max_round - 1 { tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product() } else { + E::ZERO + } }) .sum::() + (0..num_logup_spec) - .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) - .zip_eq(num_variables[num_prod_spec..].iter()) - .map(|((spec_index, alpha), max_round)| { - // logup_q'[b] = logup_q[0,b] * logup_q[1,b] - // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] - // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) - // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - eq * if round < *max_round-1 { - let evals = &tower_proofs.logup_specs_eval[spec_index][round]; - let (p1, p2, q1, q2) = - (evals[0], evals[1], evals[2], evals[3]); - *alpha_numerator * (p1 * q2 + p2 * q1) - + *alpha_denominator * (q1 * q2) - } else { - E::ZERO - } - }) - .sum::(); + .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) + .zip_eq(num_variables[num_prod_spec..].iter()) + .map(|((spec_index, alpha), max_round)| { + // logup_q'[b] = logup_q[0,b] * logup_q[1,b] + // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] + // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) + // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] + let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); + eq * if round < *max_round - 1 { + let evals = &tower_proofs.logup_specs_eval[spec_index][round]; + let (p1, p2, q1, q2) = + (evals[0], evals[1], evals[2], evals[3]); + *alpha_numerator * (p1 * q2 + p2 * q1) + + *alpha_denominator * (q1 * q2) + } else { + E::ZERO + } + }) + .sum::(); if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); @@ -878,7 +879,7 @@ impl TowerVerify { // derive single eval // rt' = r_merge || rt // r_merge.len() == ceil_log2(num_product_fanin) - let r_merge =transcript.sample_and_append_vec(b"merge", log2_num_fanin); + let r_merge = transcript.sample_and_append_vec(b"merge", log2_num_fanin); let coeffs = build_eq_x_r_vec_sequential(&r_merge); assert_eq!(coeffs.len(), num_fanin); let rt_prime = [rt, r_merge].concat(); @@ -894,17 +895,17 @@ impl TowerVerify { .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { // prod'[rt,r_merge] = \sum_b eq(r_merge, b) * prod'[b,rt] - if round < max_round -1 { + if round < max_round - 1 { // merged evaluation let evals = izip!( tower_proofs.prod_specs_eval[spec_index][round].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha * evals } else { E::ZERO @@ -918,28 +919,28 @@ impl TowerVerify { .zip_eq(next_alpha_pows[num_prod_spec..].chunks(2)) .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { - if round < max_round -1 { + if round < max_round - 1 { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); // merged evaluation let p_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][0..2].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); let q_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][2..4].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha_numerator * p_evals + *alpha_denominator * q_evals } else { E::ZERO @@ -981,50 +982,53 @@ impl EccVerifier { proof: &EccQuarkProof, transcript: &mut impl Transcript, ) -> Result<(), ZKVMError> { - let out_rt = transcript.sample_and_append_vec(b"ecc", proof.num_vars); - let alpha_pows = - transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + let num_vars = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; + let out_rt = transcript.sample_and_append_vec(b"ecc", num_vars); + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); let sumcheck_claim = IOPVerifierState::verify( E::ZERO, &proof.zerocheck_proof, &VPAuxInfo { max_degree: 3, - max_num_variables: proof.num_vars, + max_num_variables: num_vars, phantom: PhantomData, }, transcript, ); - let s0: SepticExtension = proof.evals[1..][0..SEPTIC_EXTENSION_DEGREE] + let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let x0: SepticExtension = proof.evals[1..] - [SEPTIC_EXTENSION_DEGREE..2 * SEPTIC_EXTENSION_DEGREE] + let x0: SepticExtension = proof.evals[2..][SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let y0: SepticExtension = proof.evals[1..] - [2 * SEPTIC_EXTENSION_DEGREE..3 * SEPTIC_EXTENSION_DEGREE] + let y0: SepticExtension = proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let x1: SepticExtension = proof.evals[1..] - [3 * SEPTIC_EXTENSION_DEGREE..4 * SEPTIC_EXTENSION_DEGREE] + let x1: SepticExtension = proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let y1: SepticExtension = proof.evals[1..] - [4 * SEPTIC_EXTENSION_DEGREE..5 * SEPTIC_EXTENSION_DEGREE] + let y1: SepticExtension = proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let x3: SepticExtension = proof.evals[1..] - [5 * SEPTIC_EXTENSION_DEGREE..6 * SEPTIC_EXTENSION_DEGREE] + let x3: SepticExtension = proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let y3: SepticExtension = proof.evals[1..] - [6 * SEPTIC_EXTENSION_DEGREE..7 * SEPTIC_EXTENSION_DEGREE] + let y3: SepticExtension = proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let num_instances = (1 << proof.num_vars) - 1; let rt = sumcheck_claim .point .iter() @@ -1034,6 +1038,8 @@ impl EccVerifier { // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + // zerocheck: 0 = (x[1,b] - x[b,0]) + // zerocheck: 0 = (y[1,b] - y[b,0]) // // note that they are not septic extension field elements, // we just want to reuse the multiply/add/sub formulas @@ -1041,25 +1047,60 @@ impl EccVerifier { let v2: SepticExtension = s0.square() - &x0 - &x1 - &x3; let v3: SepticExtension = s0 * (&x0 - &x3) - (&y0 + &y3); - let v: E = vec![v1, v2, v3] - .into_iter() - .enumerate() - .flat_map(|(i, v)| { - let start = i * SEPTIC_EXTENSION_DEGREE; - let end = (i + 1) * SEPTIC_EXTENSION_DEGREE; - v.0.into_iter() - .zip(alpha_pows[start..end].iter()) - .map(|(c, alpha)| c * *alpha) - }) - .sum(); + let v4: SepticExtension = &x3 - &x0; + let v5: SepticExtension = &y3 - &y0; + + let [v1, v2, v3, v4, v5] = [v1, v2, v3, v4, v5].map(|v| { + v.0.into_iter() + .zip(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(c, alpha)| c * *alpha) + .collect_vec() + }); + + let sel_add_expr = SelectorType::::QuarkBinaryTreeLessThan(Expression::StructuralWitIn( + 0, + // this value doesn't matter, as we only need structural id + StackedConstantSequence { max_value: 0 }, + )); + let mut sel_evals = vec![E::ZERO]; + sel_add_expr.evaluate(&mut sel_evals, &out_rt, &rt, proof.num_instances, 0); + let expected_sel_add = sel_evals[0]; + + if proof.evals[0] != expected_sel_add { + return Err(ZKVMError::VerifyError( + (format!( + "sel_add evaluation mismatch, expected {}, got {}", + expected_sel_add, proof.evals[0] + )) + .into(), + )); + } + + // derive `sel_bypass = eq - sel_add - sel_last_onehot` + let expected_sel_bypass = eq_eval(&out_rt, &rt) + - expected_sel_add + - (out_rt.iter().copied().product::() * rt.iter().copied().product::()); - let sel = eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt); - if sumcheck_claim.expected_evaluation != v * sel { + if proof.evals[1] != expected_sel_bypass { + return Err(ZKVMError::VerifyError( + (format!( + "sel_bypass evaluation mismatch, expected {}, got {}", + expected_sel_bypass, proof.evals[1] + )) + .into(), + )); + } + + let add_evaluations = vec![v1, v2, v3].into_iter().flatten().sum::(); + let bypass_evaluations = vec![v4, v5].into_iter().flatten().sum::(); + if sumcheck_claim.expected_evaluation + != add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + { return Err(ZKVMError::VerifyError( (format!( "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", sumcheck_claim.expected_evaluation, - v * sel + add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass )) .into(), )); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index c8d37346f..afa5cf256 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -31,7 +31,7 @@ use witness::RowMajorMatrix; ))] pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, - pub num_vars: usize, + pub num_instances: usize, pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] pub sum: SepticPoint, } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 1d4e6c56a..97cbf3c73 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -453,7 +453,8 @@ pub fn extend_exprs_with_rotation( | SelectorType::Prefix(_, sel) | SelectorType::OrderedSparse32 { expression: sel, .. - } => match_expr(sel) * zero_check_expr, + } + | SelectorType::QuarkBinaryTreeLessThan(sel) => match_expr(sel) * zero_check_expr, }; zero_check_exprs.push(expr); } diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index bc57295f1..02f5547b2 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -28,10 +28,11 @@ pub enum SelectorType { indices: Vec, expression: Expression, }, + /// binary tree [`quark`] from paper + QuarkBinaryTreeLessThan(Expression), } impl SelectorType { - /// Compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) pub fn compute( &self, out_point: &Point, @@ -50,6 +51,7 @@ impl SelectorType { } Some(sel.into_mle()) } + // compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, .. } => { let mut sel = build_eq_x_r_vec(out_point); sel.par_chunks_exact_mut(CYCLIC_POW2_5.len()) @@ -75,10 +77,68 @@ impl SelectorType { }); Some(sel.into_mle()) } + // also see evaluate() function for more explanation + SelectorType::QuarkBinaryTreeLessThan(_) => { + // num_instances: number of prefix one in leaf layer + let mut sel: Vec = build_eq_x_r_vec(out_point); + let n = sel.len(); + + let num_instances_sequence = (0..out_point.len()) + // clean up sig bits + .scan( + (num_instances / 2, num_instances.div_ceil(2)), + |(n_instance, raw_instance_ceiling), _| { + if *n_instance > 0 { + let cur = *n_instance; + *n_instance = *raw_instance_ceiling / 2; + *raw_instance_ceiling = raw_instance_ceiling.div_ceil(2); + Some(cur) + } else { + Some(0) + } + }, + ) + .collect::>(); + + // split sel into different size of region, set tailing 0 of respective chunk size + // 1st round: take v = sel[0..sel.len()/2], zero out v[num_instances_sequence[0]..] + // 2nd round: take v = sel[sel.len()/2 .. sel.len()/4], zero out v[num_instances_sequence[1]..] + // ... + // each round: progressively smaller chunk + // example: round 0 uses first half, round 1 uses next quarter, etc. + // compute cumulative start indices: + // e.g. chunk = n/2, then start = 0, chunk, chunk + chunk/2, chunk + chunk/2 + chunk/4, ... + // compute disjoint start indices and lengths + let chunks: Vec<(usize, usize)> = { + let mut result = Vec::new(); + let mut start = 0; + let mut chunk_len = n / 2; + while chunk_len > 0 { + result.push((start, chunk_len)); + start += chunk_len; + chunk_len /= 2; + } + result + }; + + for (i, (start, len)) in chunks.into_iter().enumerate() { + let slice = &mut sel[start..start + len]; + + // determine from which index to zero + let zero_start = num_instances_sequence.get(i).copied().unwrap_or(0).min(len); + + for x in &mut slice[zero_start..] { + *x = E::ZERO; + } + } + + // zero out last bh evaluations + *sel.last_mut().unwrap() = E::ZERO; + Some(sel.into_mle()) + } } } - /// Evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) pub fn evaluate( &self, evals: &mut Vec, @@ -100,6 +160,7 @@ impl SelectorType { eq_eval_less_or_equal_than(num_instances - 1, out_point, in_point), ) } + // evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, expression, @@ -114,6 +175,70 @@ impl SelectorType { eq_eval_less_or_equal_than(num_instances - 1, &out_point[5..], &in_point[5..]); (expression, eval * sel) } + SelectorType::QuarkBinaryTreeLessThan(expr) => { + // num_instances count on leaf layer + // where nodes size is 2^(N) / 2 + // out_point.len() is also log(2^(N)) - 1 + // so num_instances and 1 << out_point.len() are on same scaling + assert!(num_instances > 0); + assert!(num_instances <= (1 << out_point.len())); + if out_point.is_empty() { + panic!("empty out_point size") + } + assert_eq!(out_point.len(), in_point.len()); + + // we break down this special selector evaluation into recursive structure + // iterating through out_point and in_point, for each i + // next_eval = lhs * (1-out_point[i]) * (1 - in_point[i]) + prev_eval * out_point[i] * in_point[i] + // where the lhs is in consecutive prefix 1 follow by 0 + + // calculate prefix 1 length of each layer + let mut prefix_one_seq = (0..out_point.len()) + .scan( + (num_instances / 2, num_instances.div_ceil(2)), + |(n_instance, raw_instance_ceiling), _| { + if *n_instance > 0 { + let cur = *n_instance; + *n_instance = *raw_instance_ceiling / 2; + *raw_instance_ceiling = raw_instance_ceiling.div_ceil(2); + Some(cur) + } else { + Some(0) + } + }, + ) + .collect::>(); + prefix_one_seq.reverse(); + let mut prefix_one_seq_iter = prefix_one_seq.iter(); + + let mut res = if let Some(first) = prefix_one_seq_iter.by_ref().next() { + if *first > 0 { + assert_eq!(*first, 1); + (E::ONE - out_point[0]) * (E::ONE - in_point[0]) + } else { + E::ZERO + } + } else { + unreachable!() + }; + for i in 1..out_point.len() { + let num_prefix_one_lhs = prefix_one_seq_iter.by_ref().next().unwrap(); + let lhs_res = if *num_prefix_one_lhs > 0 { + (E::ONE - out_point[i]) + * (E::ONE - in_point[i]) + * eq_eval_less_or_equal_than( + *num_prefix_one_lhs - 1, + &out_point[..i], + &in_point[..i], + ) + } else { + E::ZERO + }; + let rhs_res = (out_point[i] * in_point[i]) * res; + res = lhs_res + rhs_res; + } + (expr, res) + } }; let Expression::StructuralWitIn(wit_id, _) = expr else { panic!("Wrong selector expression format");