11use super :: hal:: {
2- DeviceTransporter , MainSumcheckProver , OpeningProver , ProverDevice , TowerProver , TraceCommitter ,
2+ DeviceTransporter , EccQuarkProver , MainSumcheckProver , OpeningProver , ProverDevice ,
3+ TowerProver , TraceCommitter ,
34} ;
45use crate :: {
56 error:: ZKVMError ,
@@ -9,25 +10,15 @@ use crate::{
910 } ,
1011 structs:: { ComposedConstrainSystem , PointAndEval , TowerProofs } ,
1112} ;
12- use ceno_gpu:: bb31:: GpuPolynomialExt ;
1313use ff_ext:: { ExtensionField , GoldilocksExt2 } ;
1414use gkr_iop:: {
15- gkr:: {
16- self , Evaluation , GKRProof , GKRProverOutput ,
17- layer:: { LayerWitness , gpu:: utils:: extract_mle_relationships_from_monomial_terms} ,
18- } ,
15+ gkr:: { self , Evaluation , GKRProof , GKRProverOutput , layer:: LayerWitness } ,
1916 gpu:: { GpuBackend , GpuProver } ,
20- hal:: ProverBackend ,
17+ hal:: { MultilinearPolynomial , ProverBackend } ,
2118} ;
2219use itertools:: { Itertools , chain} ;
2320use mpcs:: { Point , PolynomialCommitmentScheme } ;
24- use multilinear_extensions:: {
25- Instance , WitnessId ,
26- mle:: { FieldType , MultilinearExtension } ,
27- monomialize_expr_to_wit_terms,
28- util:: ceil_log2,
29- } ;
30- use rayon:: iter:: { IntoParallelRefIterator , ParallelIterator } ;
21+ use multilinear_extensions:: { mle:: MultilinearExtension , util:: ceil_log2} ;
3122use std:: { collections:: BTreeMap , sync:: Arc } ;
3223use sumcheck:: {
3324 macros:: { entered_span, exit_span} ,
@@ -37,16 +28,20 @@ use sumcheck::{
3728use transcript:: { BasicTranscript , Transcript } ;
3829use witness:: next_pow2_instance_padding;
3930
40- use crate :: circuit_builder:: ConstraintSystem ;
41- use gkr_iop:: hal:: MultilinearPolynomial ;
42-
4331#[ cfg( feature = "gpu" ) ]
4432use gkr_iop:: gpu:: gpu_prover:: * ;
4533
4634pub struct GpuTowerProver ;
4735
48- use crate :: { e2e:: ShardContext , scheme:: constants:: NUM_FANIN } ;
49- use gkr_iop:: gpu:: { ArcMultilinearExtensionGpu , MultilinearExtensionGpu } ;
36+ use crate :: {
37+ e2e:: ShardContext ,
38+ scheme:: { constants:: NUM_FANIN , cpu:: CpuEccProver } ,
39+ structs:: EccQuarkProof ,
40+ } ;
41+ use gkr_iop:: {
42+ gpu:: { ArcMultilinearExtensionGpu , MultilinearExtensionGpu } ,
43+ selector:: SelectorContext ,
44+ } ;
5045
5146// Extract out_evals from GPU-built tower witnesses
5247#[ allow( clippy:: type_complexity) ]
@@ -102,7 +97,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TraceCommitter<GpuBa
10297 for GpuProver < GpuBackend < E , PCS > >
10398{
10499 fn commit_traces < ' a > (
105- & mut self ,
100+ & self ,
106101 traces : BTreeMap < usize , witness:: RowMajorMatrix < E :: BaseField > > ,
107102 ) -> (
108103 Vec < MultilinearExtensionGpu < ' a , E > > ,
@@ -534,13 +529,48 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
534529 gkr_circuit,
535530 } = composed_cs;
536531
532+ let num_instances = input. num_instances ( ) ;
537533 let log2_num_instances = input. log2_num_instances ( ) ;
538534 let num_threads = optimal_sumcheck_threads ( log2_num_instances) ;
539535 let num_var_with_rotation = log2_num_instances + composed_cs. rotation_vars ( ) . unwrap_or ( 0 ) ;
540536
541537 let Some ( gkr_circuit) = gkr_circuit else {
542538 panic ! ( "empty gkr circuit" )
543539 } ;
540+ let selector_ctxs = if cs. ec_final_sum . is_empty ( ) {
541+ // it's not global chip
542+ vec ! [
543+ SelectorContext {
544+ offset: 0 ,
545+ num_instances,
546+ num_vars: num_var_with_rotation,
547+ } ;
548+ gkr_circuit
549+ . layers
550+ . first( )
551+ . map( |layer| layer. out_sel_and_eval_exprs. len( ) )
552+ . unwrap_or( 0 )
553+ ]
554+ } else {
555+ // it's global chip
556+ vec ! [
557+ SelectorContext {
558+ offset: 0 ,
559+ num_instances: input. num_instances[ 0 ] ,
560+ num_vars: num_var_with_rotation,
561+ } ,
562+ SelectorContext {
563+ offset: input. num_instances[ 0 ] ,
564+ num_instances: input. num_instances[ 1 ] ,
565+ num_vars: num_var_with_rotation,
566+ } ,
567+ SelectorContext {
568+ offset: 0 ,
569+ num_instances,
570+ num_vars: num_var_with_rotation,
571+ } ,
572+ ]
573+ } ;
544574 let pub_io_mles = cs
545575 . instance_openings
546576 . iter ( )
@@ -574,7 +604,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
574604 . collect_vec ( ) ,
575605 challenges,
576606 transcript,
577- num_instances ,
607+ & selector_ctxs ,
578608 ) ?;
579609 assert_eq ! ( rt. len( ) , 1 , "TODO support multi-layer gkr iop" ) ;
580610 Ok ( (
@@ -600,6 +630,34 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
600630 }
601631}
602632
633+ impl < E : ExtensionField , PCS : PolynomialCommitmentScheme < E > > EccQuarkProver < GpuBackend < E , PCS > >
634+ for GpuProver < GpuBackend < E , PCS > >
635+ {
636+ fn prove_ec_sum_quark < ' a > (
637+ & self ,
638+ num_instances : usize ,
639+ xs : Vec < Arc < MultilinearExtensionGpu < ' a , E > > > ,
640+ ys : Vec < Arc < MultilinearExtensionGpu < ' a , E > > > ,
641+ invs : Vec < Arc < MultilinearExtensionGpu < ' a , E > > > ,
642+ transcript : & mut impl Transcript < E > ,
643+ ) -> Result < EccQuarkProof < E > , ZKVMError > {
644+ // TODO implement GPU version of `create_ecc_proof`
645+ let xs = xs. iter ( ) . map ( |mle| mle. inner_to_mle ( ) . into ( ) ) . collect_vec ( ) ;
646+ let ys = ys. iter ( ) . map ( |mle| mle. inner_to_mle ( ) . into ( ) ) . collect_vec ( ) ;
647+ let invs = invs
648+ . iter ( )
649+ . map ( |mle| mle. inner_to_mle ( ) . into ( ) )
650+ . collect_vec ( ) ;
651+ Ok ( CpuEccProver :: create_ecc_proof (
652+ num_instances,
653+ xs,
654+ ys,
655+ invs,
656+ transcript,
657+ ) )
658+ }
659+ }
660+
603661impl < E : ExtensionField , PCS : PolynomialCommitmentScheme < E > > OpeningProver < GpuBackend < E , PCS > >
604662 for GpuProver < GpuBackend < E , PCS > >
605663{
@@ -743,8 +801,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> DeviceTransporter<Gp
743801 std:: mem:: forget ( pcs_data_basefold) ;
744802 let pcs_data = Arc :: new ( pcs_data) ;
745803
746- let fixed_mles =
747- PCS :: get_arc_mle_witness_from_commitment ( pk. fixed_commit_wd . as_ref ( ) . unwrap ( ) ) ;
804+ let fixed_mles = PCS :: get_arc_mle_witness_from_commitment ( pcs_data_original. as_ref ( ) ) ;
748805 let fixed_mles = fixed_mles
749806 . iter ( )
750807 . map ( |mle| Arc :: new ( MultilinearExtensionGpu :: from_ceno ( & cuda_hal, mle) ) )
0 commit comments