Skip to content

Commit 584d9d9

Browse files
authored
continuation e2e in gpu (#1121)
reopen from #1118
1 parent 1e3c940 commit 584d9d9

File tree

5 files changed

+107
-48
lines changed

5 files changed

+107
-48
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ opt-level = 3
9797
lto = "thin"
9898

9999
# [patch."ssh://[email protected]/scroll-tech/ceno-gpu.git"]
100-
# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal" }
100+
# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features=["bb31"] }
101101

102102
# [patch."https://github.com/scroll-tech/gkr-backend"]
103103
# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" }

ceno_zkvm/src/scheme/gpu/mod.rs

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::hal::{
2-
DeviceTransporter, MainSumcheckProver, OpeningProver, ProverDevice, TowerProver, TraceCommitter,
2+
DeviceTransporter, EccQuarkProver, MainSumcheckProver, OpeningProver, ProverDevice,
3+
TowerProver, TraceCommitter,
34
};
45
use crate::{
56
error::ZKVMError,
@@ -9,25 +10,15 @@ use crate::{
910
},
1011
structs::{ComposedConstrainSystem, PointAndEval, TowerProofs},
1112
};
12-
use ceno_gpu::bb31::GpuPolynomialExt;
1313
use ff_ext::{ExtensionField, GoldilocksExt2};
1414
use gkr_iop::{
15-
gkr::{
16-
self, Evaluation, GKRProof, GKRProverOutput,
17-
layer::{LayerWitness, gpu::utils::extract_mle_relationships_from_monomial_terms},
18-
},
15+
gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness},
1916
gpu::{GpuBackend, GpuProver},
20-
hal::ProverBackend,
17+
hal::{MultilinearPolynomial, ProverBackend},
2118
};
2219
use itertools::{Itertools, chain};
2320
use mpcs::{Point, PolynomialCommitmentScheme};
24-
use multilinear_extensions::{
25-
Instance, WitnessId,
26-
mle::{FieldType, MultilinearExtension},
27-
monomialize_expr_to_wit_terms,
28-
util::ceil_log2,
29-
};
30-
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
21+
use multilinear_extensions::{mle::MultilinearExtension, util::ceil_log2};
3122
use std::{collections::BTreeMap, sync::Arc};
3223
use sumcheck::{
3324
macros::{entered_span, exit_span},
@@ -37,16 +28,20 @@ use sumcheck::{
3728
use transcript::{BasicTranscript, Transcript};
3829
use witness::next_pow2_instance_padding;
3930

40-
use crate::circuit_builder::ConstraintSystem;
41-
use gkr_iop::hal::MultilinearPolynomial;
42-
4331
#[cfg(feature = "gpu")]
4432
use gkr_iop::gpu::gpu_prover::*;
4533

4634
pub struct GpuTowerProver;
4735

48-
use crate::{e2e::ShardContext, scheme::constants::NUM_FANIN};
49-
use gkr_iop::gpu::{ArcMultilinearExtensionGpu, MultilinearExtensionGpu};
36+
use crate::{
37+
e2e::ShardContext,
38+
scheme::{constants::NUM_FANIN, cpu::CpuEccProver},
39+
structs::EccQuarkProof,
40+
};
41+
use gkr_iop::{
42+
gpu::{ArcMultilinearExtensionGpu, MultilinearExtensionGpu},
43+
selector::SelectorContext,
44+
};
5045

5146
// Extract out_evals from GPU-built tower witnesses
5247
#[allow(clippy::type_complexity)]
@@ -102,7 +97,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TraceCommitter<GpuBa
10297
for GpuProver<GpuBackend<E, PCS>>
10398
{
10499
fn commit_traces<'a>(
105-
&mut self,
100+
&self,
106101
traces: BTreeMap<usize, witness::RowMajorMatrix<E::BaseField>>,
107102
) -> (
108103
Vec<MultilinearExtensionGpu<'a, E>>,
@@ -534,13 +529,48 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
534529
gkr_circuit,
535530
} = composed_cs;
536531

532+
let num_instances = input.num_instances();
537533
let log2_num_instances = input.log2_num_instances();
538534
let num_threads = optimal_sumcheck_threads(log2_num_instances);
539535
let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0);
540536

541537
let Some(gkr_circuit) = gkr_circuit else {
542538
panic!("empty gkr circuit")
543539
};
540+
let selector_ctxs = if cs.ec_final_sum.is_empty() {
541+
// it's not global chip
542+
vec![
543+
SelectorContext {
544+
offset: 0,
545+
num_instances,
546+
num_vars: num_var_with_rotation,
547+
};
548+
gkr_circuit
549+
.layers
550+
.first()
551+
.map(|layer| layer.out_sel_and_eval_exprs.len())
552+
.unwrap_or(0)
553+
]
554+
} else {
555+
// it's global chip
556+
vec![
557+
SelectorContext {
558+
offset: 0,
559+
num_instances: input.num_instances[0],
560+
num_vars: num_var_with_rotation,
561+
},
562+
SelectorContext {
563+
offset: input.num_instances[0],
564+
num_instances: input.num_instances[1],
565+
num_vars: num_var_with_rotation,
566+
},
567+
SelectorContext {
568+
offset: 0,
569+
num_instances,
570+
num_vars: num_var_with_rotation,
571+
},
572+
]
573+
};
544574
let pub_io_mles = cs
545575
.instance_openings
546576
.iter()
@@ -574,7 +604,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
574604
.collect_vec(),
575605
challenges,
576606
transcript,
577-
num_instances,
607+
&selector_ctxs,
578608
)?;
579609
assert_eq!(rt.len(), 1, "TODO support multi-layer gkr iop");
580610
Ok((
@@ -600,6 +630,34 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
600630
}
601631
}
602632

