Skip to content
Open
22 changes: 11 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 24 additions & 24 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ repository = "https://github.com/scroll-tech/ceno"
version = "0.1.0"

[workspace.dependencies]
ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.13" }
mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.13" }
multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.13" }
p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.13" }
poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.13" }
sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.13" }
sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.13" }
transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.13" }
whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.13" }
witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.13" }
ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.14" }
mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.14" }
multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.14" }
p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.14" }
poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.14" }
sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.14" }
sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.14" }
transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.14" }
whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.14" }
witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.14" }

alloy-primitives = "1.3"
anyhow = { version = "1.0", default-features = false }
Expand Down Expand Up @@ -97,17 +97,17 @@ opt-level = 3
[profile.release]
lto = "thin"

# [patch."ssh://[email protected]/scroll-tech/ceno-gpu.git"]
# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features=["bb31"] }

# [patch."https://github.com/scroll-tech/gkr-backend"]
# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" }
# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" }
# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" }
# p3 = { path = "../gkr-backend/crates/p3", package = "p3" }
# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" }
# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" }
# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" }
# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" }
# whir = { path = "../gkr-backend/crates/whir", package = "whir" }
# witness = { path = "../gkr-backend/crates/witness", package = "witness" }
#[patch."ssh://[email protected]/scroll-tech/ceno-gpu.git"]
#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] }
#
#[patch."https://github.com/scroll-tech/gkr-backend"]
#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" }
#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" }
#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" }
#p3 = { path = "../gkr-backend/crates/p3", package = "p3" }
#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" }
#sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" }
#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" }
#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" }
#whir = { path = "../gkr-backend/crates/whir", package = "whir" }
#witness = { path = "../gkr-backend/crates/witness", package = "witness" }
8 changes: 8 additions & 0 deletions gkr_iop/src/gkr/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use multilinear_extensions::{
Expression, Instance, StructuralWitIn, ToExpr,
mle::{Point, PointAndEval},
monomial::Term,
utils::Node,
};
use p3::field::FieldAlgebra;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator};
Expand Down Expand Up @@ -48,6 +49,8 @@ pub enum LayerType {
Linear,
}

pub type DagInfo<E> = (Vec<Node>, Vec<Expression<E>>, u32, usize, usize);

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound(
serialize = "E::BaseField: Serialize",
Expand Down Expand Up @@ -103,6 +106,10 @@ pub struct Layer<E: ExtensionField> {
pub main_sumcheck_expression_monomial_terms: Option<Vec<Term<Expression<E>, Expression<E>>>>,
pub main_sumcheck_expression: Option<Expression<E>>,

// flatten computation dag
// (dag, coeffs, final_out_index, max_dag_depth, max_degree)
pub main_sumcheck_expression_dag: Option<DagInfo<E>>,

// rotation sumcheck expression, only optionally valid for zerocheck
// store in 2 forms: expression & monomial
pub rotation_sumcheck_expression_monomial_terms:
Expand Down Expand Up @@ -175,6 +182,7 @@ impl<E: ExtensionField> Layer<E> {
expr_names,
main_sumcheck_expression_monomial_terms: None,
main_sumcheck_expression: None,
main_sumcheck_expression_dag: None,
rotation_sumcheck_expression_monomial_terms: None,
rotation_sumcheck_expression: None,
};
Expand Down
48 changes: 38 additions & 10 deletions gkr_iop/src/gkr/layer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use multilinear_extensions::{
Expression,
mle::{MultilinearExtension, Point},
monomial::Term,
utils::eval_by_expr_constant,
};
use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
Expand Down Expand Up @@ -232,8 +233,26 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
layer.n_fixed,
layer.n_instance,
);

// process dag
// (dag, coeffs, final_out_index, max_dag_depth, max_degree)
let (dag, dag_coeffs, final_out_index, max_dag_depth, max_degree) =
layer.main_sumcheck_expression_dag.as_ref().unwrap();

let pub_io_eval_scalar = pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec();
// format: pub_io ++ challenge ++ constant
let dag_coeffs = dag_coeffs
.iter()
.map(|c| eval_by_expr_constant(&pub_io_eval_scalar, &main_sumcheck_challenges, c))
.map(|either_v| match either_v {
Either::Left(base_field_val) => E::from(base_field_val),
Either::Right(ext_field_val) => ext_field_val,
})
.collect_vec();

// process monomial terms
// Calculate max_num_var and max_degree from the extracted relationships
let (term_coefficients, mle_indices_per_term, mle_size_info) =
let (monomial_coefficients, mle_indices_per_term, mle_size_info) =
extract_mle_relationships_from_monomial_terms(
&layer
.main_sumcheck_expression_monomial_terms
Expand All @@ -243,18 +262,18 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
&pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(),
&main_sumcheck_challenges,
);

let max_num_var = max_num_variables;
let max_degree = mle_indices_per_term
.iter()
.map(|indices| indices.len())
.max()
.unwrap_or(0);

// Convert types for GPU function Call
let monomial_coefficients: Vec<BB31Ext> =
unsafe { std::mem::transmute(monomial_coefficients) };

// Convert types for GPU function Call
let basic_tr: &mut BasicTranscript<BB31Ext> =
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
let term_coefficients_gl64: Vec<BB31Ext> =
unsafe { std::mem::transmute(term_coefficients) };
let dag_coeffs: Vec<BB31Ext> = unsafe { std::mem::transmute(dag_coeffs) };

let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
unsafe { std::mem::transmute(all_witins_gpu) };
let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
Expand All @@ -264,13 +283,18 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
&cuda_hal,
all_witins_gpu_type_gl64,
&mle_size_info,
&term_coefficients_gl64,
&monomial_coefficients,
&mle_indices_per_term,
max_num_var,
max_degree,
*max_degree,
dag,
*max_dag_depth,
&dag_coeffs,
*final_out_index,
basic_tr,
)
.unwrap();

