Skip to content

Commit

Permalink
feat: revamp zero prove function (#793)
Browse files Browse the repository at this point in the history
* feat: revamp zero prove function

* fix: improvements

* fix: formatting

* fix: further paralelize segment proofs

* fix: remove worker number limit

* fix: deadlock

* fix: comment

* fix: review
  • Loading branch information
atanmarko authored Nov 22, 2024
1 parent c79ef67 commit 4a99bcc
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 27 deletions.
2 changes: 1 addition & 1 deletion scripts/prove_stdio.sh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ cargo build --release --jobs "$num_procs"
start_time=$(date +%s%N)

cmd=("${REPO_ROOT}/target/release/leader" --runtime in-memory \
--load-strategy on-demand -n 1 \
--load-strategy on-demand \
--block-batch-size "$BLOCK_BATCH_SIZE")

if [[ "$USE_TEST_CONFIG" == "use_test_config" ]]; then
Expand Down
6 changes: 3 additions & 3 deletions zero/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{debug_utils::save_inputs_to_disk, prover_state::p_state};

registry!();

#[derive(Deserialize, Serialize, RemoteExecute)]
#[derive(Deserialize, Serialize, RemoteExecute, Clone)]
pub struct SegmentProof {
pub save_inputs_on_error: bool,
}
Expand Down Expand Up @@ -207,7 +207,7 @@ impl Drop for SegmentProofSpan {
}
}

#[derive(Deserialize, Serialize, RemoteExecute)]
#[derive(Deserialize, Serialize, RemoteExecute, Clone)]
pub struct SegmentAggProof {
pub save_inputs_on_error: bool,
}
Expand Down Expand Up @@ -289,7 +289,7 @@ impl Monoid for SegmentAggProof {
}
}

#[derive(Deserialize, Serialize, RemoteExecute)]
#[derive(Deserialize, Serialize, RemoteExecute, Clone)]
pub struct BatchAggProof {
pub save_inputs_on_error: bool,
}
Expand Down
198 changes: 175 additions & 23 deletions zero/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use anyhow::{Context, Result};
use evm_arithmetization::Field;
use evm_arithmetization::SegmentDataIterator;
use futures::{
future, future::BoxFuture, stream::FuturesUnordered, FutureExt, TryFutureExt, TryStreamExt,
future::BoxFuture,
future::{self, try_join, try_join_all},
stream::FuturesUnordered,
FutureExt as _, StreamExt as _, TryFutureExt as _, TryStreamExt as _,
};
use hashbrown::HashMap;
use num_traits::ToPrimitive as _;
Expand All @@ -23,10 +26,10 @@ use plonky2::plonk::circuit_data::CircuitConfig;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{oneshot, Semaphore};
use tokio::sync::{mpsc, oneshot, Semaphore};
use trace_decoder::observer::DummyObserver;
use trace_decoder::{BlockTrace, OtherBlockData, WireDisposition};
use tracing::{error, info};
use tracing::{debug, error, info};

use crate::fs::generate_block_proof_file_name;
use crate::ops;
Expand Down Expand Up @@ -116,6 +119,8 @@ impl BlockProverInput {
WIRE_DISPOSITION,
)?;

let batch_count = block_generation_inputs.len();

