diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 1444b61f8..dd06e510b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -61,6 +61,10 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --package ceno_zkvm --features u16limb_circuit --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci + - name: Run gkr keccak (release) + env: + RUSTFLAGS: "-C opt-level=3" + run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/keccak_syscall - name: Run fibonacci (release) env: diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index d8929667b..9b170819f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -124,8 +124,12 @@ impl Instruction for KeccakInstruction { let (out_evals, mut chip) = layout.finalize(cb); - let layer = - Layer::from_circuit_builder(cb, "Rounds".to_string(), layout.n_challenges(), out_evals); + let layer = Layer::from_circuit_builder( + cb, + "GkrKeccak".to_string(), + layout.n_challenges(), + out_evals, + ); chip.add_layer(layer); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 7a7d62fda..fd82b3feb 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -184,6 +184,7 @@ impl KeccakLayout { sel_mem_read, sel_mem_write, eq_zero, + // TODO: explain this?? eq_rotation_left, eq_rotation_right, eq_rotation, @@ -221,10 +222,12 @@ impl KeccakLayout { }, selector_type_layout: SelectorTypeLayout { sel_mem_read: SelectorType::OrderedSparse32 { + // read at the first round indices: vec![CYCLIC_POW2_5[0] as usize], expression: sel_mem_read.expr(), }, sel_mem_write: SelectorType::OrderedSparse32 { + // write at the last round indices: vec![CYCLIC_POW2_5[ROUNDS - 1] as usize], expression: sel_mem_write.expr(), }, @@ -490,7 +493,7 @@ impl ProtocolBuilder for KeccakLayout { layout.input32_exprs = keccak_input32.try_into().unwrap(); layout.output32_exprs = keccak_output32.try_into().unwrap(); - // rotation constrain: rotation(keccak_input8).next() == keccak_output8 + // rotation constraint: rotation(keccak_input8).next() == keccak_output8 izip!(keccak_input8, keccak_output8) .for_each(|(input, output)| system.rotate_and_assert_eq(input.expr(), output.expr())); system.set_rotation_params(RotationParams { diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index f3db364ee..51eb9fd3c 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -16,6 +16,10 @@ pub mod ram; #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)] #[serde(bound = "E: ExtensionField + DeserializeOwned")] pub struct RotationParams { + // TODO: explain this + // f(left_point) = \sum eq(left_point,b) * f(b) + // f(right_point) = \sum eq(right_point,b) * f(b) + // g(point) = \sum eq(point,b) * g(b) pub rotation_eqs: Option<[Expression; ROTATION_OPENING_COUNT]>, pub rotation_cyclic_group_log2: usize, pub rotation_cyclic_subgroup_size: usize, diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 7d80229fd..86ba0c75e 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -131,6 +131,7 @@ impl GKRCircuit { let mut challenges = challenges.to_vec(); let mut evaluations = out_evals.to_vec(); + // TODO: can we avoid this resize? evaluations.resize(self.n_evaluations, PointAndEval::default()); for (i, (layer, layer_proof)) in izip!(&self.layers, sumcheck_proofs).enumerate() { tracing::debug!("verifier layer {i} layer with layer name {}", layer.name); diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index ed4808ff0..d42c6b7e7 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -80,7 +80,7 @@ pub struct Layer { /// first tuple value is optional eq pub out_sel_and_eval_exprs: Vec>, - // format: ([eq0, eq1, eq2], Vec<(rotatition_expr, expr)>) such that rotation_expr - expr == 0 + // format: ([eq0, eq1, eq2], Vec<(rotation_expr, expr)>) such that rotation_expr - expr == 0 // there got 3 different eq for (left, right, target) during rotation argument // refer https://hackmd.io/HAAj1JTQQiKfu0SIwOJDRw?view#Rotation pub rotation_exprs: RotateExprs, @@ -184,6 +184,7 @@ impl Layer { num_instances: usize, ) -> LayerProof { self.update_challenges(challenges, transcript); + // TODO: how to understand out_sel_and_eval_exprs?? let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); let (sumcheck_layer_proof, point) = match self.ty { @@ -227,7 +228,7 @@ impl Layer { proof: LayerProof, claims: &mut [PointAndEval], pub_io_evals: &[E], - challenges: &mut Vec, + challenges: &mut Vec, // TODO: can we avoid &mut here? transcript: &mut Trans, num_instances: usize, ) -> Result<(), BackendError> { @@ -349,6 +350,10 @@ impl Layer { .zip_eq(&r_record_evals.1) .enumerate() { + // with padding we have + // r_v = sel * r_expr + (1-sel) * 1 + // = sel * (r_expr - 1) + 1 + // therefore we have r_v - 1 = sel * (r_expr - 1) expressions.push(ram_expr - E::BaseField::ONE.expr()); evals.push(EvalExpression::::Linear( // evaluation = claim * one - one (padding) @@ -369,6 +374,8 @@ impl Layer { .zip_eq(&w_record_evals.1) .enumerate() { + // w_v = sel * w_expr + (1-sel) * 1 + // therefore we have w_v - 1 = sel * (w_expr - 1) expressions.push(ram_expr - E::BaseField::ONE.expr()); evals.push(EvalExpression::::Linear( // evaluation = claim * one - one (padding) @@ -389,6 +396,8 @@ impl Layer { .zip_eq(&lookup_evals.1) .enumerate() { + // lk_numerator_v = sel * lk_numerator + (1 - sel) * alpha + // therefore we have lk_numerator_v - alpha = sel * (lk_numerator - alpha) expressions.push(lookup - cb.cs.chip_record_alpha.clone()); evals.push(EvalExpression::::Linear( // evaluation = claim * one - alpha (padding) @@ -436,10 +445,10 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - expressions, + expressions, // used to derive the sumcheck expression n_challenges, in_eval_expr, - expr_evals, + expr_evals, // used to derive the expected sum in the sumcheck ((None, vec![]), 0, 0), expr_names, ) @@ -452,6 +461,9 @@ impl Layer { else { panic!("rotation params not set"); }; + + // TODO: include rotation_eqs in out_sel_and_eval_exprs + Layer::new( layer_name, LayerType::Zerocheck, diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index fa4c33c5e..fa9427749 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -158,7 +158,11 @@ impl> ZerocheckLayerProver (None, None, None, None) }; - // 2th sumcheck: batch rotation with other constrains + // f(0, r1, r2, ...) = \sum_b eq(left_point, b) * f(b) + // f(1, r1, 1-r2,r3,...) = \sum_b eq(right_point, b) * f(b) + // g(r0, r1, r2, ...) = \sum_b eq(point, b) * g(b) + + // 2th sumcheck: batch rotation with other constraints let span = entered_span!("build_out_points_eq", profiling_4 = true); let main_sumcheck_challenges = chain!( challenges.iter().copied(), @@ -249,8 +253,8 @@ impl> ZerocheckLayerProver } } -/// This is to prove the following n rotation arguments: -/// For the i-th argument, we check rotated(rotation_expr[i].0) == rotation_expr[i].1 +/// This is to prove the following N rotation arguments: +/// For the i-th argument, we check rotation_expr[i].0 == rotation_expr[i].1 /// This is proved through the following arguments: /// 0 = \sum_{b = 0}^{N - 1} sel(b) * \sum_i alpha^i * (rotated_rotation_expr[i].0(b) - rotation_expr[i].1(b)) /// With the randomness rx, we check: (currently we only support cycle with length 32) @@ -274,6 +278,8 @@ pub(crate) fn prove_rotation ZerocheckLayer for Layer { // build rotation expression let num_rotations = self.rotation_exprs.1.len(); let rotation_expr = if num_rotations > 0 { + // TODO: is there a more natural way to construct rotation sumcheck expression? + // 0 = \sum_i alpha^i * (rotate_i - target_i) let alpha_pows_expr = (2..) .take(num_rotations) .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) @@ -104,6 +106,7 @@ impl ZerocheckLayer for Layer { .iter() .flat_map(|(sel_type, out_eval)| izip!(std::iter::repeat(sel_type), out_eval.iter())) .collect(); + // TODO: explain exprs_with_selector_out_eval_monomial_form??? self.exprs_with_selector_out_eval_monomial_form = self .exprs .iter() @@ -221,6 +224,7 @@ impl ZerocheckLayer for Layer { if let Some(rotation_proof) = rotation_proof { // verify rotation proof let rt = eval_and_dedup_points + // TODO: why first? .first() .and_then(|(_, rt)| rt.as_ref()) .expect("rotation proof should have at least one point"); @@ -260,6 +264,7 @@ impl ZerocheckLayer for Layer { ) .collect_vec(); + // expected sum in the main sumcheck let sigma = dot_product( main_sumcheck_challenges.iter().skip(2).copied(), // skip first 2 global challenges eval_and_dedup_points @@ -283,9 +288,13 @@ impl ZerocheckLayer for Layer { ); let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); + // TODO: is there any check on eq_left, eq_right, eq for rotation argument? + // eval eq and set to respective witin izip!(&self.out_sel_and_eval_exprs, &eval_and_dedup_points).for_each( |((sel_type, _), (_, out_point))| { + // TODO: instead of overwrite main_evals, we should just read it out + // and then compare it with expected value sel_type.evaluate( &mut main_evals, out_point.as_ref().unwrap(), @@ -359,6 +368,7 @@ fn verify_rotation( }, transcript, ); + // TODO: this name is misleading let origin_point = in_point.into_iter().map(|c| c.elements).collect_vec(); // compute the selector evaluation @@ -387,6 +397,7 @@ fn verify_rotation( right_evals.push(*right_eval); target_evals.push(*target_eval); [ + // e.g. in the case of keccak, it's (1-s4)*left_eval + s4*right_eval (E::ONE - origin_point[rotation_cyclic_group_log2 - 1]) * *left_eval + origin_point[rotation_cyclic_group_log2 - 1] * *right_eval, *target_eval, @@ -438,6 +449,7 @@ pub fn extend_exprs_with_rotation( invalid => panic!("invalid eq format {:?}", invalid), }; + // the sumcheck expression in the zerocheck is \sum_i alpha^i * sel_i * expr_i for (sel_type, out_evals) in layer.out_sel_and_eval_exprs.iter() { let group_length = out_evals.len(); let zero_check_expr = expr_iter @@ -447,6 +459,7 @@ pub fn extend_exprs_with_rotation( .zip_eq(alpha_pows_iter.by_ref().take(group_length)) .map(|(expr, alpha)| alpha * expr) .sum::>(); + // TODO: why selectors are treated as WitIn? let expr = match sel_type { SelectorType::None => zero_check_expr, SelectorType::Whole(sel) @@ -492,6 +505,8 @@ pub fn extend_exprs_with_rotation( }) .sum(); + // batch 3 additional openings occurred in rotation argument into zerocheck argument + // push rotation expr to zerocheck expr if let Some( [ @@ -511,12 +526,18 @@ pub fn extend_exprs_with_rotation( Expression::StructuralWitIn(right_eq_id, ..), Expression::StructuralWitIn(eq_id, ..), ) => ( + // TODO: why convert to WitIn? Expression::WitIn(offset_eq_id + *left_eq_id), Expression::WitIn(offset_eq_id + *right_eq_id), Expression::WitIn(offset_eq_id + *eq_id), ), invalid => panic!("invalid eq format {:?}", invalid), }; + + // f(left_point) = \sum_b eq(left_point, b) * f(b) + // f(right_point) = \sum_b eq(right_point, b) * f(b) + // g(target_point) = \sum_b eq(target_point, b) * g(b + // add rotation left expr zero_check_exprs.push(rotation_left_eq_expr * left_rotation_expr); // add rotation right expr diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 4fd61dbed..8053d21ac 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -61,6 +61,7 @@ impl SelectorType { return; } + // TODO: no need to call .copied() let mut indices_iter = indices.iter().copied(); let mut next_keep = indices_iter.next(); @@ -78,6 +79,7 @@ impl SelectorType { } } + // TODO: correct the comment /// Evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) pub fn evaluate( &self, @@ -104,6 +106,10 @@ impl SelectorType { indices, expression, } => { + // p(b) = eq(x,b) if b[..5] \in indices else 0 + // p(y) = \sum_b p(b) * eq(y,b) + // = \sum_{bl \in indices} eq(xl,bl)*eq(yl,bl) * \sum_{br <= num_instances-1} eq(xr,br)*eq(yr,br) + // = \sum_{bl \in indices} eq(xl,yl,bl) * eq_eval_less_or_equal_than(num_instances-1, xr,yr) let out_subgroup_eq = build_eq_x_r_vec(&out_point[..5]); let in_subgroup_eq = build_eq_x_r_vec(&in_point[..5]); let mut eval = E::ZERO; diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index 285510c4b..9406e6f39 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -22,7 +22,8 @@ pub fn rotation_next_base_mle<'a, E: ExtensionField>( cyclic_group_log2_size: usize, ) -> MultilinearExtension<'a, E> { let cyclic_group_size = 1 << cyclic_group_log2_size; - let rotation_index = bh.into_iter().take(cyclic_group_size + 1).collect_vec(); + // TODO: why +1 + let rotation_index = bh.into_iter().take(cyclic_group_size).collect_vec(); let mut rotated_mle_evals = Vec::with_capacity(mle.evaluations().len()); rotated_mle_evals.par_extend( (0..mle.evaluations().len()) @@ -42,11 +43,13 @@ pub fn rotation_next_base_mle<'a, E: ExtensionField>( rotate_chunk[0] = original_chunk[0]; - for i in (0..rotation_index.len() - 1).rev() { - let to = rotation_index[i] as usize; - let from = rotation_index[i + 1] as usize; - rotate_chunk[to] = original_chunk[from]; - } + rotation_index + .iter() + .tuple_windows() + .for_each(|(cur, next)| { + // f'(b) = f(next(b)) + rotate_chunk[*cur as usize] = original_chunk[*next as usize]; + }); }); MultilinearExtension::from_evaluation_vec_smart(mle.num_vars(), rotated_mle_evals) } @@ -160,6 +163,14 @@ pub const fn wits_fixed_and_eqs( (wits, fixed, eqs) } +/// p(b) = eq(x,b) if b <= max_idx else 0 +/// +/// Its mle is defined as +/// +/// p(y) = \sum_b p(b)*eq(b,y) = \sum_{b <= max_idx} eq(x,b)*eq(b,y) +/// +/// it's easy to see that eq(x,b)*eq(y,b) = eq(x,y,b). +/// /// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in /// [0..=max_idx]. Specifically, it is an MLE of the following vector: /// partial_eq_{\mathbf{x}}(\mathbf{y}) @@ -239,8 +250,11 @@ mod tests { let mut rng = rand::thread_rng(); let point: Vec<_> = (0..7).map(|_| E::random(&mut rng)).collect(); let (left_point, right_point) = bh.get_rotation_points(&point); + // f(next(r)) = (1-r5)*f(0,r1,...) + r5*f(1,r1,1-r2,...) let rotated_eval = rotated.evaluate(&point); + // f(0,r1,...) let left_eval = poly.evaluate(&left_point); + // f(1,r1,1-r2,...) let right_eval = poly.evaluate(&right_point); assert_eq!( rotated_eval,