Skip to content

Commit a4def62

Browse files
authored
improve batching sumcheck concurrency to dominate by max num var polynomial (#895)
Previously (devirgo) sumcheck concurrency when batching with different number of variables, are dominated by the smallest num_vars. The reason is because for the num_var < log_2(threads) we can NOT divide it's evaluation into #thread region. This PR addressed the issue by introducing a extra meta information for each polynomial, so we are able to differentiate those small poly and calculate its multiplicity correctly ### Design rationale For some extreme small polynomial (num_var << log(threads)), we define a new `PolyType::Phase2Only` to differentiate them. These type of small polynomial will be handle by main_worker only during phase 1 sumcheck. And when moving to phase2 sumcheck, those small poly will expand the size to match with all. This cost is negligible giving phase2 only need to run on #thread evaluations. ### benchmark There is no different with before/after change, since we dont have batch sumcheck with different num_vars in critical path, which meet expected e2e Fibonacci 2^20 ``` fibonacci_max_steps_1048576/prove_fibonacci/fibonacci_max_steps_1048576 time: [2.8894 s 2.9028 s 2.9180 s] change: [+0.1414% +0.8059% +1.5474%] (p = 0.05 < 0.05) Change within noise threshold. ``` e2e Fibonacci 2^21 ``` fibonacci_max_steps_2097152/prove_fibonacci/fibonacci_max_steps_2097152 time: [5.1785 s 5.2061 s 5.2291 s] change: [+0.7256% +1.3249% +1.9085%] (p = 0.00 < 0.05) Change within noise threshold. ``` e2e Fibonacci 2^22 ``` fibonacci_max_steps_4194304/prove_fibonacci/fibonacci_max_steps_4194304 time: [10.308 s 10.330 s 10.353 s] change: [+0.9437% +1.2216% +1.4899%] (p = 0.00 < 0.05) Change within noise threshold. ```
1 parent 5c2ac46 commit a4def62

File tree

11 files changed

+489
-276
lines changed

11 files changed

+489
-276
lines changed

Diff for: ceno_zkvm/src/scheme/prover.rs

+5-13
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
572572
}
573573

574574
tracing::debug!("main sel sumcheck start");
575-
let (main_sel_sumcheck_proofs, state) = IOPProverState::prove_batch_polys(
576-
num_threads,
577-
virtual_polys.get_batched_polys(),
578-
transcript,
579-
);
575+
let (main_sel_sumcheck_proofs, state) = IOPProverState::prove(virtual_polys, transcript);
580576
tracing::debug!("main sel sumcheck end");
581577

582578
let main_sel_evals = state.get_mle_final_evaluations();
@@ -1015,11 +1011,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
10151011
virtual_polys.add_mle_list(vec![eq, lk_d_wit], *alpha);
10161012
}
10171013