633+
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> EccQuarkProver<GpuBackend<E, PCS>>
634+
for GpuProver<GpuBackend<E, PCS>>
635+
{
636+
fn prove_ec_sum_quark<'a>(
637+
&self,
638+
num_instances: usize,
639+
xs: Vec<Arc<MultilinearExtensionGpu<'a, E>>>,
640+
ys: Vec<Arc<MultilinearExtensionGpu<'a, E>>>,
641+
invs: Vec<Arc<MultilinearExtensionGpu<'a, E>>>,
642+
transcript: &mut impl Transcript<E>,
643+
) -> Result<EccQuarkProof<E>, ZKVMError> {
644+
// TODO implement GPU version of `create_ecc_proof`
645+
let xs = xs.iter().map(|mle| mle.inner_to_mle().into()).collect_vec();
646+
let ys = ys.iter().map(|mle| mle.inner_to_mle().into()).collect_vec();
647+
let invs = invs
648+
.iter()
649+
.map(|mle| mle.inner_to_mle().into())
650+
.collect_vec();
651+
Ok(CpuEccProver::create_ecc_proof(
652+
num_instances,
653+
xs,
654+
ys,
655+
invs,
656+
transcript,
657+
))
658+
}
659+
}
660+
603661
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> OpeningProver<GpuBackend<E, PCS>>
604662
for GpuProver<GpuBackend<E, PCS>>
605663
{
@@ -743,8 +801,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> DeviceTransporter<Gp
743801
std::mem::forget(pcs_data_basefold);
744802
let pcs_data = Arc::new(pcs_data);
745803

746-
let fixed_mles =
747-
PCS::get_arc_mle_witness_from_commitment(pk.fixed_commit_wd.as_ref().unwrap());
804+
let fixed_mles = PCS::get_arc_mle_witness_from_commitment(pcs_data_original.as_ref());
748805
let fixed_mles = fixed_mles
749806
.iter()
750807
.map(|mle| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, mle)))

ceno_zkvm/src/scheme/utils.rs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,6 @@ pub fn build_main_witness<
326326
);
327327
assert_eq!(input.fixed.len(), cs.num_fixed);
328328

329-
// check all witness size are power of 2
330-
assert!(
331-
input
332-
.witness
333-
.iter()
334-
.all(|v| { v.evaluations_len() == 1 << num_var_with_rotation })
335-
);
336-
337-
if !input.structural_witness.is_empty() {
338-
assert!(
339-
input
340-
.structural_witness
341-
.iter()
342-
.all(|v| { v.evaluations_len() == 1 << num_var_with_rotation })
343-
);
344-
}
345-
346329
let Some(gkr_circuit) = gkr_circuit else {
347330
panic!("empty gkr-iop")
348331
};
@@ -365,6 +348,17 @@ pub fn build_main_witness<
365348
.map(|instance| input.public_input[instance.0].clone())
366349
.collect_vec();
367350