let evals_gpu = evals_gpu.into_iter().flatten().collect_vec();
let row_challenges = challenges_gpu.iter().map(|c| c.elements).collect_vec();

Expand Down Expand Up @@ -389,6 +413,10 @@ pub(crate) fn prove_rotation_gpu<E: ExtensionField, PCS: PolynomialCommitmentSch
&mle_indices_per_term,
max_num_var,
max_degree,
&[],
0,
&[],
0,
basic_tr,
)
.unwrap();
Expand Down
41 changes: 35 additions & 6 deletions gkr_iop/src/gkr/layer/zerocheck_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use multilinear_extensions::{
macros::{entered_span, exit_span},
mle::{IntoMLE, Point},
monomialize_expr_to_wit_terms,
utils::{eval_by_expr, eval_by_expr_with_instance, expr_convert_to_witins},
utils::{
build_factored_dag_commutative, dag_stats, eval_by_expr, eval_by_expr_with_instance,
expr_convert_to_witins,
},
virtual_poly::VPAuxInfo,
};
use p3::field::{FieldAlgebra, dot_product};
Expand Down Expand Up @@ -165,18 +168,44 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
self.n_fixed as WitnessId,
self.n_instance,
);
tracing::debug!("main sumcheck degree: {}", zero_expr.degree());
let zero_expr_degree = zero_expr.degree();
self.main_sumcheck_expression = Some(zero_expr);
self.main_sumcheck_expression_monomial_terms = self
.main_sumcheck_expression
.as_ref()
.map(|expr| expr.get_monomial_terms());
tracing::debug!(
"main sumcheck monomial terms count: {}",

{
if let Some(terms) = self.main_sumcheck_expression_monomial_terms.as_ref() {
let num_mul: usize = terms.iter().map(|term| term.product.len()).sum();
let num_add = terms.iter().len() - 1;

tracing::debug!(
"layer name {} monomial num_add: {num_add} num_mul: {num_mul}",
self.name,
);
}
}

self.main_sumcheck_expression_dag = {
self.main_sumcheck_expression_monomial_terms
.as_ref()
.map_or(0, |terms| terms.len()),
);
.map(|terms| {
// selector are structural witin, which is used to be the largest id.
let (dag, coeffs, Some(final_out_index), max_dag_depth) = build_factored_dag_commutative(terms, false) else { panic!() };
let max_degree = zero_expr_degree;

let (num_add, num_mul) = dag_stats(&dag);
tracing::debug!(
"layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
max_dag_depth {max_dag_depth} num_scalar {} final_out_index {final_out_index}",
self.name,
coeffs.len(),
);
(dag, coeffs, final_out_index, max_dag_depth as usize, zero_expr_degree)
})
};

exit_span!(span);
}

Expand Down