1018-
let (same_r_sumcheck_proofs, state) = IOPProverState::prove_batch_polys(
1019-
num_threads,
1020-
virtual_polys.get_batched_polys(),
1021-
transcript,
1022-
);
1014+
let (same_r_sumcheck_proofs, state) =
1015+
IOPProverState::prove(virtual_polys, transcript);
10231016
let evals = state.get_mle_final_evaluations();
10241017
let mut evals_iter = evals.into_iter();
10251018
let rw_in_evals = cs
@@ -1271,9 +1264,8 @@ impl TowerProver {
12711264
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer
12721265
// shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys`
12731266
// This is likely a bug in the tracing-flame crate.
1274-
let (sumcheck_proofs, state) = IOPProverState::prove_batch_polys(
1275-
num_threads,
1276-
virtual_polys.get_batched_polys(),
1267+
let (sumcheck_proofs, state) = IOPProverState::prove(
1268+
virtual_polys,
12771269
transcript,
12781270
);
12791271
exit_span!(wrap_batch_span);

Diff for: mpcs/src/basefold/commit_phase.rs

+19-7
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ use serde::{Serialize, de::DeserializeOwned};
2121
use sumcheck::{
2222
macros::{entered_span, exit_span},
2323
structs::IOPProverState,
24-
util::{AdditiveVec, merge_sumcheck_polys, optimal_sumcheck_threads},
24+
util::{AdditiveVec, merge_sumcheck_prover_state, optimal_sumcheck_threads},
2525
};
2626
use transcript::{Challenge, Transcript};
2727

2828
use multilinear_extensions::{
2929
commutative_op_mle_pair,
3030
mle::{DenseMultilinearExtension, IntoMLE},
31+
util::ceil_log2,
3132
virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec},
3233
virtual_polys::VirtualPolynomials,
3334
};
@@ -98,15 +99,23 @@ where
9899
end_timer!(build_eq_timer);
99100

100101
let num_threads = optimal_sumcheck_threads(num_vars);
102+
let log_num_threads = ceil_log2(num_threads);
101103

102104
let mut polys = VirtualPolynomials::new(num_threads, num_vars);
103105
polys.add_mle_list(vec![&eq, &running_evals], E::ONE);
104-
let batched_polys = polys.get_batched_polys();
106+
let (batched_polys, poly_meta) = polys.get_batched_polys();
105107

106108
let mut prover_states = batched_polys
107109
.into_iter()
108-
.map(|poly| {
109-
IOPProverState::prover_init_with_extrapolation_aux(poly, vec![(vec![], vec![])])
110+
.enumerate()
111+
.map(|(thread_id, poly)| {
112+
IOPProverState::prover_init_with_extrapolation_aux(
113+
thread_id == 0, // set thread_id 0 to be main worker
114+
poly,
115+
vec![(vec![], vec![])],
116+
Some(log_num_threads),
117+
Some(poly_meta.clone()),
118+
)
110119
})
111120
.collect::<Vec<_>>();
112121

@@ -140,13 +149,16 @@ where
140149
}
141150

142151
// deal with log(#thread) basefold rounds
143-
let merge_sumcheck_polys_span = entered_span!("merge_sumcheck_polys");
144-
let poly = merge_sumcheck_polys(&prover_states);
152+
let merge_sumcheck_prover_state_span = entered_span!("merge_sumcheck_prover_state");
153+
let poly = merge_sumcheck_prover_state(prover_states);
145154
let mut prover_states = vec![IOPProverState::prover_init_with_extrapolation_aux(
155+
true,
146156
poly,
147157
vec![(vec![], vec![])],
158+
None,
159+
None,
148160
)];
149-
exit_span!(merge_sumcheck_polys_span);
161+
exit_span!(merge_sumcheck_prover_state_span);
150162

151163
let mut challenge = None;
152164

Diff for: multilinear_extensions/src/virtual_poly.rs

+8-10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
1818

1919
pub type ArcMultilinearExtension<'a, E> =
2020
Arc<dyn MultilinearExtension<E, Output = DenseMultilinearExtension<E>> + 'a>;
21+
2122
#[rustfmt::skip]
2223
/// A virtual polynomial is a sum of products of multilinear polynomials;
2324
/// where the multilinear polynomials are stored via their multilinear
@@ -113,21 +114,17 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> {
113114
///
114115
/// The MLEs will be multiplied together, and then multiplied by the scalar
115116
/// `coefficient`.
116-
pub fn add_mle_list(&mut self, mle_list: Vec<ArcMultilinearExtension<'a, E>>, coefficient: E) {
117+
pub fn add_mle_list(
118+
&mut self,
119+
mle_list: Vec<ArcMultilinearExtension<'a, E>>,
120+
coefficient: E,
121+
) -> &[usize] {
117122
let mle_list: Vec<ArcMultilinearExtension<E>> = mle_list.into_iter().collect();
118123
let mut indexed_product = Vec::with_capacity(mle_list.len());
119124

120125
assert!(!mle_list.is_empty(), "input mle_list is empty");
121126
// sanity check: all mle in mle_list must have same num_vars()
122-
assert!(
123-
mle_list
124-
.iter()
125-
.map(|m| {
126-
assert!(m.num_vars() <= self.aux_info.max_num_variables);
127-
m.num_vars()
128-
})
129-
.all_equal()
130-
);
127+
assert!(mle_list.iter().map(|m| { m.num_vars() }).all_equal());
131128

132129
self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len());
133130

@@ -143,6 +140,7 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> {
143140
}
144141
}
145142
self.products.push((coefficient, indexed_product));
143+
&self.products.last().unwrap().1
146144
}
147145

148146
/// in-place merge with another virtual polynomial

Diff for: multilinear_extensions/src/virtual_polys.rs

+52-14
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
1-
use std::{collections::HashMap, sync::Arc};
1+
use std::{
2+
collections::{BTreeMap, HashMap},
3+
sync::Arc,
4+
};
25

36
use crate::{
47
util::ceil_log2,
58
virtual_poly::{ArcMultilinearExtension, VirtualPolynomial},
69
};
710
use ff_ext::ExtensionField;
811
use itertools::Itertools;
12+
use p3::util::log2_strict_usize;
913

1014
use crate::util::transpose;
1115

16+
#[derive(Debug, Default, Clone, Copy)]
17+
pub enum PolyMeta {
18+
#[default]
19+
Normal,
20+
Phase2Only,
21+
}
22+
1223
pub struct VirtualPolynomials<'a, E: ExtensionField> {
13-
num_threads: usize,
24+
pub num_threads: usize,
1425
polys: Vec<VirtualPolynomial<'a, E>>,
1526
/// a storage to keep thread based mles, specific to multi-thread logic
1627
thread_based_mles_storage: HashMap<usize, Vec<ArcMultilinearExtension<'a, E>>>,
28+
pub(crate) poly_meta: BTreeMap<usize, PolyMeta>,
1729
}
1830

1931
impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
@@ -25,6 +37,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
2537
.map(|_| VirtualPolynomial::new(max_num_variables - ceil_log2(num_threads)))
2638
.collect_vec(),
2739
thread_based_mles_storage: HashMap::new(),
40+
poly_meta: BTreeMap::new(),
2841
}
2942
}
3043

