Skip to content

Commit 80a33d3

Browse files
authored
Feat: simplify basefold's verifier (#974)
## Rationale The current API used by `mpcs` crate is tightly coupled with our zkvm's design. This is not desireable. ## Summary We decide to be align with the `Plonky3` style API. - [x] #979. - [x] #981. - [x] #982
1 parent 1b02425 commit 80a33d3

File tree

17 files changed

+667
-1276
lines changed

17 files changed

+667
-1276
lines changed

ceno_zkvm/benches/riscv_add.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{collections::BTreeMap, time::Duration};
1+
use std::time::Duration;
22

33
use ceno_zkvm::{
44
self,
@@ -77,10 +77,7 @@ fn bench_add(c: &mut Criterion) {
7777
for _ in 0..iters {
7878
// generate mock witness
7979
let num_instances = 1 << instance_num_vars;
80-
let rmms = BTreeMap::from([(
81-
0,
82-
RowMajorMatrix::rand(&mut OsRng, num_instances, num_witin),
83-
)]);
80+
let rmms = vec![RowMajorMatrix::rand(&mut OsRng, num_instances, num_witin)];
8481

8582
let instant = std::time::Instant::now();
8683
let num_instances = 1 << instance_num_vars;

ceno_zkvm/src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum ZKVMError {
1313
UtilError(UtilError),
1414
WitnessNotFound(String),
1515
InvalidWitness(String),
16+
InvalidProof(String),
1617
VKNotFound(String),
1718
FixedTraceNotFound(String),
1819
VerifyError(String),

ceno_zkvm/src/scheme.rs

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub struct ZKVMChipProof<E: ExtensionField> {
4949

5050
pub tower_proof: TowerProofs<E>,
5151

52+
pub num_instances: usize,
5253
pub fixed_in_evals: Vec<E>,
5354
pub wits_in_evals: Vec<E>,
5455
}
@@ -114,32 +115,25 @@ pub struct ZKVMProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
114115
pub raw_pi: Vec<Vec<E::BaseField>>,
115116
// the evaluation of raw_pi.
116117
pub pi_evals: Vec<E>,
117-
// circuit size -> instance mapping
118-
pub num_instances: Vec<(usize, usize)>,
119-
opcode_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
120-
table_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
118+
chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
121119
witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
122-
pub fixed_witin_opening_proof: PCS::Proof,
120+
pub opening_proof: PCS::Proof,
123121
}
124122

125123
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
126124
pub fn new(
127125
raw_pi: Vec<Vec<E::BaseField>>,
128126
pi_evals: Vec<E>,
129-
opcode_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
130-
table_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
127+
chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
131128
witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
132-
fixed_witin_opening_proof: PCS::Proof,
133-
num_instances: Vec<(usize, usize)>,
129+
opening_proof: PCS::Proof,
134130
) -> Self {
135131
Self {
136132
raw_pi,
137133
pi_evals,
138-
opcode_proofs,
139-
table_proofs,
134+
chip_proofs,
140135
witin_commit,
141-
fixed_witin_opening_proof,
142-
num_instances,
136+
opening_proof,
143137
}
144138
}
145139

@@ -164,23 +158,19 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
164158
}
165159

166160
pub fn num_circuits(&self) -> usize {
167-
self.opcode_proofs.len() + self.table_proofs.len()
161+
self.chip_proofs.len()
168162
}
169163

170164
pub fn has_halt(&self, vk: &ZKVMVerifyingKey<E, PCS>) -> bool {
165+
let halt_circuit_index = vk
166+
.circuit_vks
167+
.keys()
168+
.position(|circuit_name| *circuit_name == HaltInstruction::<E>::name())
169+
.expect("halt circuit not exist");
171170
let halt_instance_count = self
172-
.num_instances
173-
.iter()
174-
.find_map(|(circuit_index, num_instances)| {
175-
(*circuit_index
176-
== vk
177-
.circuit_vks
178-
.keys()
179-
.position(|circuit_name| *circuit_name == HaltInstruction::<E>::name())
180-
.expect("halt circuit not exist"))
181-
.then_some(*num_instances)
182-
})
183-
.unwrap_or(0);
171+
.chip_proofs
172+
.get(&halt_circuit_index)
173+
.map_or(0, |proof| proof.num_instances);
184174
if halt_instance_count > 0 {
185175
assert_eq!(
186176
halt_instance_count, 1,
@@ -203,39 +193,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
203193
let mpcs_opcode_commitment =
204194
bincode::serialized_size(&self.witin_commit).expect("serialization error");
205195
let mpcs_opcode_opening =
206-
bincode::serialized_size(&self.fixed_witin_opening_proof).expect("serialization error");
196+
bincode::serialized_size(&self.opening_proof).expect("serialization error");
207197

208-
// opcode circuit for tower proof size
209-
let tower_proof_opcode = self
210-
.opcode_proofs
211-
.iter()
212-
.map(|(circuit_index, proof)| {
213-
let size = bincode::serialized_size(&proof.tower_proof);
214-
size.inspect(|size| {
215-
*by_circuitname_stats.entry(circuit_index).or_insert(0) += size;
216-
})
217-
})
218-
.collect::<Result<Vec<u64>, _>>()
219-
.expect("serialization error")
220-
.iter()
221-
.sum::<u64>();
222-
// opcode circuit main sumcheck
223-
let main_sumcheck_opcode = self
224-
.opcode_proofs
225-
.iter()
226-
.map(|(circuit_index, proof)| {
227-
let size = bincode::serialized_size(&proof.main_sumcheck_proofs);
228-
size.inspect(|size| {
229-
*by_circuitname_stats.entry(circuit_index).or_insert(0) += size;
230-
})
231-
})
232-
.collect::<Result<Vec<u64>, _>>()
233-
.expect("serialization error")
234-
.iter()
235-
.sum::<u64>();
236-
// table circuit for tower proof size
237-
let tower_proof_table = self
238-
.table_proofs
198+
// tower proof size
199+
let tower_proof = self
200+
.chip_proofs
239201
.iter()
240202
.map(|(circuit_index, proof)| {
241203
let size = bincode::serialized_size(&proof.tower_proof);
@@ -247,9 +209,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
247209
.expect("serialization error")
248210
.iter()
249211
.sum::<u64>();
250-
// table circuit same r sumcheck
251-
let same_r_sumcheck_table = self
252-
.table_proofs
212+
// main sumcheck
213+
let main_sumcheck = self
214+
.chip_proofs
253215
.iter()
254216
.map(|(circuit_index, proof)| {
255217
let size = bincode::serialized_size(&proof.main_sumcheck_proofs);
@@ -286,20 +248,16 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
286248
"overall_size {:.2}mb. \n\
287249
mpcs commitment {:?}% \n\
288250
mpcs opening {:?}% \n\
289-
opcode tower proof {:?}% \n\
290-
opcode main sumcheck proof {:?}% \n\
291-
table tower proof {:?}% \n\
292-
table same r sumcheck proof {:?}% \n\n\
251+
tower proof {:?}% \n\
252+
main sumcheck proof {:?}% \n\
293253
by circuit_name break down: \n\
294254
{}
295255
",
296256
byte_to_mb(overall_size),
297257
(mpcs_opcode_commitment * 100).div(overall_size),
298258
(mpcs_opcode_opening * 100).div(overall_size),
299-
(tower_proof_opcode * 100).div(overall_size),
300-
(main_sumcheck_opcode * 100).div(overall_size),
301-
(tower_proof_table * 100).div(overall_size),
302-
(same_r_sumcheck_table * 100).div(overall_size),
259+
(tower_proof * 100).div(overall_size),
260+
(main_sumcheck * 100).div(overall_size),
303261
by_circuitname_stats,
304262
)
305263
}

ceno_zkvm/src/scheme/cpu/mod.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TraceCommitter<CpuBa
269269
self.pp = Some(prover_param);
270270
self.pp.as_ref().unwrap()
271271
};
272-
let pcs_data = PCS::batch_commit(prover_param, traces).unwrap();
272+
let pcs_data = PCS::batch_commit(prover_param, traces.into_values().collect_vec()).unwrap();
273273
let commit = PCS::get_pure_commitment(&pcs_data);
274274
let mles = PCS::get_arc_mle_witness_from_commitment(&pcs_data)
275275
.into_par_iter()
@@ -752,22 +752,40 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> OpeningProver<CpuBac
752752
witness_data: PCS::CommitmentWithWitness,
753753
fixed_data: Option<Arc<PCS::CommitmentWithWitness>>,
754754
points: Vec<Point<E>>,
755-
evals: Vec<Vec<E>>,
755+
mut evals: Vec<Vec<E>>, // where each inner Vec<E> = wit_evals + fixed_evals
756756
circuit_num_polys: &[(usize, usize)],
757757
num_instances: &[(usize, usize)],
758758
transcript: &mut impl Transcript<E>,
759759
) -> PCS::Proof {
760-
PCS::batch_open(
761-
self.pp.as_ref().unwrap(),
762-
num_instances,
763-
fixed_data.as_ref().map(|f| f.as_ref()),
760+
let mut rounds = vec![];
761+
rounds.push((
764762
&witness_data,
765-
&points,
766-
&evals,
767-
circuit_num_polys,
768-
transcript,
769-
)
770-
.unwrap()
763+
points
764+
.iter()
765+
.zip_eq(evals.iter_mut())
766+
.zip_eq(num_instances.iter())
767+
.map(|((point, evals), (chip_idx, _))| {
768+
let (num_witin, _) = circuit_num_polys[*chip_idx];
769+
(point.clone(), evals.drain(..num_witin).collect_vec())
770+
})
771+
.collect_vec(),
772+
));
773+
if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) {
774+
rounds.push((
775+
fixed_data,
776+
points
777+
.iter()
778+
.zip_eq(evals.iter_mut())
779+
.zip_eq(num_instances.iter())
780+
.filter(|(_, (chip_idx, _))| {
781+
let (_, num_fixed) = circuit_num_polys[*chip_idx];
782+
num_fixed > 0
783+
})
784+
.map(|((point, evals), _)| (point.clone(), evals.to_vec()))
785+
.collect_vec(),
786+
));
787+
}
788+
PCS::batch_open(self.pp.as_ref().unwrap(), rounds, transcript).unwrap()
771789
}
772790
}
773791

ceno_zkvm/src/scheme/prover.rs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ impl<
8484
) -> Result<ZKVMProof<E, PCS>, ZKVMError> {
8585
let raw_pi = pi.to_vec::<E>();
8686
let mut pi_evals = ZKVMProof::<E, PCS>::pi_evals(&raw_pi);
87-
let mut opcode_proofs: BTreeMap<usize, ZKVMChipProof<E>> = BTreeMap::new();
88-
let mut table_proofs: BTreeMap<usize, ZKVMChipProof<E>> = BTreeMap::new();
87+
let mut chip_proofs: BTreeMap<usize, ZKVMChipProof<E>> = BTreeMap::new();
8988

9089
let span = entered_span!("commit_to_pi", profiling_1 = true);
9190
// including raw public input to transcript
@@ -233,23 +232,23 @@ impl<
233232
);
234233
points.push(input_opening_point);
235234
evaluations.push(opcode_proof.wits_in_evals.clone());
236-
opcode_proofs.insert(index, opcode_proof);
235+
chip_proofs.insert(index, opcode_proof);
237236
} else {
238237
// FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances
239238
input.num_instances = 1 << input.log2_num_instances();
240-
let (table_proof, pi_in_evals, input_opening_point) = self.create_chip_proof(
241-
circuit_name,
242-
pk,
243-
input,
244-
&mut transcript,
245-
&challenges,
246-
)?;
239+
let (mut table_proof, pi_in_evals, input_opening_point) = self
240+
.create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?;
247241
points.push(input_opening_point);
248-
evaluations.push(table_proof.wits_in_evals.clone());
249-
if cs.num_fixed() > 0 {
250-
evaluations.push(table_proof.fixed_in_evals.clone());
251-
}
252-
table_proofs.insert(index, table_proof);
242+
evaluations.push(
243+
[
244+
table_proof.wits_in_evals.clone(),
245+
table_proof.fixed_in_evals.clone(),
246+
]
247+
.concat(),
248+
);
249+
// FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances
250+
table_proof.num_instances = num_instances;
251+
chip_proofs.insert(index, table_proof);
253252
for (idx, eval) in pi_in_evals {
254253
pi_evals[idx] = eval;
255254
}
@@ -282,12 +281,9 @@ impl<
282281
let vm_proof = ZKVMProof::new(
283282
raw_pi,
284283
pi_evals,
285-
opcode_proofs,
286-
table_proofs,
284+
chip_proofs,
287285
witin_commit,
288286
mpcs_opening_proof,
289-
// verifier need this information from prover to achieve non-uniform design.
290-
num_instances,
291287
);
292288
exit_span!(main_proofs_span);
293289

@@ -363,6 +359,7 @@ impl<
363359
tower_proof,
364360
fixed_in_evals,
365361
wits_in_evals,
362+
num_instances: input.num_instances,
366363
},
367364
pi_in_evals,
368365
input_opening_point,

ceno_zkvm/src/scheme/tests.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{collections::BTreeMap, marker::PhantomData};
1+
use std::marker::PhantomData;
22

33
use crate::{
44
circuit_builder::CircuitBuilder,
@@ -138,12 +138,8 @@ fn test_rw_lk_expression_combination() {
138138
let rmm = zkvm_witness.into_iter_sorted().next().unwrap().1.remove(0);
139139
let wits_in = rmm.to_mles();
140140
// commit to main traces
141-
let commit_with_witness = Pcs::batch_commit_and_write(
142-
&prover.pk.pp,
143-
vec![(0, rmm)].into_iter().collect::<BTreeMap<_, _>>(),
144-
&mut transcript,
145-
)
146-
.unwrap();
141+
let commit_with_witness =
142+
Pcs::batch_commit_and_write(&prover.pk.pp, vec![rmm], &mut transcript).unwrap();
147143
let witin_commit = Pcs::get_pure_commitment(&commit_with_witness);
148144

149145
let wits_in = wits_in.into_iter().map(|v| v.into()).collect_vec();
@@ -189,7 +185,6 @@ fn test_rw_lk_expression_combination() {
189185
name.as_str(),
190186
verifier.vk.circuit_vks.get(&name).unwrap(),
191187
&proof,
192-
num_instances,
193188
&[],
194189
&mut v_transcript,
195190
NUM_FANIN,

0 commit comments

Comments
 (0)