351+
// check all witness size are power of 2
352+
assert!(
353+
input
354+
.witness
355+
.iter()
356+
.chain(&input.structural_witness)
357+
.chain(&input.fixed)
358+
.chain(&pub_io_mles)
359+
.all(|v| { v.evaluations_len() == 1 << num_var_with_rotation })
360+
);
361+
368362
let (_, gkr_circuit_out) = gkr_witness::<E, PCS, PB, PD>(
369363
gkr_circuit,
370364
&input.witness,

gkr_iop/src/gkr/layer/gpu/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use crate::{
4343
use crate::gpu::{MultilinearExtensionGpu, gpu_prover::*};
4444

4545
pub mod utils;
46+
use crate::selector::SelectorContext;
4647
use utils::*;
4748

4849
impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> LinearLayerProver<GpuBackend<E, PCS>>
@@ -112,7 +113,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
112113
pub_io_evals: &[<GpuBackend<E, PCS> as ProverBackend>::E],
113114
challenges: &[<GpuBackend<E, PCS> as ProverBackend>::E],
114115
transcript: &mut impl Transcript<<GpuBackend<E, PCS> as ProverBackend>::E>,
115-
num_instances: usize,
116+
selector_ctxs: &[SelectorContext],
116117
) -> (
117118
LayerProof<<GpuBackend<E, PCS> as ProverBackend>::E>,
118119
Point<<GpuBackend<E, PCS> as ProverBackend>::E>,
@@ -175,8 +176,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
175176
.out_sel_and_eval_exprs
176177
.iter()
177178
.zip(out_points.iter())
178-
.map(|((sel_type, _), point)| {
179-
build_eq_x_r_with_sel_gpu(&cuda_hal, point, num_instances, sel_type)
179+
.zip(selector_ctxs.iter())
180+
.map(|(((sel_type, _), point), selector_ctx)| {
181+
build_eq_x_r_with_sel_gpu(&cuda_hal, point, selector_ctx, sel_type)
180182
})
181183
// for rotation left point
182184
.chain(

gkr_iop/src/gkr/layer/gpu/utils.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use multilinear_extensions::{
1010
Expression, mle::Point, monomial::Term, utils::eval_by_expr_constant,
1111
};
1212

13-
use crate::selector::SelectorType;
13+
use crate::selector::{SelectorContext, SelectorType};
1414

1515
use crate::gpu::{MultilinearExtensionGpu, gpu_prover::*};
1616

@@ -63,7 +63,7 @@ pub fn extract_mle_relationships_from_monomial_terms<'a, E: ExtensionField>(
6363
pub fn build_eq_x_r_with_sel_gpu<E: ExtensionField>(
6464
hal: &CudaHalBB31,
6565
point: &Point<E>,
66-
num_instances: usize,
66+
selector_ctx: &SelectorContext,
6767
selector: &SelectorType<E>,
6868
) -> MultilinearExtensionGpu<'static, E> {
6969
if std::any::TypeId::of::<E::BaseField>() != std::any::TypeId::of::<BB31Base>() {
@@ -74,12 +74,16 @@ pub fn build_eq_x_r_with_sel_gpu<E: ExtensionField>(
7474
let (num_instances, is_sp32, indices) = match selector {
7575
SelectorType::None => panic!("SelectorType::None"),
7676
SelectorType::Whole(_expr) => (eq_len, false, vec![]),
77-
SelectorType::Prefix(_, _expr) => (num_instances, false, vec![]),
78-
SelectorType::OrderedSparse32 { indices, .. } => (num_instances, true, indices.clone()),
77+
SelectorType::Prefix(_expr) => (selector_ctx.num_instances, false, vec![]),
78+
SelectorType::OrderedSparse32 { indices, .. } => {
79+
(selector_ctx.num_instances, true, indices.clone())
80+
}
81+
SelectorType::QuarkBinaryTreeLessThan(..) => unimplemented!(),
7982
};
8083

8184
// type eq
8285
let eq_mle = if is_sp32 {
86+
assert_eq!(selector_ctx.offset, 0);
8387
let eq = build_eq_x_r_gpu(hal, point);
8488
let mut eq_buf = match eq.mle {
8589
GpuFieldType::Base(_) => panic!("should be ext field"),
@@ -103,6 +107,7 @@ pub fn build_eq_x_r_with_sel_gpu<E: ExtensionField>(
103107
&hal.inner,
104108
&gpu_points,
105109
&mut gpu_output,
110+
selector_ctx.offset,
106111
num_instances,
107112
)
108113
.unwrap();
@@ -135,6 +140,7 @@ pub fn build_eq_x_r_gpu<E: ExtensionField>(
135140
&hal.inner,
136141
&gpu_points,
137142
&mut gpu_output,
143+
0,
138144
eq_len,
139145
)
140146
.unwrap();

0 commit comments

Comments
 (0)