Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 85 additions & 32 deletions ceno_zkvm/src/scheme/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ use multilinear_extensions::{
virtual_poly::build_eq_x_r_vec,
virtual_polys::VirtualPolynomialsBuilder,
};
use p3::field::FieldAlgebra;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use p3::field::{Field, FieldAlgebra};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
use std::{collections::BTreeMap, sync::Arc};
use sumcheck::{
macros::{entered_span, exit_span},
Expand All @@ -43,6 +46,7 @@ use sumcheck::{
use transcript::Transcript;
use witness::next_pow2_instance_padding;

use gkr_iop::hal::MultilinearPolynomial;
#[cfg(feature = "sanity-check")]
use {crate::scheme::septic_curve::SepticExtension, gkr_iop::utils::eq_eval_less_or_equal_than};

Expand All @@ -65,6 +69,7 @@ impl CpuEccProver {

pub fn create_ecc_proof<'a, E: ExtensionField>(
&self,
num_instances: usize,
mut xs: Vec<MultilinearExtension<'a, E>>,
mut ys: Vec<MultilinearExtension<'a, E>>,
invs: Vec<MultilinearExtension<'a, E>>,
Expand All @@ -78,17 +83,38 @@ impl CpuEccProver {
let out_rt = transcript.sample_and_append_vec(b"ecc", n);
let num_threads = optimal_sumcheck_threads(out_rt.len());

let alpha_pows =
transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha");
// 2: expression got add and double
// 3: each contribute 3 zero constrains
let alpha_pows = transcript
.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3 * 2, b"ecc_alpha");
let mut alpha_pows_iter = alpha_pows.iter();

let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len());

let sel = SelectorType::Prefix(E::BaseField::ZERO, 0.into());
let num_instances = (1 << n) - 1;
let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, num_instances).unwrap();
let sel_expr = expr_builder.lift(sel_mle.to_either());
let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into());
let mut sel_add_mle: MultilinearExtension<'_, E> =
sel_add.compute(&out_rt, num_instances).unwrap();
// we construct sel_double witness here
// verifier can derive it via `sel_double = 1 - sel_add - last_onehot`
let mut sel_double_mle: Vec<E> = build_eq_x_r_vec(&out_rt);
match sel_add_mle.evaluations() {
FieldType::Ext(sel_add_mle) => sel_add_mle
.par_iter()
.zip_eq(sel_double_mle.par_iter_mut())
.for_each(|(sel_add, sel_double)| {
if *sel_add != E::ZERO {
*sel_double = E::ZERO;
}
}),
_ => unreachable!(),
}
*sel_double_mle.last_mut().unwrap() = E::ZERO;
let mut sel_double_mle = sel_double_mle.into_mle();
let sel_add_expr = expr_builder.lift(sel_add_mle.to_either());
let sel_double_expr = expr_builder.lift(sel_double_mle.to_either());

let mut exprs = vec![];
let mut exprs_add = vec![];
let mut exprs_double = vec![];