@@ -44,32 +57,52 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
4457
}
4558

4659
pub fn add_mle_list(&mut self, polys: Vec<&'a ArcMultilinearExtension<'a, E>>, coeff: E) {
47-
let polys = polys
60+
let log2_num_threads = log2_strict_usize(self.num_threads);
61+
let (poly_meta, polys): (Vec<PolyMeta>, Vec<Vec<ArcMultilinearExtension<E>>>) = polys
4862
.into_iter()
4963
.map(|p| {
5064
let mle_ptr: usize = Arc::as_ptr(p) as *const () as usize;
51-
if let Some(mles) = self.thread_based_mles_storage.get(&mle_ptr) {
65+
let poly_meta = if p.num_vars() > log2_num_threads {
66+
PolyMeta::Normal
67+
} else {
68+
// polynomial is too small
69+
PolyMeta::Phase2Only
70+
};
71+
let mles_cloned = if let Some(mles) = self.thread_based_mles_storage.get(&mle_ptr) {
5272
mles.clone()
5373
} else {
5474
let mles = (0..self.num_threads)
55-
.map(|thread_id| {
56-
self.get_range_polys_by_thread_id(thread_id, vec![p])
57-
.remove(0)
75+
.map(|thread_id| match poly_meta {
76+
PolyMeta::Normal => self
77+
.get_range_polys_by_thread_id(thread_id, vec![p])
78+
.remove(0),
79+
PolyMeta::Phase2Only => Arc::new(p.get_ranged_mle(1, 0)),
5880
})
5981
.collect_vec();
6082
let mles_cloned = mles.clone();
6183
self.thread_based_mles_storage.insert(mle_ptr, mles);
6284
mles_cloned
63-
}
85+
};
86+
(poly_meta, mles_cloned)
6487
})
65-
.collect_vec();
88+
.unzip();
6689

6790
// poly -> thread to thread -> poly
6891
let polys = transpose(polys);
69-
(0..self.num_threads)
92+
let poly_index: &[usize] = self
93+
.polys
94+
.iter_mut()
7095
.zip_eq(polys)
71-
.for_each(|(thread_id, polys)| {
72-
self.polys[thread_id].add_mle_list(polys, coeff);
96+
.map(|(poly, polys)| poly.add_mle_list(polys, coeff))
97+
.collect_vec()
98+
.first()
99+
.expect("expect to get at index from first thread");
100+
101+
poly_index
102+
.iter()
103+
.zip_eq(&poly_meta)
104+
.for_each(|(index, poly_meta)| {
105+
self.poly_meta.insert(*index, *poly_meta);
73106
});
74107
}
75108

@@ -84,8 +117,13 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
84117
}
85118
}
86119

87-
pub fn get_batched_polys(self) -> Vec<VirtualPolynomial<'a, E>> {
88-
self.polys
120+
/// return thread_based polynomial with its polynomial type
121+
pub fn get_batched_polys(self) -> (Vec<VirtualPolynomial<'a, E>>, Vec<PolyMeta>) {
122+
let mut poly_meta = vec![PolyMeta::Normal; self.polys[0].flattened_ml_extensions.len()];
123+
for (index, poly_meta_by_index) in self.poly_meta {
124+
poly_meta[index] = poly_meta_by_index
125+
}
126+
(self.polys, poly_meta)
89127
}
90128

91129
pub fn degree(&self) -> usize {

Diff for: sumcheck/benches/devirgo_sumcheck.rs

+23-52
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
#![allow(clippy::manual_memcpy)]
22
#![allow(clippy::needless_range_loop)]
33

4-
use std::{array, time::Duration};
4+
use std::time::Duration;
55

66
use ark_std::test_rng;
77
use criterion::*;
88
use ff_ext::{ExtensionField, GoldilocksExt2};
99
use itertools::Itertools;
10-
use sumcheck::{structs::IOPProverState, util::ceil_log2};
10+
use p3::field::PrimeCharacteristicRing;
11+
use sumcheck::structs::IOPProverState;
1112

1213
use multilinear_extensions::{
1314
mle::DenseMultilinearExtension,
1415
op_mle,
1516
util::max_usable_threads,
1617
virtual_poly::{ArcMultilinearExtension, VirtualPolynomial},
18+
virtual_polys::VirtualPolynomials,
1719
};
1820
use transcript::BasicTranscript as Transcript;
1921

@@ -39,49 +41,15 @@ pub fn transpose<T>(v: Vec<Vec<T>>) -> Vec<Vec<T>> {
3941
.collect()
4042
}
4143

42-
fn prepare_input<'a, E: ExtensionField>(
43-
nv: usize,
44-
) -> (E, VirtualPolynomial<'a, E>, Vec<VirtualPolynomial<'a, E>>) {
44+
fn prepare_input<'a, E: ExtensionField>(nv: usize) -> (E, Vec<ArcMultilinearExtension<'a, E>>) {
4545
let mut rng = test_rng();
46-
let max_thread_id = max_usable_threads();
47-
let size_log2 = ceil_log2(max_thread_id);
48-
let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| {
49-
let mle: ArcMultilinearExtension<'a, E> =
50-
DenseMultilinearExtension::<E>::random(nv, &mut rng).into();
51-
mle
52-
});
53-
54-
let mut virtual_poly_v1 = VirtualPolynomial::new(nv);
55-
virtual_poly_v1.add_mle_list(fs.to_vec(), E::ONE);
56-
57-
// devirgo version
58-
let virtual_poly_v2: Vec<Vec<ArcMultilinearExtension<'a, E>>> = transpose(
59-
fs.iter()
60-
.map(|f| match &f.evaluations() {
61-
multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations
62-
.chunks((1 << nv) >> size_log2)
63-
.map(|chunk| {
64-
let mle: ArcMultilinearExtension<'a, E> =
65-
DenseMultilinearExtension::<E>::from_evaluations_vec(
66-
nv - size_log2,
67-
chunk.to_vec(),
68-
)
69-
.into();
70-
mle
71-
})
72-
.collect_vec(),
73-
_ => unreachable!(),
74-
})
75-
.collect(),
76-
);
77-
let virtual_poly_v2: Vec<VirtualPolynomial<E>> = virtual_poly_v2
78-
.into_iter()
79-
.map(|fs| {
80-
let mut virtual_polynomial = VirtualPolynomial::new(fs[0].num_vars());
81-
virtual_polynomial.add_mle_list(fs, E::ONE);
82-
virtual_polynomial
46+
let fs = (0..NUM_DEGREE)
47+
.map(|_| {
48+
let mle: ArcMultilinearExtension<'a, E> =
49+
DenseMultilinearExtension::<E>::random(nv, &mut rng).into();
50+
mle
8351
})
84-
.collect();
52+
.collect_vec();
8553

8654
let asserted_sum = fs
8755
.iter()
@@ -97,7 +65,7 @@ fn prepare_input<'a, E: ExtensionField>(
9765
.cloned()
9866
.sum::<E>();
9967

100-
(asserted_sum, virtual_poly_v1, virtual_poly_v2)
68+
(asserted_sum, fs)
10169
}
10270

10371
fn sumcheck_fn(c: &mut Criterion) {
@@ -116,12 +84,15 @@ fn sumcheck_fn(c: &mut Criterion) {
11684
let mut time = Duration::new(0, 0);
11785
for _ in 0..iters {
11886
let mut prover_transcript = Transcript::new(b"test");
119-
let (_, virtual_poly, _) = { prepare_input(nv) };
87+
let (_, fs) = { prepare_input(nv) };
88+
89+
let mut virtual_poly_v1 = VirtualPolynomial::new(nv);
90+
virtual_poly_v1.add_mle_list(fs.to_vec(), E::ONE);
12091

12192
let instant = std::time::Instant::now();
12293
#[allow(deprecated)]
12394
let (_sumcheck_proof_v1, _) = IOPProverState::<E>::prove_parallel(
124-
virtual_poly.clone(),
95+
virtual_poly_v1,
12596
&mut prover_transcript,
12697
);
12798
let elapsed = instant.elapsed();
@@ -153,14 +124,14 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) {
153124
let mut time = Duration::new(0, 0);
154125
for _ in 0..iters {
155126
let mut prover_transcript = Transcript::new(b"test");
156-
let (_, _, virtual_poly_splitted) = { prepare_input(nv) };
127+
let (_, fs) = { prepare_input(nv) };
128+
129+
let mut virtual_poly_v2 = VirtualPolynomials::new(threads, nv);
130+
virtual_poly_v2.add_mle_list(fs.iter().collect_vec(), E::ONE);
157131

158132
let instant = std::time::Instant::now();
159-
let (_sumcheck_proof_v2, _) = IOPProverState::<E>::prove_batch_polys(
160-
threads,
161-
virtual_poly_splitted,
162-
&mut prover_transcript,
163-
);
133+
let (_sumcheck_proof_v2, _) =
134+
IOPProverState::<E>::prove(virtual_poly_v2, &mut prover_transcript);
164135
let elapsed = instant.elapsed();
165136
time += elapsed;
166137
}

0 commit comments

Comments
 (0)