Skip to content

Commit 22a9b7e

Browse files
authored
[precompile] separate prover backend for cpu/gpu (#972)
to support e2e integration, build on top of #971
1 parent 4c58c75 commit 22a9b7e

28 files changed

+911
-664
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ceno_zkvm/benches/riscv_add.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@ use std::{collections::BTreeMap, time::Duration};
33
use ceno_zkvm::{
44
self,
55
instructions::{Instruction, riscv::arith::AddInstruction},
6-
scheme::{
7-
cpu::{CpuBackend, CpuProver},
8-
hal::ProofInput,
9-
prover::ZKVMProver,
10-
},
6+
scheme::{hal::ProofInput, prover::ZKVMProver},
117
structs::{ZKVMConstraintSystem, ZKVMFixedTraces},
128
};
139
mod alloc;
1410
use criterion::*;
1511

1612
use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES;
1713
use ff_ext::GoldilocksExt2;
14+
use gkr_iop::cpu::{CpuBackend, CpuProver};
1815
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel};
1916

2017
use rand::rngs::OsRng;

ceno_zkvm/src/e2e.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use crate::{
33
instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig},
44
scheme::{
55
PublicValues, ZKVMProof,
6-
cpu::{CpuBackend, CpuProver},
76
mock_prover::{LkMultiplicityKey, MockProver},
87
prover::ZKVMProver,
98
verifier::ZKVMVerifier,
@@ -23,6 +22,7 @@ use clap::ValueEnum;
2322
use ff_ext::ExtensionField;
2423
#[cfg(debug_assertions)]
2524
use ff_ext::{Instrumented, PoseidonField};
25+
use gkr_iop::cpu::{CpuBackend, CpuProver};
2626
use itertools::{Itertools, MinMaxResult, chain};
2727
use mpcs::{PolynomialCommitmentScheme, SecurityLevel};
2828
use std::{

ceno_zkvm/src/instructions.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use ff_ext::ExtensionField;
33
use gkr_iop::{
44
ProtocolBuilder, ProtocolWitnessGenerator,
55
gkr::{GKRCircuit, GKRCircuitOutput, GKRCircuitWitness},
6+
hal::ProverBackend,
67
};
78
use multilinear_extensions::util::max_usable_threads;
89
use rayon::{
@@ -95,7 +96,7 @@ pub trait GKRIOPInstruction<E: ExtensionField>
9596
where
9697
Self: Instruction<E>,
9798
{
98-
type Layout<'a>: ProtocolWitnessGenerator<'a, E> + ProtocolBuilder<E>;
99+
type Layout: ProtocolWitnessGenerator<E> + ProtocolBuilder<E>;
99100

100101
/// Similar to Instruction::construct_circuit; generally
101102
/// meant to extend InstructionConfig with GKR-specific
@@ -108,8 +109,8 @@ where
108109
}
109110

110111
/// Should generate phase1 witness for GKR from step records
111-
fn phase1_witness_from_steps<'a>(
112-
layout: &Self::Layout<'a>,
112+
fn phase1_witness_from_steps(
113+
layout: &Self::Layout,
113114
steps: &[StepRecord],
114115
) -> RowMajorMatrix<E::BaseField>;
115116

@@ -125,17 +126,17 @@ where
125126

126127
/// Similar to Instruction::assign_instances, but with access to the GKR layout.
127128
#[allow(clippy::type_complexity)]
128-
fn assign_instances_with_gkr_iop<'a>(
129+
fn assign_instances_with_gkr_iop<'a, PB: ProverBackend<E = E>>(
129130
_config: &Self::InstructionConfig,
130131
_num_witin: usize,
131132
_steps: Vec<StepRecord>,
132133
_gkr_circuit: &GKRCircuit<E>,
133-
_gkr_layout: &Self::Layout<'a>,
134+
_gkr_layout: &Self::Layout,
134135
) -> Result<
135136
(
136137
RowMajorMatrix<E::BaseField>,
137-
GKRCircuitWitness<'a, E>,
138-
GKRCircuitOutput<'a, E>,
138+
GKRCircuitWitness<'a, PB>,
139+
GKRCircuitOutput<'a, PB>,
139140
LkMultiplicity,
140141
),
141142
ZKVMError,

ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ impl<E: ExtensionField, S: SyscallSpec> Instruction<E> for LargeEcallDummy<E, S>
124124
}
125125

126126
impl<E: ExtensionField> GKRIOPInstruction<E> for LargeEcallDummy<E, KeccakSpec> {
127-
type Layout<'a> = KeccakLayout<E>;
127+
type Layout = KeccakLayout<E>;
128128

129129
fn gkr_info() -> crate::instructions::GKRinfo {
130130
GKRinfo {
@@ -191,8 +191,8 @@ impl<E: ExtensionField> GKRIOPInstruction<E> for LargeEcallDummy<E, KeccakSpec>
191191
}
192192
}
193193

194-
fn phase1_witness_from_steps<'a>(
195-
layout: &Self::Layout<'a>,
194+
fn phase1_witness_from_steps(
195+
layout: &Self::Layout,
196196
steps: &[StepRecord],
197197
) -> RowMajorMatrix<E::BaseField> {
198198
let instances = steps

ceno_zkvm/src/scheme/cpu/mod.rs

Lines changed: 7 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use super::hal::{
2-
DeviceTransporter, MainSumcheckProver, MultilinearPolynomial, OpeningProver, ProverBackend,
3-
ProverDevice, TowerProver, TraceCommitter,
2+
DeviceTransporter, MainSumcheckProver, OpeningProver, ProverDevice, TowerProver, TraceCommitter,
43
};
54
use crate::{
65
circuit_builder::ConstraintSystem,
@@ -17,8 +16,12 @@ use crate::{
1716
};
1817
use either::Either;
1918
use ff_ext::ExtensionField;
19+
use gkr_iop::{
20+
cpu::{CpuBackend, CpuProver},
21+
hal::ProverBackend,
22+
};
2023
use itertools::{Itertools, chain};
21-
use mpcs::{Point, PolynomialCommitmentScheme, SecurityLevel};
24+
use mpcs::{Point, PolynomialCommitmentScheme};
2225
use multilinear_extensions::{
2326
Expression, Instance,
2427
mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension},
@@ -28,10 +31,7 @@ use multilinear_extensions::{
2831
virtual_poly::build_eq_x_r_vec,
2932
virtual_polys::VirtualPolynomialsBuilder,
3033
};
31-
use p3::{
32-
field::{FieldAlgebra, TwoAdicField},
33-
matrix::dense::RowMajorMatrix,
34-
};
34+
use p3::field::FieldAlgebra;
3535
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
3636
use std::{collections::BTreeMap, sync::Arc};
3737
use sumcheck::{
@@ -42,63 +42,6 @@ use sumcheck::{
4242
use transcript::Transcript;
4343
use witness::next_pow2_instance_padding;
4444

45-
pub struct CpuBackend<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
46-
pub param: PCS::Param,
47-
_marker: std::marker::PhantomData<E>,
48-
}
49-
50-
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> Default for CpuBackend<E, PCS> {
51-
fn default() -> Self {
52-
Self::new()
53-
}
54-
}
55-
56-
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> CpuBackend<E, PCS> {
57-
pub fn new() -> Self {
58-
let param =
59-
PCS::setup(E::BaseField::TWO_ADICITY, SecurityLevel::Conjecture100bits).unwrap();
60-
Self {
61-
param,
62-
_marker: std::marker::PhantomData,
63-
}
64-
}
65-
}
66-
67-
impl<'a, E: ExtensionField> MultilinearPolynomial<E> for MultilinearExtension<'a, E> {
68-
fn num_vars(&self) -> usize {
69-
self.num_vars()
70-
}
71-
72-
fn eval(&self, point: Point<E>) -> E {
73-
self.evaluate(&point)
74-
}
75-
}
76-
77-
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ProverBackend for CpuBackend<E, PCS> {
78-
type E = E;
79-
type Pcs = PCS;
80-
type MultilinearPoly<'a> = MultilinearExtension<'a, E>;
81-
type Matrix = RowMajorMatrix<E::BaseField>;
82-
type PcsData = PCS::CommitmentWithWitness;
83-
}
84-
85-
/// CPU prover for CPU backend
86-
pub struct CpuProver<PB: ProverBackend> {
87-
backend: PB,
88-
pp: Option<<<PB as ProverBackend>::Pcs as PolynomialCommitmentScheme<PB::E>>::ProverParam>,
89-
largest_poly_size: Option<usize>,
90-
}
91-
92-
impl<PB: ProverBackend> CpuProver<PB> {
93-
pub fn new(backend: PB) -> Self {
94-
Self {
95-
backend,
96-
pp: None,
97-
largest_poly_size: None,
98-
}
99-
}
100-
}
101-
10245
pub struct CpuTowerProver;
10346

10447
impl CpuTowerProver {

ceno_zkvm/src/scheme/hal.rs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,13 @@ use crate::{
66
structs::{TowerProofs, ZKVMProvingKey},
77
};
88
use ff_ext::ExtensionField;
9+
use gkr_iop::hal::ProverBackend;
910
use mpcs::{Point, PolynomialCommitmentScheme};
1011
use multilinear_extensions::{mle::MultilinearExtension, util::ceil_log2};
1112
use sumcheck::structs::IOPProverMessage;
1213
use transcript::Transcript;
1314
use witness::next_pow2_instance_padding;
1415

15-
pub trait MultilinearPolynomial<E: ExtensionField> {
16-
fn num_vars(&self) -> usize;
17-
fn eval(&self, point: Point<E>) -> E;
18-
}
19-
20-
/// Defines basic types like field, pcs that are common among all devices
21-
/// and also defines the types that are specific to device.
22-
pub trait ProverBackend {
23-
/// types that are common across all devices
24-
type E: ExtensionField;
25-
type Pcs: PolynomialCommitmentScheme<Self::E>;
26-
27-
/// device-specific types
28-
// TODO: remove lifetime bound
29-
type MultilinearPoly<'a>: Send + Sync + Clone + MultilinearPolynomial<Self::E>;
30-
type Matrix: Send + Sync + Clone;
31-
type PcsData;
32-
}
33-
3416
pub trait ProverDevice<PB>:
3517
TraceCommitter<PB>
3618
+ TowerProver<PB>

ceno_zkvm/src/scheme/prover.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
use ff_ext::ExtensionField;
2+
use gkr_iop::{
3+
cpu::{CpuBackend, CpuProver},
4+
hal::ProverBackend,
5+
};
26
use std::{
37
collections::{BTreeMap, HashMap},
48
marker::PhantomData,
59
sync::Arc,
610
};
711

8-
use crate::scheme::hal::{MainSumcheckEvals, MultilinearPolynomial};
12+
use crate::scheme::hal::MainSumcheckEvals;
13+
use gkr_iop::hal::MultilinearPolynomial;
914
use itertools::Itertools;
1015
use mpcs::{Point, PolynomialCommitmentScheme};
1116
use multilinear_extensions::{
@@ -27,11 +32,7 @@ use crate::{
2732
structs::{ProvingKey, TowerProofs, ZKVMProvingKey, ZKVMWitnesses},
2833
};
2934

30-
use super::{
31-
PublicValues, ZKVMChipProof, ZKVMProof,
32-
cpu::{CpuBackend, CpuProver},
33-
hal::{ProverBackend, ProverDevice},
34-
};
35+
use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice};
3536

3637
type CreateTableProof<E> = (ZKVMChipProof<E>, HashMap<usize, E>, Point<E>);
3738

ceno_zkvm/src/scheme/tests.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
riscv::{arith::AddInstruction, ecall::HaltInstruction},
99
},
1010
scheme::{
11-
cpu::{CpuBackend, CpuProver, CpuTowerProver},
11+
cpu::CpuTowerProver,
1212
hal::{ProofInput, TowerProverSpec},
1313
prover::ZkVMCpuProver,
1414
},
@@ -25,6 +25,7 @@ use ceno_emul::{
2525
Platform, Program, StepRecord, VMState, encode_rv32,
2626
};
2727
use ff_ext::{ExtensionField, FieldInto, FromUniformBytes, GoldilocksExt2};
28+
use gkr_iop::cpu::{CpuBackend, CpuProver};
2829
use multilinear_extensions::{ToExpr, WitIn, mle::MultilinearExtension};
2930

3031
#[cfg(debug_assertions)]

gkr_iop/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ clap.workspace = true
1616
either.workspace = true
1717
ff_ext = { path = "../ff_ext" }
1818
itertools.workspace = true
19+
mpcs = { path = "../mpcs" }
1920
multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" }
2021
ndarray.workspace = true
22+
p3.workspace = true
2123
p3-field.workspace = true
2224
p3-goldilocks.workspace = true
2325
p3-util.workspace = true

0 commit comments

Comments
 (0)