let filter_bj = |v: &[MultilinearExtension<'_, E>], j: usize| {
v.iter()
Expand Down Expand Up @@ -157,43 +183,69 @@ impl CpuEccProver {
);
// affine addition
// zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1)
exprs.extend(
exprs_add.extend(
(s.clone() * (&x0 - &x1) - (&y0 - &y1))
.to_exprs()
.into_iter()
.zip(alpha_pows.iter().take(SEPTIC_EXTENSION_DEGREE))
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1)
exprs.extend(
exprs_add.extend(
((&s * &s) - &x0 - &x1 - &x3)
.to_exprs()
.into_iter()
.zip(
alpha_pows[SEPTIC_EXTENSION_DEGREE..]
.iter()
.take(SEPTIC_EXTENSION_DEGREE),
)
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1)
exprs.extend(
exprs_add.extend(
(s.clone() * (&x0 - &x3) - (&y0 + &y3))
.to_exprs()
.into_iter()
.zip(
alpha_pows[SEPTIC_EXTENSION_DEGREE * 2..]
.iter()
.take(SEPTIC_EXTENSION_DEGREE),
)
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

let exprs_add = exprs_add.into_iter().sum::<Expression<E>>() * sel_add_expr;

// deal with double
// 0 = s[0,b] * (2*y[b,0]) - (3*x[b,0]^2 + a)
exprs_double.extend(
(s.clone() * (&y0.mul_scalar(Either::Left(E::BaseField::TWO)))
- ((&x0 * &x0.mul_scalar(Either::Left(E::BaseField::from_canonical_u32(3))))
.add_scalar(Either::Left(E::BaseField::TWO))))
.to_exprs()
.into_iter()
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// 0 = s[0,b]^2 - 2*x[b,0] - x[1,b]
exprs_double.extend(
((&s * &s) - (&x0.mul_scalar(Either::Left(E::BaseField::TWO))) - &x3)
.to_exprs()
.into_iter()
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// 0 = s * (x[b,0] - x[1,b]) - (y[b,0] + y[1, b])
exprs_double.extend(
(s.clone() * (&x0 - &x3) - (&y0 + &y3))
.to_exprs()
.into_iter()
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);
assert!(alpha_pows_iter.next().is_none());

let exprs_double = exprs_double.into_iter().sum::<Expression<E>>() * sel_double_expr;

let (zerocheck_proof, state) = IOPProverState::prove(
expr_builder
.to_virtual_polys(&[exprs.into_iter().sum::<Expression<E>>() * sel_expr], &[]),
expr_builder.to_virtual_polys(&[exprs_add + exprs_double], &[]),
transcript,
);

Expand All @@ -202,10 +254,11 @@ impl CpuEccProver {

assert_eq!(zerocheck_proof.extract_sum(), E::ZERO);
// 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt]
assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7);
assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7);

#[cfg(feature = "sanity-check")]
{
let last_evaluation_index = (1 << n) - 1;
let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec();
let x0 = filter_bj(&xs, 0);
let y0 = filter_bj(&ys, 0);
Expand All @@ -214,19 +267,19 @@ impl CpuEccProver {
let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec();
let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec();
let final_sum_x: SepticExtension<E::BaseField> = (x3.iter())
.map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0]
.map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum_y: SepticExtension<E::BaseField> = (y3.iter())
.map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0]
.map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y);

assert_eq!(final_sum, sum);
// check evaluations
assert_eq!(
eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt),
eq_eval_less_or_equal_than(last_evaluation_index - 1, &out_rt, &rt),
evals[0]
);
for i in 0..SEPTIC_EXTENSION_DEGREE {
Expand Down Expand Up @@ -258,7 +311,7 @@ impl CpuEccProver {
// TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt]
EccQuarkProof {
zerocheck_proof,
num_vars: n,
num_instances,
evals,
sum,
}
Expand Down Expand Up @@ -1099,12 +1152,11 @@ mod tests {
}));

points.extend(
points[points.len() - num_inputs..]
inputs
.chunks_exact(2)
.map(|chunk| {
let p = chunk[0].clone();
let q = chunk[1].clone();

p + q
})
.collect_vec(),
Expand Down Expand Up @@ -1144,6 +1196,7 @@ mod tests {
let mut transcript = BasicTranscript::new(b"test");
let prover = CpuEccProver::new();
let quark_proof = prover.create_ecc_proof(
n_points,
xs.to_vec(),
ys.to_vec(),
s.to_vec(),
Expand Down
22 changes: 22 additions & 0 deletions ceno_zkvm/src/scheme/septic_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,28 @@ impl<F: Field> MulAssign<Self> for SepticExtension<F> {
#[derive(Clone, Debug)]
pub struct SymbolicSepticExtension<E: ExtensionField>(pub Vec<Expression<E>>);

impl<E: ExtensionField> SymbolicSepticExtension<E> {
pub fn mul_scalar(&self, scalar: Either<E::BaseField, E>) -> Self {
let res = self
.0
.iter()
.map(|a| a.clone() * Expression::Constant(scalar))
.collect();

SymbolicSepticExtension(res)
}

pub fn add_scalar(&self, scalar: Either<E::BaseField, E>) -> Self {
let res = self
.0
.iter()
.map(|a| a.clone() + Expression::Constant(scalar))
.collect();

SymbolicSepticExtension(res)
}
}

impl<E: ExtensionField> Add<Self> for &SymbolicSepticExtension<E> {
type Output = SymbolicSepticExtension<E>;

Expand Down
Loading
Loading