// Create segment proof.
let seg_prove_ops = ops::SegmentProof {
save_inputs_on_error,
Expand All @@ -131,29 +136,176 @@ impl BlockProverInput {
save_inputs_on_error,
};

// Segment the batches, prove segments and aggregate them to resulting batch
// proofs.
let batch_proof_futs: FuturesUnordered<_> = block_generation_inputs
.iter()
.enumerate()
.map(|(idx, txn_batch)| {
let segment_data_iterator =
SegmentDataIterator::<Field>::new(txn_batch, Some(max_cpu_len_log));

Directive::map(IndexedStream::from(segment_data_iterator), &seg_prove_ops)
.fold(&seg_agg_ops)
.run(&proof_runtime.heavy_proof)
.map(move |e| {
e.map(|p| (idx, crate::proof_types::BatchAggregatableProof::from(p)))
})
// Generate channels to communicate segments of each batch to a batch proving
// task. We generate segments and send them to the proving task, where they
// are proven in parallel.
let (segment_senders, segment_receivers): (Vec<_>, Vec<_>) = (0..batch_count)
.map(|_idx| {
let (segment_tx, segment_rx) =
mpsc::channel::<Option<evm_arithmetization::AllData>>(1);
(segment_tx, segment_rx)
})
.collect();
.unzip();

// The size of this channel does not matter much, as it is only used to collect
// batch proofs.
let (batch_proof_tx, mut batch_proof_rx) =
mpsc::channel::<(usize, crate::proof_types::BatchAggregatableProof)>(32);

// Spin up a task for each batch to generate segments for that batch
// and send them to the proving task.
let segment_generation_task = tokio::spawn(async move {
let mut batch_segment_futures: FuturesUnordered<_> = FuturesUnordered::new();

for (batch_idx, (txn_batch, segment_tx)) in block_generation_inputs
.into_iter()
.zip(segment_senders)
.enumerate()
{
batch_segment_futures.push(async move {
let segment_data_iterator =
SegmentDataIterator::<Field>::new(&txn_batch, Some(max_cpu_len_log));
for (segment_idx, segment_data) in segment_data_iterator.enumerate() {
segment_tx
.send(Some(segment_data))
.await
.context(format!("failed to send segment data for batch {batch_idx} segment {segment_idx}"))?;
}
// Mark the end of the batch segments by sending `None`
segment_tx
.send(None)
.await
.context(format!("failed to send end segment data indicator for batch {batch_idx}"))?;
anyhow::Ok(())
});
}
while let Some(it) = batch_segment_futures.next().await {
// In case of an error, propagate the error to the main task
it?;
}
let () = batch_segment_futures.try_collect().await?;
anyhow::Ok(())
});

let proof_runtime_ = proof_runtime.clone();
let batches_proving_task = tokio::spawn(async move {
let mut batch_proving_futures = FuturesUnordered::new();
// Span a proving subtask for each batch where we generate segment proofs
// and aggregate them to batch proof.
for (batch_idx, mut segment_rx) in segment_receivers.into_iter().enumerate() {
let batch_proof_tx = batch_proof_tx.clone();
let seg_prove_ops = seg_prove_ops.clone();
let seg_agg_ops = seg_agg_ops.clone();
let proof_runtime = proof_runtime_.clone();
// Tasks to dispatch proving jobs and aggregate segment proofs of one batch
batch_proving_futures.push(async move {
let mut batch_segment_aggregatable_proofs = Vec::new();

// This channel collects segment proofs from the one batch
// proven in parallel. The size of this channel does not matter much,
// as it is only used to collect segment aggregatable proofs.
let (segment_proof_tx, mut segment_proof_rx) =
mpsc::channel::<(usize, crate::proof_types::SegmentAggregatableProof)>(32);

// Wait for segments and dispatch them to the segment proof worker task.
// The segment proof worker task will prove the segment and send it back.
let mut segment_counter = 0;
let mut segment_proving_tasks = Vec::new();
while let Some(Some(segment_data)) = segment_rx.recv().await {
let seg_prove_ops = seg_prove_ops.clone();
let proof_runtime = proof_runtime.clone();
let segment_proof_tx = segment_proof_tx.clone();
// Prove one segment in a dedicated async task.
let segment_proving_task = tokio::spawn(async move {
debug!(%batch_idx, %segment_counter, "proving batch segment");
let seg_aggregatable_proof= Directive::map(
IndexedStream::from([segment_data]),
&seg_prove_ops,
)
.run(&proof_runtime.heavy_proof)
.await?
.into_values_sorted()
.await?
.into_iter()
.next()
.context(format!(
"failed to get segment proof, batch: {batch_idx}, segment: {segment_counter}"
))?;

segment_proof_tx
.send((segment_counter, seg_aggregatable_proof))
.await
.context(format!(
"unable to send segment proof, batch: {batch_idx}, segment: {segment_counter}"
))?;
anyhow::Ok(())
});

segment_proving_tasks.push(segment_proving_task);
segment_counter += 1;
}
drop(segment_proof_tx);
// Wait for all the segment proving tasks of one batch to finish.
while let Some((segment_idx, segment_aggregatable_proof)) = segment_proof_rx.recv().await {
batch_segment_aggregatable_proofs.push((segment_idx, segment_aggregatable_proof));
}
try_join_all(segment_proving_tasks).await?;
batch_segment_aggregatable_proofs.sort_by(|(a, _), (b, _)| a.cmp(b));
debug!(%block_number, batch=%batch_idx, "finished proving all segments");
// We have proved all the segments in a batch,
// now we need to aggregate them to the batch proof.
// Fold the segment aggregated proof stream into a single batch proof.
let batch_proof = if batch_segment_aggregatable_proofs.len() == 1 {
// If there is only one segment aggregated proof, just transform it to batch proof.
(batch_idx, crate::proof_types::BatchAggregatableProof::from(
batch_segment_aggregatable_proofs.pop().map(|(_, it)| it).unwrap(),
))
} else {
Directive::fold(IndexedStream::from(batch_segment_aggregatable_proofs.into_iter().map(|(_, it)| it)), &seg_agg_ops)
.run(&proof_runtime.light_proof)
.map(move |e| {
e.map(|p| {
(
batch_idx,
crate::proof_types::BatchAggregatableProof::from(p),
)
})
})
.await?
};
debug!(%block_number, batch=%batch_idx, "generated batch proof for block");
batch_proof_tx.send(batch_proof).await.context(format!(
"unable to send batch proof, block: {block_number}, batch: {batch_idx}"
))?;
anyhow::Ok(())
});
}
// Wait for all the batch proving tasks to finish. Exit early on error.
while let Some(it) = batch_proving_futures.next().await {
it?;
}
anyhow::Ok(())
});

// Collect all the batch proofs.
let mut batch_proofs: Vec<(usize, crate::proof_types::BatchAggregatableProof)> = Vec::new();
while let Some((batch_idx, batch_proof)) = batch_proof_rx.recv().await {
batch_proofs.push((batch_idx, batch_proof));
}
debug!(%block_number, "collected all batch proofs");

// Wait for the segment generation and proving tasks to finish.
let _ = try_join(segment_generation_task, batches_proving_task).await?;

batch_proofs.sort_by(|(a, _), (b, _)| a.cmp(b));

// Fold the batch aggregated proof stream into a single proof.
let final_batch_proof =
Directive::fold(IndexedStream::new(batch_proof_futs), &batch_agg_ops)
.run(&proof_runtime.light_proof)
.await?;
let final_batch_proof = Directive::fold(
IndexedStream::from(batch_proofs.into_iter().map(|(_, it)| it)),
&batch_agg_ops,
)
.run(&proof_runtime.light_proof)
.await?;

if let crate::proof_types::BatchAggregatableProof::BatchAgg(proof) = final_batch_proof {
let block_number = block_number
Expand Down

0 comments on commit 4a99bcc

Please sign in to comment.