Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ jobs:
RUSTFLAGS: "-C opt-level=3"
run: cargo run --package ceno_zkvm --features u16limb_circuit --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci

- name: Run gkr keccak (release)
env:
RUSTFLAGS: "-C opt-level=3"
run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/keccak_syscall

- name: Run fibonacci (release)
env:
Expand Down
8 changes: 6 additions & 2 deletions ceno_zkvm/src/instructions/riscv/ecall/keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ impl<E: ExtensionField> Instruction<E> for KeccakInstruction<E> {

let (out_evals, mut chip) = layout.finalize(cb);

let layer =
Layer::from_circuit_builder(cb, "Rounds".to_string(), layout.n_challenges(), out_evals);
let layer = Layer::from_circuit_builder(
cb,
"GkrKeccak".to_string(),
layout.n_challenges(),
out_evals,
);
chip.add_layer(layer);

let circuit = chip.gkr_circuit();
Expand Down
5 changes: 4 additions & 1 deletion ceno_zkvm/src/precompiles/lookup_keccakf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ impl<E: ExtensionField> KeccakLayout<E> {
sel_mem_read,
sel_mem_write,
eq_zero,
// TODO: explain this??
eq_rotation_left,
eq_rotation_right,
eq_rotation,
Expand Down Expand Up @@ -221,10 +222,12 @@ impl<E: ExtensionField> KeccakLayout<E> {
},
selector_type_layout: SelectorTypeLayout {
sel_mem_read: SelectorType::OrderedSparse32 {
// read at the first round
indices: vec![CYCLIC_POW2_5[0] as usize],
expression: sel_mem_read.expr(),
},
sel_mem_write: SelectorType::OrderedSparse32 {
// write at the last round
indices: vec![CYCLIC_POW2_5[ROUNDS - 1] as usize],
expression: sel_mem_write.expr(),
},
Expand Down Expand Up @@ -490,7 +493,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
layout.input32_exprs = keccak_input32.try_into().unwrap();
layout.output32_exprs = keccak_output32.try_into().unwrap();

// rotation constrain: rotation(keccak_input8).next() == keccak_output8
// rotation constraint: rotation(keccak_input8).next() == keccak_output8
izip!(keccak_input8, keccak_output8)
.for_each(|(input, output)| system.rotate_and_assert_eq(input.expr(), output.expr()));
system.set_rotation_params(RotationParams {
Expand Down
4 changes: 4 additions & 0 deletions gkr_iop/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ pub mod ram;
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
#[serde(bound = "E: ExtensionField + DeserializeOwned")]
pub struct RotationParams<E: ExtensionField> {
// TODO: explain this
// f(left_point) = \sum eq(left_point,b) * f(b)
// f(right_point) = \sum eq(right_point,b) * f(b)
// g(point) = \sum eq(point,b) * g(b)
pub rotation_eqs: Option<[Expression<E>; ROTATION_OPENING_COUNT]>,
pub rotation_cyclic_group_log2: usize,
pub rotation_cyclic_subgroup_size: usize,
Expand Down
1 change: 1 addition & 0 deletions gkr_iop/src/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ impl<E: ExtensionField> GKRCircuit<E> {

let mut challenges = challenges.to_vec();
let mut evaluations = out_evals.to_vec();
// TODO: can we avoid this resize?
evaluations.resize(self.n_evaluations, PointAndEval::default());
for (i, (layer, layer_proof)) in izip!(&self.layers, sumcheck_proofs).enumerate() {
tracing::debug!("verifier layer {i} layer with layer name {}", layer.name);
Expand Down
20 changes: 16 additions & 4 deletions gkr_iop/src/gkr/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub struct Layer<E: ExtensionField> {
/// first tuple value is optional eq
pub out_sel_and_eval_exprs: Vec<ExprEvalType<E>>,

// format: ([eq0, eq1, eq2], Vec<(rotatition_expr, expr)>) such that rotation_expr - expr == 0
// format: ([eq0, eq1, eq2], Vec<(rotation_expr, expr)>) such that rotation_expr - expr == 0
// there got 3 different eq for (left, right, target) during rotation argument
// refer https://hackmd.io/HAAj1JTQQiKfu0SIwOJDRw?view#Rotation
pub rotation_exprs: RotateExprs<E>,
Expand Down Expand Up @@ -184,6 +184,7 @@ impl<E: ExtensionField> Layer<E> {
num_instances: usize,
) -> LayerProof<E> {
self.update_challenges(challenges, transcript);
// TODO: how to understand out_sel_and_eval_exprs??
let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges);

let (sumcheck_layer_proof, point) = match self.ty {
Expand Down Expand Up @@ -227,7 +228,7 @@ impl<E: ExtensionField> Layer<E> {
proof: LayerProof<E>,
claims: &mut [PointAndEval<E>],
pub_io_evals: &[E],
challenges: &mut Vec<E>,
challenges: &mut Vec<E>, // TODO: can we avoid &mut here?
transcript: &mut Trans,
num_instances: usize,
) -> Result<(), BackendError> {
Expand Down Expand Up @@ -349,6 +350,10 @@ impl<E: ExtensionField> Layer<E> {
.zip_eq(&r_record_evals.1)
.enumerate()
{
// with padding we have
// r_v = sel * r_expr + (1-sel) * 1
// = sel * (r_expr - 1) + 1
// therefore we have r_v - 1 = sel * (r_expr - 1)
expressions.push(ram_expr - E::BaseField::ONE.expr());
evals.push(EvalExpression::<E>::Linear(
// evaluation = claim * one - one (padding)
Expand All @@ -369,6 +374,8 @@ impl<E: ExtensionField> Layer<E> {
.zip_eq(&w_record_evals.1)
.enumerate()
{
// w_v = sel * w_expr + (1-sel) * 1
// therefore we have w_v - 1 = sel * (w_expr - 1)
expressions.push(ram_expr - E::BaseField::ONE.expr());
evals.push(EvalExpression::<E>::Linear(
// evaluation = claim * one - one (padding)
Expand All @@ -389,6 +396,8 @@ impl<E: ExtensionField> Layer<E> {
.zip_eq(&lookup_evals.1)
.enumerate()
{
// lk_numerator_v = sel * lk_numerator + (1 - sel) * alpha
// therefore we have lk_numerator_v - alpha = sel * (lk_numerator - alpha)
expressions.push(lookup - cb.cs.chip_record_alpha.clone());
evals.push(EvalExpression::<E>::Linear(
// evaluation = claim * one - alpha (padding)
Expand Down Expand Up @@ -436,10 +445,10 @@ impl<E: ExtensionField> Layer<E> {
cb.cs.num_witin as usize,
cb.cs.num_structural_witin as usize,
cb.cs.num_fixed,
expressions,
expressions, // used to derive the sumcheck expression
n_challenges,
in_eval_expr,
expr_evals,
expr_evals, // used to derive the expected sum in the sumcheck
((None, vec![]), 0, 0),
expr_names,
)
Expand All @@ -452,6 +461,9 @@ impl<E: ExtensionField> Layer<E> {
else {
panic!("rotation params not set");
};

// TODO: include rotation_eqs in out_sel_and_eval_exprs

Layer::new(
layer_name,
LayerType::Zerocheck,
Expand Down
15 changes: 11 additions & 4 deletions gkr_iop/src/gkr/layer/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
(None, None, None, None)
};

// 2th sumcheck: batch rotation with other constrains
// f(0, r1, r2, ...) = \sum_b eq(left_point, b) * f(b)
// f(1, r1, 1-r2,r3,...) = \sum_b eq(right_point, b) * f(b)
// g(r0, r1, r2, ...) = \sum_b eq(point, b) * g(b)

// 2th sumcheck: batch rotation with other constraints
let span = entered_span!("build_out_points_eq", profiling_4 = true);
let main_sumcheck_challenges = chain!(
challenges.iter().copied(),
Expand Down Expand Up @@ -249,8 +253,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
}
}

/// This is to prove the following n rotation arguments:
/// For the i-th argument, we check rotated(rotation_expr[i].0) == rotation_expr[i].1
/// This is to prove the following N rotation arguments:
/// For the i-th argument, we check rotation_expr[i].0 == rotation_expr[i].1
/// This is proved through the following arguments:
/// 0 = \sum_{b = 0}^{N - 1} sel(b) * \sum_i alpha^i * (rotated_rotation_expr[i].0(b) - rotation_expr[i].1(b))
/// With the randomness rx, we check: (currently we only support cycle with length 32)
Expand All @@ -274,6 +278,8 @@ pub(crate) fn prove_rotation<E: ExtensionField, PCS: PolynomialCommitmentScheme<
// rotated_mles is non-deterministic input, rotated from existing witness polynomial
// we will reduce it to zero check, and finally reduce to committed polynomial opening
let (mut selector, mut rotated_mles) = {
// sanity check on max_num_variables
assert_eq!(rt.len(), max_num_variables);
let eq = build_eq_x_r_vec(rt);
let mut mles = raw_rotation_exprs
.par_iter()
Expand Down Expand Up @@ -306,7 +312,8 @@ pub(crate) fn prove_rotation<E: ExtensionField, PCS: PolynomialCommitmentScheme<
let builder = VirtualPolynomialsBuilder::new_with_mles(
num_threads,
max_num_variables,
// mles format [rotation_mle1, target_mle1, rotation_mle2, target_mle2, ....., selector, eq]
// keep the order of mles = [rotation_mle1, target_mle1, rotation_mle2, target_mle2, ....., selector]
// to be consistent with `rotation_sumcheck_expression`
rotated_mles
.iter_mut()
.zip_eq(raw_rotation_exprs)
Expand Down
21 changes: 21 additions & 0 deletions gkr_iop/src/gkr/layer/zerocheck_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
// build rotation expression
let num_rotations = self.rotation_exprs.1.len();
let rotation_expr = if num_rotations > 0 {
// TODO: is there a more natural way to construct rotation sumcheck expression?
// 0 = \sum_i alpha^i * (rotate_i - target_i)
let alpha_pows_expr = (2..)
.take(num_rotations)
.map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO))
Expand All @@ -104,6 +106,7 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
.iter()
.flat_map(|(sel_type, out_eval)| izip!(std::iter::repeat(sel_type), out_eval.iter()))
.collect();
// TODO: explain exprs_with_selector_out_eval_monomial_form???
self.exprs_with_selector_out_eval_monomial_form = self
.exprs
.iter()
Expand Down Expand Up @@ -221,6 +224,7 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
if let Some(rotation_proof) = rotation_proof {
// verify rotation proof
let rt = eval_and_dedup_points
// TODO: why first?
.first()
.and_then(|(_, rt)| rt.as_ref())
.expect("rotation proof should have at least one point");
Expand Down Expand Up @@ -260,6 +264,7 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
)
.collect_vec();

// expected sum in the main sumcheck
let sigma = dot_product(
main_sumcheck_challenges.iter().skip(2).copied(), // skip first 2 global challenges
eval_and_dedup_points
Expand All @@ -283,9 +288,13 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
);
let in_point = in_point.into_iter().map(|c| c.elements).collect_vec();

// TODO: is there any check on eq_left, eq_right, eq for rotation argument?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hero78119 It seems to me that eq_left / eq_right / eq_target is not checked at all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rotation eq will be re-construct when the random point was derived from previous layer

.chain(rotation_left.par_iter().map(|rotation_left| {
MultilinearExtension::from_evaluations_ext_vec(
rotation_left.len(),
build_eq_x_r_vec(rotation_left),
)
}))
// for rotation right point
.chain(rotation_right.par_iter().map(|rotation_right| {
MultilinearExtension::from_evaluations_ext_vec(
rotation_right.len(),
build_eq_x_r_vec(rotation_right),
)
}))
// for rotation point
.chain(rotation_point.par_iter().map(|rotation_point| {
MultilinearExtension::from_evaluations_ext_vec(
rotation_point.len(),
build_eq_x_r_vec(rotation_point),
)
}))
.collect::<Vec<_>>();

in summary, eq will transform from A to B format

  • format A: each eq will be assign as selector with only 0, 1 during witness generation to serve layer by layer witness inference
  • format B: once random point $$r$$ derive from previous layer, each eq will be constructed via build_eq(r) then proceed sumcheck

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the array self.out_sel_and_eval_exprs does not include eq_left / eq_right / eq.

izip!(&self.out_sel_and_eval_exprs, &eval_and_dedup_points).for_each(
|((sel_type, _), (_, out_point))| {
sel_type.evaluate(
&mut main_evals,
out_point.as_ref().unwrap(),
&in_point,
num_instances,
self.n_witin,
);
},
);


// eval eq and set to respective witin
izip!(&self.out_sel_and_eval_exprs, &eval_and_dedup_points).for_each(
|((sel_type, _), (_, out_point))| {
// TODO: instead of overwrite main_evals, we should just read it out
// and then compare it with expected value
sel_type.evaluate(
&mut main_evals,
out_point.as_ref().unwrap(),
Expand Down Expand Up @@ -359,6 +368,7 @@ fn verify_rotation<E: ExtensionField>(
},
transcript,
);
// TODO: this name is misleading
let origin_point = in_point.into_iter().map(|c| c.elements).collect_vec();

// compute the selector evaluation
Expand Down Expand Up @@ -387,6 +397,7 @@ fn verify_rotation<E: ExtensionField>(
right_evals.push(*right_eval);
target_evals.push(*target_eval);
[
// e.g. in the case of keccak, it's (1-s4)*left_eval + s4*right_eval
(E::ONE - origin_point[rotation_cyclic_group_log2 - 1]) * *left_eval
+ origin_point[rotation_cyclic_group_log2 - 1] * *right_eval,
*target_eval,
Expand Down Expand Up @@ -438,6 +449,7 @@ pub fn extend_exprs_with_rotation<E: ExtensionField>(
invalid => panic!("invalid eq format {:?}", invalid),
};

// the sumcheck expression in the zerocheck is \sum_i alpha^i * sel_i * expr_i
for (sel_type, out_evals) in layer.out_sel_and_eval_exprs.iter() {
let group_length = out_evals.len();
let zero_check_expr = expr_iter
Expand All @@ -447,6 +459,7 @@ pub fn extend_exprs_with_rotation<E: ExtensionField>(
.zip_eq(alpha_pows_iter.by_ref().take(group_length))
.map(|(expr, alpha)| alpha * expr)
.sum::<Expression<E>>();
// TODO: why selectors are treated as WitIn?
let expr = match sel_type {
SelectorType::None => zero_check_expr,
SelectorType::Whole(sel)
Expand Down Expand Up @@ -492,6 +505,8 @@ pub fn extend_exprs_with_rotation<E: ExtensionField>(
})
.sum();

// batch 3 additional openings occurred in rotation argument into zerocheck argument

// push rotation expr to zerocheck expr
if let Some(
[
Expand All @@ -511,12 +526,18 @@ pub fn extend_exprs_with_rotation<E: ExtensionField>(
Expression::StructuralWitIn(right_eq_id, ..),
Expression::StructuralWitIn(eq_id, ..),
) => (
// TODO: why convert to WitIn?
Expression::WitIn(offset_eq_id + *left_eq_id),
Expression::WitIn(offset_eq_id + *right_eq_id),
Expression::WitIn(offset_eq_id + *eq_id),
),
invalid => panic!("invalid eq format {:?}", invalid),
};

// f(left_point) = \sum_b eq(left_point, b) * f(b)
// f(right_point) = \sum_b eq(right_point, b) * f(b)
// g(target_point) = \sum_b eq(target_point, b) * g(b

// add rotation left expr
zero_check_exprs.push(rotation_left_eq_expr * left_rotation_expr);
// add rotation right expr
Expand Down
6 changes: 6 additions & 0 deletions gkr_iop/src/selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl<E: ExtensionField> SelectorType<E> {
return;
}

// TODO: no need to call .copied()
let mut indices_iter = indices.iter().copied();
let mut next_keep = indices_iter.next();

Expand All @@ -78,6 +79,7 @@ impl<E: ExtensionField> SelectorType<E> {
}
}

// TODO: correct the comment
/// Evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..]))
pub fn evaluate(
&self,
Expand All @@ -104,6 +106,10 @@ impl<E: ExtensionField> SelectorType<E> {
indices,
expression,
} => {
// p(b) = eq(x,b) if b[..5] \in indices else 0
// p(y) = \sum_b p(b) * eq(y,b)
// = \sum_{bl \in indices} eq(xl,bl)*eq(yl,bl) * \sum_{br <= num_instances-1} eq(xr,br)*eq(yr,br)
// = \sum_{bl \in indices} eq(xl,yl,bl) * eq_eval_less_or_equal_than(num_instances-1, xr,yr)
let out_subgroup_eq = build_eq_x_r_vec(&out_point[..5]);
let in_subgroup_eq = build_eq_x_r_vec(&in_point[..5]);
let mut eval = E::ZERO;
Expand Down
26 changes: 20 additions & 6 deletions gkr_iop/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pub fn rotation_next_base_mle<'a, E: ExtensionField>(
cyclic_group_log2_size: usize,
) -> MultilinearExtension<'a, E> {
let cyclic_group_size = 1 << cyclic_group_log2_size;
let rotation_index = bh.into_iter().take(cyclic_group_size + 1).collect_vec();
// TODO: why +1
let rotation_index = bh.into_iter().take(cyclic_group_size).collect_vec();
let mut rotated_mle_evals = Vec::with_capacity(mle.evaluations().len());
rotated_mle_evals.par_extend(
(0..mle.evaluations().len())
Expand All @@ -42,11 +43,13 @@ pub fn rotation_next_base_mle<'a, E: ExtensionField>(

rotate_chunk[0] = original_chunk[0];

for i in (0..rotation_index.len() - 1).rev() {
let to = rotation_index[i] as usize;
let from = rotation_index[i + 1] as usize;
rotate_chunk[to] = original_chunk[from];
}
rotation_index
.iter()
.tuple_windows()
.for_each(|(cur, next)| {
// f'(b) = f(next(b))
rotate_chunk[*cur as usize] = original_chunk[*next as usize];
});
});
MultilinearExtension::from_evaluation_vec_smart(mle.num_vars(), rotated_mle_evals)
}
Expand Down Expand Up @@ -160,6 +163,14 @@ pub const fn wits_fixed_and_eqs<const N: usize, const M: usize, const Q: usize>(
(wits, fixed, eqs)
}

/// p(b) = eq(x,b) if b <= max_idx else 0
///
/// Its mle is defined as
///
/// p(y) = \sum_b p(b)*eq(b,y) = \sum_{b <= max_idx} eq(x,b)*eq(b,y)
///
/// it's easy to see that eq(x,b)*eq(y,b) = eq(x,y,b).
///
/// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in
/// [0..=max_idx]. Specifically, it is an MLE of the following vector:
/// partial_eq_{\mathbf{x}}(\mathbf{y})
Expand Down Expand Up @@ -239,8 +250,11 @@ mod tests {
let mut rng = rand::thread_rng();
let point: Vec<_> = (0..7).map(|_| E::random(&mut rng)).collect();
let (left_point, right_point) = bh.get_rotation_points(&point);
// f(next(r)) = (1-r5)*f(0,r1,...) + r5*f(1,r1,1-r2,...)
let rotated_eval = rotated.evaluate(&point);
// f(0,r1,...)
let left_eval = poly.evaluate(&left_point);
// f(1,r1,1-r2,...)
let right_eval = poly.evaluate(&right_point);
assert_eq!(
rotated_eval,
Expand Down