Skip to content

Commit d6d2c87

Browse files
authored
[precompile] mem read/write rlc within precompile circuit (#973)
This PR add the missing witness: current ts, mem read ts, mem start addr in keccak precompile. Those information are needed for consistent check. Beside, mem read/write are rlc with those record > This PR is not fully sound, as we need to revamp current eq with selector logic ### benchmark The observed performance regression is expected, as the additional workload, whether handled inside or outside the precompile circuit, is necessary and unavoidable. | Benchmark | Median Time (s) | Median Change (%) | |----------------------------------|------------------|----------------------------------------| | keccak_lookup_f_4096 | 1.0320 | +4.7694% (Performance has regressed) | | keccak_lookup_f_8192 | 2.0793 | +9.6826% (Performance has regressed) |
1 parent 22a9b7e commit d6d2c87

File tree

12 files changed

+256
-112
lines changed

12 files changed

+256
-112
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ceno_zkvm/src/chip_handler/global_state.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitB
1212
pc,
1313
ts,
1414
];
15-
1615
self.read_record(|| "state_in", RAMType::GlobalState, record)
1716
}
1817

ceno_zkvm/src/e2e.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ pub fn run_e2e_with_checkpoint<
706706

707707
// proving
708708
let backend: CpuBackend<E, PCS> = CpuBackend::new();
709-
let device: CpuProver<CpuBackend<E, PCS>> = CpuProver::new(backend);
709+
let device = CpuProver::new(backend);
710710
let mut prover = ZKVMProver::new(pk, device);
711711

712712
if is_mock_proving {
@@ -780,7 +780,7 @@ pub fn run_e2e_proof<
780780

781781
// proving
782782
let backend: CpuBackend<E, PCS> = CpuBackend::new();
783-
let device: CpuProver<CpuBackend<E, PCS>> = CpuProver::new(backend);
783+
let device = CpuProver::new(backend);
784784
let mut prover = ZKVMProver::new(pk, device);
785785

786786
if is_mock_proving {

ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::marker::PhantomData;
22

3-
use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, SyscallSpec};
3+
use ceno_emul::{ByteAddr, Change, InsnKind, KeccakSpec, StepRecord, SyscallSpec};
44
use ff_ext::{ExtensionField, SmallField};
55
use itertools::{Itertools, zip_eq};
66
use witness::RowMajorMatrix;
@@ -22,7 +22,9 @@ use multilinear_extensions::{ToExpr, WitIn};
2222

2323
use gkr_iop::{
2424
ProtocolWitnessGenerator,
25-
precompiles::{AND_LOOKUPS, KeccakLayout, KeccakTrace, RANGE_LOOKUPS, XOR_LOOKUPS},
25+
precompiles::{
26+
AND_LOOKUPS, KECCAK_INPUT32_SIZE, KeccakLayout, KeccakTrace, RANGE_LOOKUPS, XOR_LOOKUPS,
27+
},
2628
};
2729

2830
/// LargeEcallDummy can handle any instruction and produce its effects,
@@ -208,8 +210,14 @@ impl<E: ExtensionField> GKRIOPInstruction<E> for LargeEcallDummy<E, KeccakSpec>
208210
.unwrap()
209211
})
210212
.collect_vec();
213+
let num_instances = instances.len();
211214

212-
layout.phase1_witness_group(KeccakTrace { instances })
215+
layout.phase1_witness_group(KeccakTrace {
216+
instances,
217+
ram_start_addr: vec![ByteAddr::from(0); num_instances],
218+
cur_ts: vec![0; num_instances],
219+
read_ts: vec![[0; KECCAK_INPUT32_SIZE]; num_instances],
220+
})
213221
}
214222

