diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e429be901..ef20b0299 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -141,6 +141,15 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_curve_syscalls + - name: Run fibonacci (release) in 3 shards with CENO_CROSS_SHARD_LIMIT + env: + RUST_LOG: debug + RUSTFLAGS: "-C opt-level=3" + MOCK_PROVING: 1 + CENO_CROSS_SHARD_LIMIT: 32 + run: cargo run --release --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=20000 --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci + + - name: Install cargo make run: | cargo make --version || cargo install cargo-make diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 5e2c066a9..232069ad3 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -7,6 +7,7 @@ use crate::{ hal::ProverDevice, mock_prover::{LkMultiplicityKey, MockProver}, prover::ZKVMProver, + septic_curve::SepticPoint, verifier::ZKVMVerifier, }, state::GlobalState, @@ -44,6 +45,7 @@ use witness::next_pow2_instance_padding; pub const DEFAULT_MIN_CYCLE_PER_SHARDS: Cycle = 1 << 24; pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 27; +pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20; /// The polynomial commitment scheme kind #[derive( @@ -175,11 +177,16 @@ pub struct ShardContext<'a> { Either>, &'a mut BTreeMap>, pub cur_shard_cycle_range: std::ops::Range, pub expected_inst_per_shard: usize, + pub max_num_cross_shard_accesses: usize, } impl<'a> Default for ShardContext<'a> { fn default() -> Self { let max_threads = max_usable_threads(); + let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT") + .map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT)) + .unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT); + Self { shard_id: 0, num_shards: 1, @@ -202,6 +209,7 @@ impl<'a> Default for ShardContext<'a> { ), cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX, expected_inst_per_shard: usize::MAX, + max_num_cross_shard_accesses, } } } @@ -231,6 +239,10 @@ impl<'a> ShardContext<'a> { let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize; let max_threads = max_usable_threads(); + let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT") + .map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT)) + .unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT); + // strategies // 0. set cur_num_shards = num_provers // 1. split instructions evenly by cur_num_shards @@ -323,6 +335,7 @@ impl<'a> ShardContext<'a> { ), cur_shard_cycle_range, expected_inst_per_shard, + max_num_cross_shard_accesses, } }) .collect_vec() @@ -355,6 +368,7 @@ impl<'a> ShardContext<'a> { write_records_tbs: Either::Right(write), cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), expected_inst_per_shard: self.expected_inst_per_shard, + max_num_cross_shard_accesses: self.max_num_cross_shard_accesses, }, ) .collect_vec(), @@ -1125,17 +1139,26 @@ pub fn generate_witness<'a, E: ExtensionField>( pi.end_pc = current_shard_end_pc; pi.end_cycle = current_shard_end_cycle; // set shard ram bus expected output to pi - let shard_ram_witness = zkvm_witness.get_table_witness(&ShardRamCircuit::::name()); - if let Some(shard_ram_witness) = shard_ram_witness - && shard_ram_witness[0].num_instances() > 0 - { - for (f, v) in ShardRamCircuit::::extract_ec_sum( - &system_config.mmu_config.ram_bus_circuit, - &shard_ram_witness[0], - ) - .into_iter() - .zip_eq(pi.shard_rw_sum.as_mut_slice()) - { + let shard_ram_witnesses = zkvm_witness.get_witness(&ShardRamCircuit::::name()); + + if let Some(shard_ram_witnesses) = shard_ram_witnesses { + let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses + .iter() + .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) + .map(|shard_ram_witness| { + ShardRamCircuit::::extract_ec_sum( + &system_config.mmu_config.ram_bus_circuit, + &shard_ram_witness.witness_rmms[0], + ) + }) + .sum(); + + let xy = shard_ram_ec_sum + .x + .0 + .iter() + .chain(shard_ram_ec_sum.y.0.iter()); + for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { *v = f.to_canonical_u64() as u32; } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index cd540ebfa..fae5c14ca 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -196,7 +196,7 @@ impl MmuConfig<'_, E> { &self.local_final_circuit, &(shard_ctx, all_records.as_slice()), )?; - witness.assign_global_chip_circuit( + witness.assign_shared_circuit( cs, &(shard_ctx, all_records.as_slice()), &self.ram_bus_circuit, diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 716462927..c2145a08a 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -39,6 +39,9 @@ impl ZKVMConstraintSystem { fixed_traces.insert(circuit_index, fixed_trace_rmm); } + vm_pk + .circuit_name_to_index + .insert(c_name.clone(), circuit_index); let circuit_pk = cs.key_gen(); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 47b45ab1f..b64722375 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, fmt::{self, Debug}, + iter, ops::Div, rc::Rc, }; @@ -156,7 +157,8 @@ pub struct ZKVMProof> { pub raw_pi: Vec>, // the evaluation of raw_pi. pub pi_evals: Vec, - pub chip_proofs: BTreeMap>, + // each circuit may have multiple proof instances + pub chip_proofs: BTreeMap>>, pub witin_commit: >::Commitment, pub opening_proof: PCS::Proof, } @@ -165,7 +167,7 @@ impl> ZKVMProof { pub fn new( raw_pi: Vec>, pi_evals: Vec, - chip_proofs: BTreeMap>, + chip_proofs: BTreeMap>>, witin_commit: >::Commitment, opening_proof: PCS::Proof, ) -> Self { @@ -211,7 +213,13 @@ impl> ZKVMProof { let halt_instance_count = self .chip_proofs .get(&halt_circuit_index) - .map_or(0, |proof| proof.num_instances.iter().sum()); + .map_or(0, |proofs| { + proofs + .iter() + .flat_map(|proof| &proof.num_instances) + .copied() + .sum() + }); if halt_instance_count > 0 { assert_eq!( halt_instance_count, 1, @@ -240,6 +248,9 @@ impl + Serialize> fmt::Dis let tower_proof = self .chip_proofs .iter() + .flat_map(|(circuit_index, proofs)| { + iter::repeat_n(circuit_index, proofs.len()).zip(proofs) + }) .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.tower_proof); size.inspect(|size| { @@ -254,6 +265,9 @@ impl + Serialize> fmt::Dis let main_sumcheck = self .chip_proofs .iter() + .flat_map(|(circuit_index, proofs)| { + iter::repeat_n(circuit_index, proofs.len()).zip(proofs) + }) .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.main_sumcheck_proofs); size.inspect(|size| { diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 290640b7a..fbaa4e25e 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -24,17 +24,21 @@ use gkr_iop::{ use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Expression, + Expression, ToExpr, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, - virtual_poly::build_eq_x_r_vec, + virtual_poly::{build_eq_x_r_vec, eq_eval}, virtual_polys::VirtualPolynomialsBuilder, }; use rayon::iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; -use std::{collections::BTreeMap, sync::Arc}; +use std::{ + collections::BTreeMap, + iter::{once, repeat_n}, + sync::Arc, +}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverState}, @@ -75,9 +79,9 @@ impl CpuEccProver { let out_rt = transcript.sample_and_append_vec(b"ecc", n); let num_threads = optimal_sumcheck_threads(out_rt.len()); - // expression with add (3 zero constrains) and bypass (2 zero constrains) + // expression with add (3 zero constraints), bypass (2 zero constraints), export (2 zero constraints) let alpha_pows = transcript.sample_and_append_challenge_pows( - SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2 + SEPTIC_EXTENSION_DEGREE * 2, b"ecc_alpha", ); let mut alpha_pows_iter = alpha_pows.iter(); @@ -92,6 +96,17 @@ impl CpuEccProver { }; let mut sel_add_mle: MultilinearExtension<'_, E> = sel_add.compute(&out_rt, &sel_add_ctx).unwrap(); + + // [1,1,...,1,0] + let last_evaluation_index = (1 << n) - 2; + let lsi_on_hypercube = repeat_n(E::ONE, n - 1).chain(once(E::ZERO)).collect_vec(); + let mut sel_export = (0..(1 << n)) + .into_par_iter() + .map(|_| E::ZERO) + .collect::>(); + sel_export[last_evaluation_index] = eq_eval(&out_rt, lsi_on_hypercube.as_slice()); + let mut sel_export_mle = sel_export.into_mle(); + // we construct sel_bypass witness here // verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot` let mut sel_bypass_mle: Vec = build_eq_x_r_vec(&out_rt); @@ -110,6 +125,7 @@ impl CpuEccProver { let mut sel_bypass_mle = sel_bypass_mle.into_mle(); let sel_add_expr = expr_builder.lift(sel_add_mle.to_either()); let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either()); + let sel_export_expr = expr_builder.lift(sel_export_mle.to_either()); let mut exprs_add = vec![]; let mut exprs_bypass = vec![]; @@ -219,12 +235,35 @@ impl CpuEccProver { .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); - assert!(alpha_pows_iter.next().is_none()); + + // export x[1,...,1,0], y[1,...,1,0] for final result + let xp = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let yp = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); + let final_sum_x: SepticExtension = (xp.iter()) + .map(|x| x.get_base_field_vec()[last_evaluation_index]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum_y: SepticExtension = (yp.iter()) + .map(|y| y.get_base_field_vec()[last_evaluation_index]) // x[1,...,1,0] + .collect_vec() + .into(); + // 0 = sel_export * (x[1,b] - final_sum.x) + // 0 = sel_export * (y[1,b] - final_sum.y) + let export_expr = + x3.0.iter() + .zip_eq(final_sum_x.0.iter()) + // .chain(y3.0.iter().zip_eq(final_sum_y.0.iter())) + .map(|(x, final_x)| x - final_x.expr()) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))) + .sum::>() + * sel_export_expr; + // assert!(alpha_pows_iter.next().is_none()); let exprs_bypass = exprs_bypass.into_iter().sum::>() * sel_bypass_expr; let (zerocheck_proof, state) = IOPProverState::prove( - expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]), + expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass + export_expr], &[]), transcript, ); @@ -232,20 +271,7 @@ impl CpuEccProver { let evals = state.get_mle_flatten_final_evaluations(); // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt] - assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); - - let last_evaluation_index = (1 << n) - 1; - let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); - let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); - let final_sum_x: SepticExtension = (x3.iter()) - .map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] - .collect_vec() - .into(); - let final_sum_y: SepticExtension = (y3.iter()) - .map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] - .collect_vec() - .into(); - let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); + assert_eq!(evals.len(), 3 + SEPTIC_EXTENSION_DEGREE * 7); #[cfg(feature = "sanity-check")] { @@ -254,8 +280,10 @@ impl CpuEccProver { let y0 = filter_bj(&ys, 0); let x1 = filter_bj(&xs, 1); let y1 = filter_bj(&ys, 1); + let sel_export = eq_eval(&out_rt, &lsi_on_hypercube) * eq_eval(&rt, &lsi_on_hypercube); + assert_eq!(sel_export, evals[2]); - let evals = &evals[2..]; + let evals = &evals[3..]; // check evaluations for i in 0..SEPTIC_EXTENSION_DEGREE { assert_eq!(s[i].evaluate(&rt), evals[i]); @@ -263,10 +291,11 @@ impl CpuEccProver { assert_eq!(y0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 2 + i]); assert_eq!(x1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 3 + i]); assert_eq!(y1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 4 + i]); - assert_eq!(x3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]); - assert_eq!(y3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]); + assert_eq!(xp[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]); + assert_eq!(yp[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]); } } + let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); EccQuarkProof { diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index b1e89cab5..e9b7f8575 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -34,7 +34,7 @@ use p3::field::{Field, FieldAlgebra}; use rand::thread_rng; use std::{ cmp::max, - collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + collections::{BTreeSet, HashMap, HashSet}, fmt::Debug, fs::File, hash::Hash, @@ -1004,21 +1004,14 @@ Hints: let mut fixed_mles = HashMap::new(); let mut num_instances = HashMap::new(); - let circuit_index_fixed_num_instances: BTreeMap = fixed_trace - .circuit_fixed_traces - .iter() - .map(|(circuit_name, rmm)| { - ( - circuit_name.clone(), - rmm.as_ref().map(|rmm| rmm.num_instances()).unwrap_or(0), - ) - }) - .collect(); let mut lkm_tables = LkMultiplicityRaw::::default(); let mut lkm_opcodes = LkMultiplicityRaw::::default(); // Process all circuits. - for (circuit_name, composed_cs) in &cs.circuit_css { + for circuit_input in witnesses.iter_sorted() { + let circuit_name = &circuit_input.name; + let composed_cs = cs.circuit_css.get(circuit_name).unwrap(); + // for (circuit_name, composed_cs) in &cs.circuit_css { let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = &composed_cs; @@ -1037,19 +1030,11 @@ Hints: continue; } - let [witness, structural_witness] = witnesses - .get_opcode_witness(circuit_name) - .or_else(|| witnesses.get_table_witness(circuit_name)) - .unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name)); + let [witness, structural_witness] = &circuit_input.witness_rmms; let num_rows = if witness.num_instances() > 0 { witness.num_instances() } else if structural_witness.num_instances() > 0 { structural_witness.num_instances() - } else if composed_cs.is_static_circuit() { - circuit_index_fixed_num_instances - .get(circuit_name) - .copied() - .unwrap_or(0) } else { 0 }; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 577641fc1..64b427521 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -25,7 +25,6 @@ use sumcheck::{ structs::IOPProverMessage, }; use transcript::Transcript; -use witness::RowMajorMatrix; use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice}; use crate::{ @@ -86,7 +85,7 @@ impl< ) -> Result, ZKVMError> { let raw_pi = pi.to_vec::(); let mut pi_evals = ZKVMProof::::pi_evals(&raw_pi); - let mut chip_proofs: BTreeMap> = BTreeMap::new(); + let mut chip_proofs = BTreeMap::new(); let span = entered_span!("commit_to_pi", profiling_1 = true); // including raw public input to transcript @@ -111,90 +110,56 @@ impl< } exit_span!(span); - // keep track of circuit name to index mapping - let circuit_name_index_mapping = self - .pk - .circuit_pks - .keys() - .enumerate() - .map(|(k, v)| (v, k)) - .collect::>(); // only keep track of circuits that have non-zero instances - let mut num_instances = Vec::with_capacity(self.pk.circuit_pks.len()); - let mut num_instances_with_rotation = Vec::with_capacity(self.pk.circuit_pks.len()); - let mut circuit_name_num_instances_mapping = BTreeMap::new(); - for (index, (circuit_name, ProvingKey { vk, .. })) in self.pk.circuit_pks.iter().enumerate() - { - // skip omc init on >1 shard - if !shard_ctx.is_first_shard() && vk.get_cs().with_omc_init_only() { + + for chip_input in witnesses.iter_sorted() { + let pk = self + .pk + .circuit_pks + .get(&chip_input.name) + .ok_or(ZKVMError::VKNotFound( + format!("proving key for circuit {} not found", chip_input.name).into(), + ))?; + + // include omc init tables iff it's in first shard + if !shard_ctx.is_first_shard() && pk.get_cs().with_omc_init_only() { continue; } - // num_instance from witness might include rotation - if let Some(num_instance) = witnesses - .num_instances - .get(circuit_name) - .cloned() - .and_then(|num_instance| { - if num_instance.iter().sum::() > 0 { - Some(num_instance) - } else { - None - } - }) - .or_else(|| { - vk.get_cs().is_static_circuit().then(|| { - self.pk - .circuit_index_fixed_num_instances - .get(&index) - .copied() - .map(|num_instance| vec![num_instance]) - .unwrap_or(vec![]) - }) - }) - { - let num_instance_exclude_rotation = num_instance - .iter() - .map(|num_instance| num_instance >> vk.get_cs().rotation_vars().unwrap_or(0)) - .collect_vec(); - num_instances.push((index, num_instance_exclude_rotation.clone())); - circuit_name_num_instances_mapping - .insert(circuit_name, num_instance_exclude_rotation); - num_instances_with_rotation.push((index, num_instance)); + + if chip_input.num_instances() == 0 { + continue; } - } - // write (circuit_idx, num_var) to transcript - for (circuit_idx, num_instance) in &num_instances { + // num_instance from witness might include rotation + let num_instances = chip_input + .num_instances + .iter() + .map(|num_instance| num_instance >> pk.get_cs().rotation_vars().unwrap_or(0)) + .collect_vec(); + let circuit_idx = self.pk.circuit_name_to_index.get(&chip_input.name).unwrap(); + // write (circuit_idx, num_var) to transcript transcript.append_message(&circuit_idx.to_le_bytes()); - for num_instance in num_instance { + for num_instance in num_instances { transcript.append_message(&num_instance.to_le_bytes()); } } + // extract chip meta info before consuming witnesses + // (circuit_name, num_instances) + let name_and_instances = witnesses.get_witnesses_name_instance(); + let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); let mut wits_rmms = BTreeMap::new(); - let mut structural_wits = BTreeMap::new(); + let mut structural_rmms = Vec::with_capacity(name_and_instances.len()); // commit to opcode circuits first and then commit to table circuits, sorted by name - for (circuit_name, mut rmm) in witnesses.into_iter_sorted() { - let witness_rmm = rmm.remove(0); - // only table got structural witness - let structural_witness_rmm = if !rmm.is_empty() { - rmm.remove(0) - } else { - RowMajorMatrix::empty() - }; + for (i, chip_input) in witnesses.into_iter_sorted().enumerate() { + let [witness_rmm, structural_witness_rmm] = chip_input.witness_rmms; if witness_rmm.num_instances() > 0 { - wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); - } - if structural_witness_rmm.num_instances() > 0 { - let num_instances = circuit_name_num_instances_mapping - .get(&circuit_name) - .unwrap(); - let structural_witness = structural_witness_rmm.to_mles(); - structural_wits.insert(circuit_name, (structural_witness, num_instances)); + wits_rmms.insert(i, witness_rmm); } + structural_rmms.push(structural_witness_rmm); } // commit to witness traces in batch @@ -222,78 +187,87 @@ impl< exit_span!(public_input_span); let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); - let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold( - (vec![], vec![]), - |(mut points, mut evaluations), (index, (circuit_name, pk))| { - let num_instances = circuit_name_num_instances_mapping - .get(&circuit_name) - .cloned() - .unwrap_or_default(); - let cs = pk.get_cs(); - if !shard_ctx.is_first_shard() && cs.with_omc_init_only() { - assert!(num_instances.is_empty()); - // skip drain respective fixed because we use different set of fixed commitment - return Ok::<(Vec<_>, Vec>), ZKVMError>((points, evaluations)); - } - if num_instances.is_empty() { - // we need to drain respective fixed when num_instances is 0 - if cs.num_fixed() > 0 { - let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec(); - } - return Ok::<(Vec<_>, Vec>), ZKVMError>((points, evaluations)); - } - transcript.append_field_element(&E::BaseField::from_canonical_u64(index as u64)); - - // TODO: add an enum for circuit type either in constraint_system or vk - let witness_mle = witness_mles - .drain(..cs.num_witin()) - .map(|mle| mle.into()) - .collect_vec(); - - let structural_witness_span = - entered_span!("structural_witness", profiling_2 = true); - let structural_mles = structural_wits - .remove(circuit_name) - .map(|(sw, _)| sw) - .unwrap_or(vec![]); - let structural_witness = self.device.transport_mles(&structural_mles); - exit_span!(structural_witness_span); - - let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); - let input = ProofInput { - witness: witness_mle, - fixed, - structural_witness, - public_input: public_input.clone(), - pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), - num_instances: num_instances.clone(), - has_ecc_ops: cs.has_ecc_ops(), - }; - - let (opcode_proof, pi_in_evals, input_opening_point) = - self.create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; - tracing::trace!( - "generated proof for opcode {} with num_instances={:?}", - circuit_name, - num_instances - ); - if cs.num_witin() > 0 || cs.num_fixed() > 0 { - points.push(input_opening_point); - evaluations.push(vec![ - opcode_proof.wits_in_evals.clone(), - opcode_proof.fixed_in_evals.clone(), - ]); - } else { - assert!(opcode_proof.wits_in_evals.is_empty()); - assert!(opcode_proof.fixed_in_evals.is_empty()); - } - chip_proofs.insert(index, opcode_proof); - for (idx, eval) in pi_in_evals { - pi_evals[idx] = eval; + + let mut points = Vec::new(); + let mut evaluations = Vec::new(); + for ((circuit_name, num_instances), structural_rmm) in name_and_instances + .into_iter() + .zip_eq(structural_rmms.into_iter()) + { + let circuit_idx = self + .pk + .circuit_name_to_index + .get(&circuit_name) + .cloned() + .expect("invalid circuit {} not exist in ceno zkvm"); + let pk = self.pk.circuit_pks.get(&circuit_name).unwrap(); + let cs = pk.get_cs(); + if !shard_ctx.is_first_shard() && cs.with_omc_init_only() { + assert!(num_instances.is_empty()); + // skip drain respective fixed because we use different set of fixed commitment + continue; + } + if num_instances.is_empty() { + // we need to drain respective fixed when num_instances is 0 + if cs.num_fixed() > 0 { + let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec(); } - Ok((points, evaluations)) - }, - )?; + continue; + } + transcript.append_field_element(&E::BaseField::from_canonical_u64(circuit_idx as u64)); + + // TODO: add an enum for circuit type either in constraint_system or vk + let witness_mle = witness_mles + .drain(..cs.num_witin()) + .map(|mle| mle.into()) + .collect_vec(); + + let structural_witness_span = entered_span!("structural_witness", profiling_2 = true); + let structural_mles = structural_rmm.to_mles(); + let structural_witness = self.device.transport_mles(&structural_mles); + exit_span!(structural_witness_span); + + let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); + let input = ProofInput { + witness: witness_mle, + fixed, + structural_witness, + public_input: public_input.clone(), + pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), + num_instances: num_instances.clone(), + has_ecc_ops: cs.has_ecc_ops(), + }; + + let (opcode_proof, pi_in_evals, input_opening_point) = self.create_chip_proof( + circuit_name.as_str(), + pk, + input, + &mut transcript, + &challenges, + )?; + tracing::trace!( + "generated proof for opcode {} with num_instances={:?}", + circuit_name, + num_instances + ); + if cs.num_witin() > 0 || cs.num_fixed() > 0 { + points.push(input_opening_point); + evaluations.push(vec![ + opcode_proof.wits_in_evals.clone(), + opcode_proof.fixed_in_evals.clone(), + ]); + } else { + assert!(opcode_proof.wits_in_evals.is_empty()); + assert!(opcode_proof.fixed_in_evals.is_empty()); + } + chip_proofs + .entry(circuit_idx) + .or_insert(vec![]) + .push(opcode_proof); + for (idx, eval) in pi_in_evals { + pi_evals[idx] = eval; + } + } exit_span!(main_proofs_span); // batch opening pcs diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index be85360f9..5027e4fa5 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -154,7 +154,12 @@ fn test_rw_lk_expression_combination() { // get proof let prover = ZKVMProver::new(pk, device); let mut transcript = BasicTranscript::new(b"test"); - let mut rmm = zkvm_witness.into_iter_sorted().next().unwrap().1; + let mut rmm: Vec<_> = zkvm_witness + .into_iter_sorted() + .next() + .unwrap() + .witness_rmms + .into(); let (rmm, structural_rmm) = (rmm.remove(0), rmm.remove(0)); let wits_in = rmm.to_mles(); let structural_wits_in = structural_rmm.to_mles(); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2513076a7..57f324810 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,6 +1,9 @@ use either::Either; use ff_ext::ExtensionField; -use std::{iter, marker::PhantomData}; +use std::{ + iter::{self, once, repeat_n}, + marker::PhantomData, +}; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; @@ -201,9 +204,10 @@ impl> ZKVMVerifier } // write (circuit_idx, num_instance) to transcript - for (circuit_idx, proof) in &vm_proof.chip_proofs { + for (circuit_idx, proofs) in &vm_proof.chip_proofs { transcript.append_message(&circuit_idx.to_le_bytes()); - for num_instance in &proof.num_instances { + // length of proof.num_instances will be constrained in verify_chip_proof + for num_instance in proofs.iter().flat_map(|proof| &proof.num_instances) { transcript.append_message(&num_instance.to_le_bytes()); } } @@ -235,18 +239,34 @@ impl> ZKVMVerifier let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut shard_ec_sum = SepticPoint::::default(); - for (index, proof) in &vm_proof.chip_proofs { - let num_instance: usize = proof.num_instances.iter().sum(); - assert!(num_instance > 0); + + // check num proofs + for (index, proofs) in &vm_proof.chip_proofs { let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; - if shard_id > 0 && circuit_vk.get_cs().with_omc_init_only() { return Err(ZKVMError::InvalidProof( format!("{shard_id}th shard non-first shard got omc dynamic table init",) .into(), )); } + if shard_id == 0 && circuit_vk.get_cs().with_omc_init_only() && proofs.len() != 1 { + return Err(ZKVMError::InvalidProof( + format!("{shard_id}th shard first shard got > 1 omc dynamic table init",) + .into(), + )); + } + } + + for (index, proof) in vm_proof + .chip_proofs + .iter() + .flat_map(|(index, proofs)| iter::repeat_n(index, proofs.len()).zip(proofs)) + { + let num_instance: usize = proof.num_instances.iter().sum(); + assert!(num_instance > 0); + let circuit_name = &self.vk.circuit_index_to_name[index]; + let circuit_vk = &self.vk.circuit_vks[circuit_name]; // check chip proof is well-formed if proof.wits_in_evals.len() != circuit_vk.get_cs().num_witin() @@ -501,25 +521,8 @@ impl> ZKVMVerifier tracing::debug!("verifying ecc proof..."); assert!(proof.ecc_proof.is_some()); let ecc_proof = proof.ecc_proof.as_ref().unwrap(); - - let expected_septic_xy = cs - .ec_final_sum - .iter() - .map(|expr| { - eval_by_expr_with_instance(&[], &[], &[], pi, challenges, expr) - .right() - .and_then(|v| v.as_base()) - .unwrap() - }) - .collect_vec(); - let expected_septic_x: SepticExtension = - expected_septic_xy[0..SEPTIC_EXTENSION_DEGREE].into(); - let expected_septic_y: SepticExtension = - expected_septic_xy[SEPTIC_EXTENSION_DEGREE..].into(); - - assert_eq!(&ecc_proof.sum.x, &expected_septic_x); - assert_eq!(&ecc_proof.sum.y, &expected_septic_y); assert!(!ecc_proof.sum.is_infinity); + EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; tracing::debug!("ecc proof verified."); Some(ecc_proof.sum.clone()) @@ -889,7 +892,7 @@ impl EccVerifier { let num_vars = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; let out_rt = transcript.sample_and_append_vec(b"ecc", num_vars); let alpha_pows = transcript.sample_and_append_challenge_pows( - SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 4, b"ecc_alpha", ); let mut alpha_pows_iter = alpha_pows.iter(); @@ -905,19 +908,20 @@ impl EccVerifier { transcript, ); - let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE].into(); + let evals = &proof.evals[3..]; // skip sel_add, sel_bypass, sel_export + let s0: SepticExtension = evals[0..][..SEPTIC_EXTENSION_DEGREE].into(); let x0: SepticExtension = - proof.evals[2..][SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + evals[SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let y0: SepticExtension = - proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + evals[2 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let x1: SepticExtension = - proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + evals[3 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let y1: SepticExtension = - proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + evals[4 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let x3: SepticExtension = - proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + evals[5 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let y3: SepticExtension = - proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + evals[6 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let rt = sumcheck_claim .point @@ -925,11 +929,13 @@ impl EccVerifier { .map(|c| c.elements) .collect_vec(); - // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) - // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] - // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + // zerocheck: 0 = s[1,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + // zerocheck: 0 = s[1,b]^2 - x[b,0] - x[b,1] - x[1,b] + // zerocheck: 0 = s[1,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) // zerocheck: 0 = (x[1,b] - x[b,0]) // zerocheck: 0 = (y[1,b] - y[b,0]) + // zerocheck: 0 = (x[1,b] - final_x) + // zerocheck: 0 = (y[1,b] - final_y) // // note that they are not septic extension field elements, // we just want to reuse the multiply/add/sub formulas @@ -989,10 +995,36 @@ impl EccVerifier { )); } + // derive `sel_export` + let lsi_on_hypercube = repeat_n(E::ONE, out_rt.len() - 1) + .chain(once(E::ZERO)) + .collect_vec(); + let expected_sel_export = + eq_eval(&out_rt, &lsi_on_hypercube) * eq_eval(&rt, &lsi_on_hypercube); + if proof.evals[2] != expected_sel_export { + return Err(ZKVMError::VerifyError( + (format!( + "sel_export evaluation mismatch, expected {}, got {}", + expected_sel_export, proof.evals[2] + )) + .into(), + )); + } + let export_evaluations: E = + x3.0.iter() + .zip_eq(proof.sum.x.0.iter()) + // .chain(y3.0.iter().zip_eq(proof.sum.y.0.iter())) + .map(|(a, b)| *a - *b) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(c, alpha)| c * *alpha) + .sum(); + let add_evaluations = vec![v1, v2, v3].into_iter().flatten().sum::(); let bypass_evaluations = vec![v4, v5].into_iter().flatten().sum::(); if sumcheck_claim.expected_evaluation - != add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + != add_evaluations * expected_sel_add + + bypass_evaluations * expected_sel_bypass + + export_evaluations * expected_sel_export { return Err(ZKVMError::VerifyError( (format!( diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs index 270259c30..35ec58025 100644 --- a/ceno_zkvm/src/stats.rs +++ b/ceno_zkvm/src/stats.rs @@ -236,7 +236,7 @@ impl Report { let num_instances = zkvm_witnesses .clone() .into_iter_sorted() - .map(|(key, value)| (key, value[0].num_instances())) + .map(|chip_input| (chip_input.name, chip_input.num_instances[0])) .collect::>(); Self::new(static_report, num_instances, program_name) } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index e8f5edd10..c24989a02 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -16,7 +16,10 @@ use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Mult use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::{ + iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, + prelude::ParallelSlice, +}; use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ @@ -140,11 +143,6 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.num_fixed } - /// static circuit means there is only fixed column - pub fn is_static_circuit(&self) -> bool { - (self.num_witin() + self.num_structural_witin()) == 0 && self.num_fixed() > 0 - } - pub fn num_reads(&self) -> usize { self.zkvm_v1_css.r_expressions.len() + self.zkvm_v1_css.r_table_expressions.len() } @@ -319,23 +317,42 @@ impl ZKVMFixedTraces { } } +#[derive(Clone)] +pub struct ChipInput { + pub name: String, + pub witness_rmms: RMMCollections, + // in shard ram chip, num_instances length would be > 1 + pub num_instances: Vec, +} + +impl ChipInput { + pub fn new( + name: String, + witness_rmms: RMMCollections, + num_instances: Vec, + ) -> Self { + Self { + name, + witness_rmms, + num_instances, + } + } + + pub fn num_instances(&self) -> usize { + self.num_instances.iter().sum() + } +} + #[derive(Default, Clone)] pub struct ZKVMWitnesses { - witnesses_opcodes: BTreeMap>, - witnesses_tables: BTreeMap>, + witnesses: BTreeMap>>, lk_mlts: BTreeMap>, combined_lk_mlt: Option>>, - // in ram bus chip, num_instances length would be > 1 - pub num_instances: BTreeMap>, } impl ZKVMWitnesses { - pub fn get_opcode_witness(&self, name: &String) -> Option<&RMMCollections> { - self.witnesses_opcodes.get(name) - } - - pub fn get_table_witness(&self, name: &String) -> Option<&RMMCollections> { - self.witnesses_tables.get(name) + pub fn get_witness(&self, name: &String) -> Option<&Vec>> { + self.witnesses.get(name) } pub fn get_lk_mlt(&self, name: &String) -> Option<&Multiplicity> { @@ -359,13 +376,17 @@ impl ZKVMWitnesses { cs.zkvm_v1_css.num_structural_witin as usize, records, )?; - assert!( - self.num_instances - .insert(OC::name(), vec![witness[0].num_instances()]) - .is_none() + let num_instances = vec![witness[0].num_instances()]; + let input = ChipInput::new( + OC::name(), + witness, + if num_instances[0] > 0 { + num_instances + } else { + vec![] + }, ); - assert!(self.witnesses_opcodes.insert(OC::name(), witness).is_none()); - assert!(!self.witnesses_tables.contains_key(&OC::name())); + assert!(self.witnesses.insert(OC::name(), vec![input]).is_none()); assert!( self.lk_mlts .insert(OC::name(), logup_multiplicity) @@ -418,18 +439,21 @@ impl ZKVMWitnesses { input, )?; let num_instances = std::cmp::max(witness[0].num_instances(), witness[1].num_instances()); - assert!( - self.num_instances - .insert(TC::name(), vec![num_instances]) - .is_none() + let input = ChipInput::new( + TC::name(), + witness, + if num_instances > 0 { + vec![num_instances] + } else { + vec![] + }, ); - assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); - assert!(!self.witnesses_opcodes.contains_key(&TC::name())); + assert!(self.witnesses.insert(TC::name(), vec![input]).is_none()); Ok(()) } - pub fn assign_global_chip_circuit( + pub fn assign_shared_circuit( &mut self, cs: &ZKVMConstraintSystem, // shard_ctx: &ShardContext, @@ -492,8 +516,6 @@ impl ZKVMWitnesses { } else { vec![] }; - let non_first_shard_records_len = non_first_shard_records.len(); - let global_input = shard_ctx .write_records() .par_iter() @@ -529,64 +551,61 @@ impl ZKVMWitnesses { assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); - let witness = ShardRamCircuit::assign_instances( - config, - cs.zkvm_v1_css.num_witin as usize, - cs.zkvm_v1_css.num_structural_witin as usize, - self.combined_lk_mlt.as_ref().unwrap(), - &global_input, - )?; - // set num_read, num_write as separate instance - assert!( - self.num_instances - .insert( + let circuit_inputs = global_input + .par_chunks(shard_ctx.max_num_cross_shard_accesses) + .map(|shard_accesses| { + let witness = ShardRamCircuit::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + shard_accesses, + )?; + let num_reads = shard_accesses + .par_iter() + .filter(|access| access.record.is_to_write_set) + .count(); + let num_writes = shard_accesses.len() - num_reads; + + Ok(ChipInput::new( ShardRamCircuit::::name(), - vec![ - // global write -> local read - shard_ctx - .write_records() - .iter() - .map(|records| records.len()) - .sum::() - + non_first_shard_records_len, - // global read -> local write - shard_ctx - .read_records() - .iter() - .map(|records| records.len()) - .sum(), - ] - ) - .is_none() - ); + witness, + vec![num_reads, num_writes], + )) + }) + .collect::, ZKVMError>>()?; + // set num_read, num_write as separate instance assert!( - self.witnesses_tables - .insert(ShardRamCircuit::::name(), witness) + self.witnesses + .insert(ShardRamCircuit::::name(), circuit_inputs) .is_none() ); - assert!( - !self - .witnesses_opcodes - .contains_key(&ShardRamCircuit::::name()) - ); Ok(()) } + pub fn get_witnesses_name_instance(&self) -> Vec<(String, Vec)> { + self.witnesses + .iter() + .flat_map(|(_, chip_inputs)| { + chip_inputs + .iter() + .map(|chip_input| (chip_input.name.clone(), chip_input.num_instances.clone())) + }) + .collect_vec() + } + + pub fn iter_sorted(&self) -> impl Iterator> { + self.witnesses + .iter() + .flat_map(|(_, chip_input)| chip_input.iter()) + } + /// Iterate opcode/table circuits, sorted by alphabetical order. - pub fn into_iter_sorted( - self, - ) -> impl Iterator>)> { - self.witnesses_opcodes - .into_iter() - .map(|(name, witnesses)| (name, witnesses.into())) - .chain( - self.witnesses_tables - .into_iter() - .map(|(name, witnesses)| (name, witnesses.into())), - ) - .collect::>() + pub fn into_iter_sorted(self) -> impl Iterator> { + self.witnesses .into_iter() + .flat_map(|(_, chip_inputs)| chip_inputs.into_iter()) } } pub struct ZKVMProvingKey> { @@ -596,6 +615,7 @@ pub struct ZKVMProvingKey> pub entry_pc: u32, // pk for opcode and table circuits pub circuit_pks: BTreeMap>, + pub circuit_name_to_index: BTreeMap, // Fixed commitments are separated into two groups: // @@ -629,6 +649,7 @@ impl> ZKVMProvingKey ShardRamCircuit { pub fn extract_ec_sum( config: &ShardRamConfig, rmm: &witness::RowMajorMatrix<::BaseField>, - ) -> Vec<::BaseField> { + ) -> SepticPoint<::BaseField> { assert!(rmm.height() >= 2); let instance = &rmm[rmm.height() - 2]; - config + let xy = config .x .iter() .chain(config.y.iter()) .map(|witin| instance[witin.id as usize]) - .collect_vec() + .collect_vec(); + + let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); + let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); + + SepticPoint::from_affine(x, y) } } impl TableCircuit for ShardRamCircuit { type TableConfig = ShardRamConfig; type FixedInput = (); - type WitnessInput = Vec>; + type WitnessInput = [ShardRamInput]; fn name() -> String { "ShardRamCircuit".to_string() @@ -652,7 +657,6 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::riscv::constants::SHARD_RW_SUM_IDX, scheme::{ PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, @@ -768,12 +772,7 @@ mod tests { // api extract ec sum from rmm witness assert_eq!( - public_value - .to_vec::() - .into_iter() - .skip(SHARD_RW_SUM_IDX) - .flatten() - .collect_vec(), + global_ec_sum, ShardRamCircuit::extract_ec_sum(&config, &witness[0]) );