215223
fn assign_instance_with_gkr_iop(

ceno_zkvm/src/scheme/tests.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ use crate::{
1313
prover::ZkVMCpuProver,
1414
},
1515
set_val,
16-
structs::{
17-
PointAndEval, RAMType::Register, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses,
18-
},
16+
structs::{PointAndEval, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
1917
tables::{ProgramTableCircuit, U16TableCircuit},
2018
witness::LkMultiplicity,
2119
};
@@ -64,8 +62,8 @@ impl<E: ExtensionField, const L: usize, const RW: usize> Instruction<E> for Test
6462
let reg_id = cb.create_witin(|| "reg_id");
6563
(0..RW).try_for_each(|_| {
6664
let record = vec![1.into(), reg_id.expr()];
67-
cb.read_record(|| "read", Register, record.clone())?;
68-
cb.write_record(|| "write", Register, record)?;
65+
cb.read_record(|| "read", RAMType::Register, record.clone())?;
66+
cb.write_record(|| "write", RAMType::Register, record)?;
6967
Result::<(), ZKVMError>::Ok(())
7068
})?;
7169
(0..L).try_for_each(|_| {

ceno_zkvm/src/structs.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ use crate::{
77
witness::LkMultiplicity,
88
};
99
use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord, SyscallSpec};
10-
use either::Either;
1110
use ff_ext::ExtensionField;
1211
use gkr_iop::{LookupTable, gkr::GKRCircuit, precompiles::KeccakLayout};
1312
use itertools::Itertools;
1413
use mpcs::{Point, PolynomialCommitmentScheme};
15-
use multilinear_extensions::{Expression, impl_expr_from_unsigned};
14+
use multilinear_extensions::Expression;
1615
use serde::{Deserialize, Serialize, de::DeserializeOwned};
1716
use std::{
1817
collections::{BTreeMap, HashMap},
@@ -45,14 +44,7 @@ pub type ChallengeId = u16;
4544

4645
pub type ROMType = LookupTable;
4746

48-
#[derive(Clone, Debug, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
49-
pub enum RAMType {
50-
GlobalState,
51-
Register,
52-
Memory,
53-
}
54-
55-
impl_expr_from_unsigned!(RAMType);
47+
pub type RAMType = gkr_iop::RAMType;
5648

5749
pub type PointAndEval<F> = multilinear_extensions::mle::PointAndEval<F>;
5850

gkr_iop/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ version.workspace = true
1212
[dependencies]
1313
ark-std = { version = "0.5" }
1414
bincode.workspace = true
15+
ceno_emul = { path = "../ceno_emul" }
1516
clap.workspace = true
1617
either.workspace = true
1718
ff_ext = { path = "../ff_ext" }

gkr_iop/src/gkr/layer_constraint_system.rs

Lines changed: 133 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use std::{cmp::Ordering, collections::BTreeMap};
22

3-
use ff_ext::ExtensionField;
4-
use itertools::{Itertools, chain, izip};
5-
use multilinear_extensions::{Expression, Fixed, ToExpr, WitnessId, rlc_chip_record};
6-
use p3_field::FieldAlgebra;
7-
83
use crate::{
9-
LookupTable,
4+
LookupTable, RAMType,
105
evaluation::EvalExpression,
116
gkr::layer::{Layer, LayerType, ROTATION_OPENING_COUNT},
127
};
8+
use ceno_emul::Tracer;
9+
use ff_ext::ExtensionField;
10+
use itertools::{Itertools, chain, izip};
11+
use multilinear_extensions::{Expression, Fixed, ToExpr, WitnessId, rlc_chip_record};
12+
use p3_field::FieldAlgebra;
1313

1414
#[derive(Clone, Debug, Default)]
1515
pub struct RotationParams<E: ExtensionField> {
@@ -39,6 +39,9 @@ pub struct LayerConstraintSystem<E: ExtensionField> {
3939
pub xor_lookups: Vec<Expression<E>>,
4040
pub range_lookups: Vec<Expression<E>>,
4141

42+
pub ram_read: Vec<Expression<E>>,
43+
pub ram_write: Vec<Expression<E>>,
44+
4245
// global challenge
4346
pub alpha: Expression<E>,
4447
pub beta: Expression<E>,
@@ -66,6 +69,8 @@ impl<E: ExtensionField> LayerConstraintSystem<E> {
6669
and_lookups: vec![],
6770
xor_lookups: vec![],
6871
range_lookups: vec![],
72+
ram_read: vec![],
73+
ram_write: vec![],
6974
alpha,
7075
beta,
7176
}
@@ -89,29 +94,31 @@ impl<E: ExtensionField> LayerConstraintSystem<E> {
8994
self.expr_names.push(name);
9095
}
9196

92-
fn rlc_table_lookup(
93-
&self,
94-
lookup_table: LookupTable,
95-
values: impl Iterator<Item = Expression<E>>,
96-
) -> Expression<E> {
97-
rlc_chip_record(
98-
chain!(
99-
std::iter::once(E::BaseField::from_canonical_u64(lookup_table as u64).expr()),
100-
values
101-
)
102-
.collect_vec(),
97+
pub fn lookup_and8(&mut self, a: Expression<E>, b: Expression<E>, c: Expression<E>) {
98+
let rlc_record = rlc_chip_record(
99+
vec![
100+
E::BaseField::from_canonical_u64(LookupTable::And as u64).expr(),
101+
a,
102+
b,
103+
c,
104+
],
103105
self.alpha.clone(),
104106
self.beta.clone(),
105-
)
106-
}
107-
108-
pub fn lookup_and8(&mut self, a: Expression<E>, b: Expression<E>, c: Expression<E>) {
109-
let rlc_record = self.rlc_table_lookup(LookupTable::And, [a, b, c].iter().cloned());
107+
);
110108
self.and_lookups.push(rlc_record);
111109
}
112110

113111
pub fn lookup_xor8(&mut self, a: Expression<E>, b: Expression<E>, c: Expression<E>) {
114-
let rlc_record = self.rlc_table_lookup(LookupTable::Xor, [a, b, c].iter().cloned());
112+
let rlc_record = rlc_chip_record(
113+
vec![
114+
E::BaseField::from_canonical_u64(LookupTable::Xor as u64).expr(),
115+
a,
116+
b,
117+
c,
118+
],
119+
self.alpha.clone(),
120+
self.beta.clone(),
121+
);
115122
self.xor_lookups.push(rlc_record);
116123
}
117124

@@ -120,17 +127,90 @@ impl<E: ExtensionField> LayerConstraintSystem<E> {
120127
/// `value << (16 - size)`.
121128
pub fn lookup_range(&mut self, value: Expression<E>, size: usize) {
122129
assert!(size <= 16);
123-
self.range_lookups
124-
.push(self.rlc_table_lookup(LookupTable::U16, vec![value.clone()].into_iter()));
130+
let rlc_record = rlc_chip_record(
131+
vec![
132+
E::BaseField::from_canonical_u64(LookupTable::U16 as u64).expr(),
133+
value.clone(),
134+
],
135+
self.alpha.clone(),
136+
self.beta.clone(),
137+
);
138+
self.range_lookups.push(rlc_record);
125139
if size < 16 {
126-
self.range_lookups.push(
127-
self.rlc_table_lookup(
128-
LookupTable::U16,
129-
[value * E::BaseField::from_canonical_u64(1 << (16 - size)).expr()]
130-
.iter()
131-
.cloned(),
132-
),
133-
)
140+
let rlc_record = rlc_chip_record(
141+
vec![
142+
E::BaseField::from_canonical_u64(LookupTable::U16 as u64).expr(),
143+
value * E::BaseField::from_canonical_u64(1 << (16 - size)).expr(),
144+
],
145+
self.alpha.clone(),
146+
self.beta.clone(),
147+
);
148+
self.range_lookups.push(rlc_record)
149+
}
150+
}
151+
152+
/// records RAM write operations into the `ram_write` trace vector using RLC encoding.
153+
///
154+
/// this function appends one RLC-encoded record per word written to memory,
155+
/// starting from `mem_start_addr`. Each record includes:
156+
/// - The operation type (assumed to be `RAMType::Memory`)
157+
/// - The memory address of the word (adjusted by `4 * index`)
158+
/// - The value being written
159+
/// - The timestamp of the write
160+
pub fn ram_write_record(
161+
&mut self,
162+
mem_start_addr: Expression<E>,
163+
values: Vec<Expression<E>>,
164+
cur_ts: Expression<E>,
165+
) {
166+
for (idx, value) in values.into_iter().enumerate() {
167+
let rlc_record = rlc_chip_record(
168+
vec![
169+
E::BaseField::from_canonical_u64(RAMType::Memory as u64).expr(),
170+
mem_start_addr.clone()
171+
// `4` is num bytes per word
172+
// TODO fetch from constant
173+
+ E::BaseField::from_canonical_u64((4 * idx) as u64).expr(),
174+
value,
175+
cur_ts.clone() + E::BaseField::from_canonical_u64(Tracer::SUBCYCLE_MEM).expr(),
176+
],
177+
self.alpha.clone(),
178+
self.beta.clone(),
179+
);
180+
self.ram_write.push(rlc_record);
181+
}
182+
}
183+
184+
/// records RAM read operations into the `ram_read` trace vector using RLC encoding.
185+
///
186+
/// this function appends one RLC-encoded record per word written to memory,
187+
/// starting from `mem_start_addr`. Each record includes:
188+
/// - The operation type (assumed to be `RAMType::Memory`)
189+
/// - The memory address of the word (adjusted by `4 * index`)
190+
/// - The value being read
191+
/// - The ts corresponding to each value
192+
pub fn ram_read_record(
193+
&mut self,
194+
mem_start_addr: Expression<E>,
195+
values: Vec<Expression<E>>,
196+
ts: Vec<Expression<E>>,
197+
) {
198+
assert_eq!(values.len(), ts.len());
199+
for (idx, (value, ts)) in izip!(values, ts).enumerate() {
200+
let rlc_record = rlc_chip_record(
201+
vec![
202+
E::BaseField::from_canonical_u64(RAMType::Memory as u64).expr(),
203+
mem_start_addr.clone()
204+
// `4` is num bytes per word
205+
// TODO fetch from constant
206+
+ E::BaseField::from_canonical_u64((4 * idx) as u64).expr(),
207+
value,
208+
ts,
209+
],
210+
self.alpha.clone(),
211+
self.beta.clone(),
212+
);
213+
self.ram_read.push(rlc_record);
134214
}
135215
}
136216

@@ -270,8 +350,27 @@ impl<E: ExtensionField> LayerConstraintSystem<E> {
270350
layer_name: String,
271351
in_expr_evals: Vec<usize>,
272352
n_challenges: usize,
353+
ram_write_evals: impl ExactSizeIterator<Item = (Option<Expression<E>>, usize)>,
354+
ram_read_evals: impl ExactSizeIterator<Item = (Option<Expression<E>>, usize)>,
273355
lookup_evals: impl ExactSizeIterator<Item = (Option<Expression<E>>, usize)>,
274356
) -> Layer<E> {
357+
// process ram read/write record
358+
assert_eq!(ram_write_evals.len(), self.ram_write.len(),);
359+
assert_eq!(ram_read_evals.len(), self.ram_read.len(),);
360+
361+
for (idx, ram_expr, ram_eval) in izip!(
362+
0..,
363+
chain!(self.ram_write.clone(), self.ram_read.clone(),),
364+
ram_write_evals.chain(ram_read_evals)
365+
) {
366+
self.add_non_zero_constraint(
367+
ram_expr - E::BaseField::ONE.expr(), // ONE is for padding
368+
(ram_eval.0, EvalExpression::Single(ram_eval.1)),
369+
format!("round 0th: {idx}th ram read/write operation"),
370+
);
371+
}
372+
373+
// process lookup records
275374
assert_eq!(
276375
lookup_evals.len(),
277376
self.and_lookups.len() + self.xor_lookups.len() + self.range_lookups.len()
@@ -336,6 +435,7 @@ impl<E: ExtensionField> LayerConstraintSystem<E> {
336435
is_linear_so_far && t.is_linear()
337436
});
338437

438+
// process evaluation group by eq expression
339439
let mut eq_map = BTreeMap::new();
340440
izip!(
341441
evals.into_iter(),

gkr_iop/src/lib.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use std::marker::PhantomData;
22

33
use crate::hal::{ProtocolWitnessGeneratorProver, ProverDevice};
44
use chip::Chip;
5+
use either::Either;
56
use ff_ext::ExtensionField;
67
use gkr::{GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, layer::LayerWitness};
7-
use multilinear_extensions::mle::ArcMultilinearExtension;
8+
use multilinear_extensions::{Expression, impl_expr_from_unsigned, mle::ArcMultilinearExtension};
89
use strum_macros::EnumIter;
910
use transcript::Transcript;
1011
use utils::infer_layer_witness;
@@ -103,3 +104,13 @@ pub enum LookupTable {
103104
Pow, // a ** b where a is 2 and b is 5-bit value
104105
Instruction, // Decoded instruction from the fixed program.
105106
}
107+
108+
#[derive(Clone, Debug, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
109+
#[repr(usize)]
110+
pub enum RAMType {
111+
GlobalState,
112+
Register,
113+
Memory,
114+
}
115+
116+
impl_expr_from_unsigned!(RAMType);

0 commit comments

Comments
 (0)