diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 928e620fed..9172b64d99 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -60,7 +60,6 @@ pub enum DynamoLlmResult { pub unsafe extern "C" fn dynamo_llm_init( namespace_c_str: *const c_char, component_c_str: *const c_char, - worker_id: i64, kv_block_size: u32, ) -> DynamoLlmResult { initialize_tracing(); @@ -102,7 +101,7 @@ pub unsafe extern "C" fn dynamo_llm_init( match result { Ok(_) => match KV_PUB.get_or_try_init(move || { - dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size) + dynamo_create_kv_publisher(namespace, component, kv_block_size) }) { Ok(_) => DynamoLlmResult::OK, Err(e) => { @@ -144,7 +143,6 @@ pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult { fn dynamo_create_kv_publisher( namespace: String, component: String, - worker_id: i64, kv_block_size: u32, ) -> Result { tracing::info!("Creating KV Publisher for model: {}", component); @@ -154,7 +152,7 @@ fn dynamo_create_kv_publisher( { Ok(drt) => { let backend = drt.namespace(namespace)?.component(component)?; - KvEventPublisher::new(backend, worker_id as u64, kv_block_size, None) + KvEventPublisher::new(backend, kv_block_size, None) } Err(e) => Err(e), } diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 24b6c8aa32..48e18df0ce 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -143,7 +143,6 @@ impl ZmqKvEventPublisher { fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult { let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( component.inner, - config.worker_id, config.kv_block_size as u32, Some(KvEventSourceConfig::Zmq { endpoint: config.zmq_endpoint, @@ -239,20 +238,14 @@ pub(crate) struct KvEventPublisher { #[pymethods] impl KvEventPublisher { #[new] - #[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))] - fn new( - component: Component, - worker_id: WorkerId, - kv_block_size: usize, - dp_rank: DpRank, - ) -> PyResult { + #[pyo3(signature = (component, kv_block_size, dp_rank=0))] + fn new(component: Component, kv_block_size: usize, dp_rank: DpRank) -> PyResult { if kv_block_size == 0 { return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); } let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( component.inner, - worker_id, kv_block_size as u32, None, ) diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index bbeb7c40e8..2ee9d58953 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -354,9 +354,10 @@ impl RadixTree { None => { tracing::warn!( worker_id = worker.worker_id.to_string(), - dp_rank = ?worker.dp_rank, + dp_rank = worker.dp_rank, id, parent_hash = ?op.parent_hash, + num_blocks = op.blocks.len(), "Failed to find parent block; skipping store operation" ); return Err(KvCacheEventError::ParentBlockNotFound); @@ -412,8 +413,10 @@ impl RadixTree { Some(entry) => entry.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), + worker_id = worker.worker_id.to_string(), + dp_rank = worker.dp_rank, id, + block_hash = ?block, "Failed to find block to remove; skipping remove operation" ); return Err(KvCacheEventError::BlockNotFound); diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index f620c416e2..ab595a21e1 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -97,7 +97,6 @@ pub struct KvEventPublisher { impl KvEventPublisher { pub fn new( component: Component, - worker_id: u64, kv_block_size: u32, source_config: Option, ) -> Result { @@ -105,6 +104,13 @@ impl KvEventPublisher { let (tx, rx) = mpsc::unbounded_channel::(); + // Infer worker_id from component's primary lease + let worker_id = component + .drt() + .primary_lease() + .expect("Cannot publish KV events without lease") + .id(); + // Create our event source (if any) let mut source = None; if let Some(config) = source_config { diff --git a/lib/llm/src/mocker.rs b/lib/llm/src/mocker.rs index 4e21c8d81f..671855247a 100644 --- a/lib/llm/src/mocker.rs +++ b/lib/llm/src/mocker.rs @@ -5,5 +5,6 @@ pub mod engine; pub mod evictor; pub mod kv_manager; pub mod protocols; +pub mod running_mean; pub mod scheduler; pub mod sequence; diff --git a/lib/llm/src/mocker/engine.rs b/lib/llm/src/mocker/engine.rs index 0a43445fe7..a703106c14 100644 --- a/lib/llm/src/mocker/engine.rs +++ b/lib/llm/src/mocker/engine.rs @@ -8,7 +8,7 @@ use crate::kv_router::publisher::WorkerMetricsPublisher; use crate::mocker::protocols::DirectRequest; -use crate::mocker::protocols::{MockEngineArgs, OutputSignal}; +use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType}; use crate::mocker::scheduler::Scheduler; use crate::protocols::TokenIdType; use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}; @@ -23,9 +23,6 @@ use dynamo_runtime::{ pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait}, traits::DistributedRuntimeProvider, }; - -use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData}; -use crate::kv_router::publisher::KvEventPublisher; use futures::StreamExt; use rand::Rng; use std::collections::HashMap; @@ -37,10 +34,9 @@ use uuid::Uuid; pub const MOCKER_COMPONENT: &str = "mocker"; -/// Generate a random token ID from 1k to 5k fn generate_random_token() -> TokenIdType { let mut rng = rand::rng(); - rng.random_range(1000..5000) + rng.random_range(100..200) } /// AsyncEngine wrapper around the Scheduler that generates random character tokens @@ -71,26 +67,25 @@ impl MockVllmEngine { tracing::info!("Engine startup simulation completed"); } - let (schedulers, kv_event_receiver) = self.start_schedulers( + // Pass component to schedulers only if prefix caching is enabled and not a decode worker + let scheduler_component = if self.engine_args.enable_prefix_caching + && self.engine_args.worker_type != WorkerType::Decode + { + Some(component.clone()) + } else { + None + }; + + let schedulers = self.start_schedulers( self.engine_args.clone(), self.active_requests.clone(), + scheduler_component, cancel_token.clone(), ); Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone()) .await?; - // Start KV events publishing with the actual receivers from schedulers - if self.engine_args.enable_prefix_caching { - Self::start_kv_events_publishing( - kv_event_receiver, - Some(component.clone()), - self.engine_args.block_size, - cancel_token.clone(), - ) - .await?; - } - Ok(()) } @@ -100,18 +95,14 @@ impl MockVllmEngine { } /// Create schedulers and spawn their background tasks for distributing token notifications - /// Returns schedulers and their corresponding KV event receivers fn start_schedulers( &self, args: MockEngineArgs, active_requests: Arc>>>, + component: Option, cancel_token: CancellationToken, - ) -> ( - Vec, - Vec>, - ) { + ) -> Vec { let mut schedulers = Vec::::new(); - let mut kv_event_receivers = Vec::new(); let mut senders = Vec::with_capacity(args.dp_size as usize); // Create multiple schedulers and their background tasks @@ -119,20 +110,16 @@ impl MockVllmEngine { // Create a shared output channel that this scheduler will use let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); - // Create a channel for KV events from this scheduler - let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::(); - let scheduler = Scheduler::new( args.clone(), dp_rank, Some(output_tx), - Some(kv_events_tx), // Pass the KV events sender to scheduler + component.clone(), Some(cancel_token.clone()), ); senders.push(scheduler.request_sender()); schedulers.push(scheduler); - kv_event_receivers.push(kv_events_rx); // Spawn a background task for this scheduler to distribute token notifications to active requests // let output_rx = Arc::new(Mutex::new(output_rx)); @@ -166,7 +153,7 @@ impl MockVllmEngine { .set(senders) .expect("Already initialized"); - (schedulers, kv_event_receivers) + schedulers } /// Start background tasks to publish metrics on change @@ -228,83 +215,6 @@ impl MockVllmEngine { tracing::info!("Metrics background tasks started"); Ok(()) } - - /// Start background tasks to collect and publish KV events from schedulers - async fn start_kv_events_publishing( - kv_event_receivers: Vec>, - component: Option, - block_size: usize, - cancel_token: CancellationToken, - ) -> Result<()> { - tracing::debug!("Starting KV events publishing"); - - // Only start KV events publishing if we have a component - let Some(comp) = component else { - tracing::warn!("No component provided, skipping KV events publishing"); - return Ok(()); - }; - tracing::debug!("Component found for KV events publishing"); - - tracing::debug!("Getting worker_id"); - let worker_id = comp - .drt() - .primary_lease() - .expect("Cannot publish KV events without lease") // ← This will PANIC on static! - .id(); - // let worker_id = 0; - tracing::debug!("Worker_id set to: {worker_id}"); - - tracing::debug!("Creating KV event publisher"); - let kv_event_publisher = Arc::new(KvEventPublisher::new( - comp.clone(), - worker_id, - block_size as u32, - None, - )?); - tracing::debug!("KV event publisher created"); - - tracing::debug!( - "Starting KV event background tasks for {} receivers", - kv_event_receivers.len() - ); - for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() { - tracing::debug!("Starting background task for DP rank {dp_rank}"); - let publisher = kv_event_publisher.clone(); - let dp_rank = dp_rank as u32; - let cancel_token = cancel_token.clone(); - - tokio::spawn(async move { - tracing::debug!("Background task started for DP rank {dp_rank}"); - loop { - tokio::select! { - // Receive actual KV events from the scheduler - Some(event_data) = kv_events_rx.recv() => { - // Convert KvCacheEventData to KvCacheEvent with random UUID as event_id - let event = KvCacheEvent { - event_id: Uuid::new_v4().as_u128() as u64, - data: event_data, - dp_rank, - }; - - // Publish the event - if let Err(e) = publisher.publish(event) { - tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}"); - } else { - tracing::trace!("Published KV event for DP rank {dp_rank}"); - } - } - _ = cancel_token.cancelled() => { - tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}"); - break; - } - } - } - }); - } - tracing::info!("All KV event background tasks started"); - - Ok(()) - } } #[async_trait] @@ -356,7 +266,13 @@ impl AsyncEngine, ManyOut, Error> let active_requests = self.active_requests.clone(); let async_context = ctx.context(); - let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize; + let is_prefill = self.engine_args.worker_type == WorkerType::Prefill; + // Override max_tokens to 1 for prefill workers + let max_tokens = if is_prefill { + 1 + } else { + request.stop_conditions.max_tokens.unwrap() as usize + }; // Spawn a task to handle the complex async logic tokio::spawn(async move { @@ -383,7 +299,12 @@ impl AsyncEngine, ManyOut, Error> top_logprobs: None, finish_reason: None, index: None, - disaggregated_params: None, + // Add dummy disaggregated_params for prefill workers + disaggregated_params: if is_prefill { + Some(serde_json::json!("dummy")) + } else { + None + }, extra_args: None, }; diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs index 2e4f1ec68b..17d7491162 100644 --- a/lib/llm/src/mocker/kv_manager.rs +++ b/lib/llm/src/mocker/kv_manager.rs @@ -33,13 +33,20 @@ //! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror //! implementation of the main block manager. +use crate::kv_router::protocols::{ + ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, + KvCacheStoredBlockData, LocalBlockHash, +}; +use crate::kv_router::publisher::KvEventPublisher; use crate::mocker::evictor::LRUEvictor; -use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost}; +use crate::mocker::protocols::{MoveBlock, PrefillCost}; use crate::mocker::sequence::ActiveSequence; use crate::tokens::blocks::UniqueBlock; +use crate::tokens::{BlockHash, SequenceHash}; use derive_getters::Getters; +use dynamo_runtime::component::Component; use std::collections::{HashMap, HashSet}; -use tokio::sync::mpsc; +use std::sync::Arc; #[derive(Getters)] pub struct KvManager { @@ -55,60 +62,113 @@ pub struct KvManager { all_blocks: HashSet, - move_block_response_tx: Option>, + kv_event_publisher: Option>, + + #[getter(copy)] + dp_rank: u32, + + next_event_id: u64, } impl KvManager { pub fn new(max_capacity: usize, block_size: usize) -> Self { - Self::new_with_sender(max_capacity, block_size, None) + Self::new_with_publisher(max_capacity, block_size, None, 0) } - pub fn new_with_sender( + pub fn new_with_publisher( max_capacity: usize, block_size: usize, - move_block_response_tx: Option>, + component: Option, + dp_rank: u32, ) -> Self { let active_blocks = HashMap::new(); let inactive_blocks = LRUEvictor::default(); let all_blocks = HashSet::new(); + let kv_event_publisher = component.map(|comp| { + tracing::info!( + "Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}" + ); + Arc::new( + KvEventPublisher::new(comp, block_size as u32, None) + .expect("Failed to create KV event publisher"), + ) + }); + KvManager { max_capacity, block_size, active_blocks, inactive_blocks, all_blocks, - move_block_response_tx, + kv_event_publisher, + dp_rank, + next_event_id: 0, } } - /// Utility method to send block responses with optional reversing - fn send_block_response( - &self, - mut blocks: Vec, - reverse: bool, - store: bool, + /// Converts stored/removed blocks into KvCacheEventData and publishes if publisher is available + fn publish_kv_event( + &mut self, + full_blocks: Vec, + local_hashes: &[BlockHash], parent_hash: Option, + is_store: bool, ) { - if let Some(ref tx) = self.move_block_response_tx - && !blocks.is_empty() - { - if reverse { - blocks.reverse(); - } - let response = if store { - MoveBlockResponse::Store(blocks, parent_hash) - } else { - MoveBlockResponse::Remove(blocks) - }; - tx.send(response).unwrap(); + if full_blocks.is_empty() { + return; + } + + let Some(ref publisher) = self.kv_event_publisher else { + return; + }; + + let event_data = if is_store { + let num_blocks = full_blocks.len(); + let local_hashes_slice = &local_hashes[local_hashes + .len() + .checked_sub(num_blocks) + .expect("local hashes fewer than stored blocks")..]; + + KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: parent_hash.map(ExternalSequenceBlockHash), + blocks: full_blocks + .into_iter() + .zip(local_hashes_slice.iter()) + .map(|(global_hash, local_hash)| KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(global_hash), + tokens_hash: LocalBlockHash(*local_hash), + }) + .collect(), + }) + } else { + KvCacheEventData::Removed(KvCacheRemoveData { + block_hashes: full_blocks + .into_iter() + .map(ExternalSequenceBlockHash) + .collect(), + }) + }; + + // Use incremental event ID starting from 0 + let event_id = self.next_event_id; + self.next_event_id += 1; + + let event = KvCacheEvent { + event_id, + data: event_data, + dp_rank: self.dp_rank, + }; + + if let Err(e) = publisher.publish(event) { + tracing::warn!("Failed to publish KV event: {e}"); } } /// Process a MoveBlock instruction synchronously pub fn process(&mut self, event: &MoveBlock) -> bool { match event { - MoveBlock::Use(hashes) => { + MoveBlock::Use(hashes, local_hashes) => { let mut blocks_stored = Vec::::new(); let mut parent_block: Option<&UniqueBlock> = None; @@ -138,16 +198,20 @@ impl KvManager { let Some(evicted) = self.inactive_blocks.evict() else { return false; }; + tracing::trace!( + "Evicting block from inactive pool: {evicted:?}, dp_rank={}", + self.dp_rank + ); self.all_blocks.remove(&evicted); if let UniqueBlock::FullBlock(evicted_full_block) = evicted { - self.send_block_response(vec![evicted_full_block], false, false, None); + self.publish_kv_event(vec![evicted_full_block], &[], None, false); } } // Now insert the new block in active blocks with reference count 1 self.active_blocks.insert(hash.clone(), 1); self.all_blocks.insert(hash.clone()); - if self.move_block_response_tx.is_some() + if self.kv_event_publisher.is_some() && let UniqueBlock::FullBlock(stored_full_block) = hash { blocks_stored.push(*stored_full_block); @@ -159,32 +223,32 @@ impl KvManager { Some(UniqueBlock::FullBlock(block)) => Some(*block), Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"), }; - self.send_block_response(blocks_stored, false, true, parent_hash); + self.publish_kv_event(blocks_stored, local_hashes, parent_hash, true); } MoveBlock::Destroy(hashes) => { let mut blocks_destroyed = Vec::::new(); - // Loop in inverse direction - for hash in hashes.iter().rev() { + // Process blocks in order (already reversed by caller if needed) + for hash in hashes.iter() { self.active_blocks.remove(hash).unwrap(); // Remove from all_blocks when destroyed assert!(self.all_blocks.remove(hash)); // Track blocks for batch sending - if self.move_block_response_tx.is_some() + if self.kv_event_publisher.is_some() && let UniqueBlock::FullBlock(destroyed_full_block) = hash { blocks_destroyed.push(*destroyed_full_block); } } - self.send_block_response(blocks_destroyed, true, false, None); + self.publish_kv_event(blocks_destroyed, &[], None, false); } MoveBlock::Deref(hashes) => { - // Loop in inverse direction - for hash in hashes.iter().rev() { + // Process blocks in order (already reversed by caller if needed) + for hash in hashes.iter() { // Decrement reference count and check if we need to move to inactive if let Some(ref_count) = self.active_blocks.get_mut(hash) { if *ref_count == 0 { @@ -202,24 +266,30 @@ impl KvManager { } } - MoveBlock::Promote(uuid, hash, parent_hash) => { + MoveBlock::Promote(uuid, hash, parent_hash, local_hash) => { let uuid_block = UniqueBlock::PartialBlock(*uuid); let hash_block = UniqueBlock::FullBlock(*hash); - let Some(ref_count) = self.active_blocks.remove(&uuid_block) else { - let in_all_blocks = self.all_blocks.contains(&uuid_block); - panic!( - "Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}" - ); + assert_eq!( + self.active_blocks.remove(&uuid_block), + Some(1), + "uuid_block {uuid_block:?} should exist and be unique with ref_count=1" + ); + + let hash_ref_count = if let Some(ref_count) = self.active_blocks.get(&hash_block) { + *ref_count + } else if self.inactive_blocks.remove(&hash_block) { + 0 + } else { + self.publish_kv_event(vec![*hash], &[*local_hash], *parent_hash, true); + 0 }; - // Replace with hash block, keeping the same reference count - self.active_blocks.insert(hash_block.clone(), ref_count); + self.active_blocks + .insert(hash_block.clone(), hash_ref_count + 1); - // Update all_blocks assert!(self.all_blocks.remove(&uuid_block)); self.all_blocks.insert(hash_block); - self.send_block_response(vec![*hash], false, true, *parent_hash); } } @@ -291,7 +361,6 @@ impl KvManager { #[cfg(test)] mod tests { use super::*; - use tokio::sync::mpsc; #[test] fn test_failure_on_max_capacity() { @@ -300,8 +369,9 @@ mod tests { // Helper function to use multiple blocks that returns the response fn use_blocks(manager: &mut KvManager, ids: Vec) -> bool { - let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); - manager.process(&MoveBlock::Use(blocks)) + let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect(); + let hashes: Vec<_> = ids.into_iter().collect(); + manager.process(&MoveBlock::Use(blocks, hashes)) } // First use 10 blocks (0 to 9) in a batch @@ -321,16 +391,14 @@ mod tests { #[test] fn test_block_lifecycle_stringent() { - // Create a channel to listen to block responses - let (tx, mut rx) = mpsc::unbounded_channel::(); - - // Create a KvManager with 10 blocks capacity and the response sender - let mut manager = KvManager::new_with_sender(10, 16, Some(tx)); + // Create a KvManager with 10 blocks capacity (no KV event publisher for tests) + let mut manager = KvManager::new(10, 16); // Helper function to use multiple blocks fn use_blocks(manager: &mut KvManager, ids: Vec) { - let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); - manager.process(&MoveBlock::Use(blocks)); + let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect(); + let hashes: Vec<_> = ids.into_iter().collect(); + manager.process(&MoveBlock::Use(blocks, hashes)); } // Helper function to destroy multiple blocks @@ -345,56 +413,6 @@ mod tests { manager.process(&MoveBlock::Deref(blocks)); } - // Helper function to assert block responses - fn assert_block_response( - rx: &mut mpsc::UnboundedReceiver, - expected_type: &str, - expected_blocks: Vec, - description: &str, - ) { - let response = rx - .try_recv() - .unwrap_or_else(|_| panic!("Expected {expected_type} response {description}")); - - match (&response, expected_type) { - (MoveBlockResponse::Store(blocks, _parent_hash), "Store") => { - assert_eq!( - blocks.len(), - expected_blocks.len(), - "Expected {} blocks in Store response {}", - expected_blocks.len(), - description - ); - assert_eq!( - *blocks, expected_blocks, - "Store blocks don't match expected {description}" - ); - } - (MoveBlockResponse::Remove(blocks), "Remove") => { - assert_eq!( - blocks.len(), - expected_blocks.len(), - "Expected {} blocks in Remove response {}", - expected_blocks.len(), - description - ); - assert_eq!( - *blocks, expected_blocks, - "Remove blocks don't match expected {description}" - ); - } - _ => panic!("Expected {expected_type} response, got {response:?} {description}"), - } - } - - // Helper function to assert no response is received - fn assert_no_response( - rx: &mut mpsc::UnboundedReceiver, - description: &str, - ) { - assert!(rx.try_recv().is_err(), "Expected no response {description}",); - } - // Helper function to check if active blocks contain expected blocks with expected ref counts fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) { assert_eq!( @@ -442,11 +460,9 @@ mod tests { // First use blocks 0, 1, 2, 3, 4 in a batch use_blocks(&mut manager, (0..5).collect()); - assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use"); // Then use blocks 0, 1, 5, 6 in a batch use_blocks(&mut manager, vec![0, 1, 5, 6]); - assert_block_response(&mut rx, "Store", vec![5, 6], "after second use"); // Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2 assert_active_blocks( @@ -456,11 +472,9 @@ mod tests { // Now destroy block 4 destroy_blocks(&mut manager, vec![4]); - assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4"); // And deref blocks 3, 2, 1, 0 in this order as a batch deref_blocks(&mut manager, vec![0, 1, 2, 3]); - assert_no_response(&mut rx, "after deref operation"); // Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2 assert_inactive_blocks(&manager, 2, &[3, 2]); @@ -468,7 +482,6 @@ mod tests { // Now destroy block 6 destroy_blocks(&mut manager, vec![6]); - assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction"); // And deref blocks 5, 1, 0 as a batch deref_blocks(&mut manager, vec![0, 1, 5]); @@ -479,7 +492,6 @@ mod tests { // Now use 0, 1, 2, 7, 8, 9 as a batch use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]); - assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use"); // Check that the inactive_blocks is size 2, and contains 3 and 5 assert_inactive_blocks(&manager, 2, &[3, 5]); @@ -494,14 +506,10 @@ mod tests { // Now use blocks 10, 11, 12 as a batch use_blocks(&mut manager, vec![10, 11, 12]); - assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction"); - assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use"); // Check that the inactive_blocks is size 1 and contains only 5 assert_inactive_blocks(&manager, 1, &[5]); use_blocks(&mut manager, vec![13]); - assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction"); - assert_block_response(&mut rx, "Store", vec![13], "after block 13 use"); } } diff --git a/lib/llm/src/mocker/protocols.rs b/lib/llm/src/mocker/protocols.rs index 8e0bb425d5..29f216dc46 100644 --- a/lib/llm/src/mocker/protocols.rs +++ b/lib/llm/src/mocker/protocols.rs @@ -7,23 +7,19 @@ use std::collections::{HashMap, HashSet}; use std::path::Path; use uuid::Uuid; -use crate::kv_router::protocols::{ - ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, - KvCacheStoredBlockData, LocalBlockHash, -}; use crate::tokens::blocks::UniqueBlock; use crate::tokens::{BlockHash, SequenceHash, Token}; pub type NumBlocks = usize; /// Represents different block movement operations in the cache -/// For Use and Promote variants, parent hash is the second field +/// For Use and Promote variants, block hashes are included for KV event publishing #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum MoveBlock { - Use(Vec), + Use(Vec, Vec), Destroy(Vec), Deref(Vec), - Promote(Uuid, SequenceHash, Option), + Promote(Uuid, SequenceHash, Option, BlockHash), } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -50,7 +46,7 @@ pub struct PrefillCost { impl PrefillCost { pub fn predict_prefill_compute(&self, new_tokens: Option) -> f64 { let tokens = new_tokens.unwrap_or(self.new_tokens); - 1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1 + 4.209989e-07 * (tokens as f64).powi(2) + 1.518344e-02 * (tokens as f64) + 1.650142e+01 } } @@ -260,49 +256,6 @@ impl MockEngineArgs { } } -/// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData. -/// -/// This function assumes that the stored sequence hashes in the response always -/// correspond to the tail part of the local hashes array. This is the expected -/// behavior of KV block storage, where blocks are stored sequentially and the -/// response contains the most recent blocks that were stored. -/// -/// # Panics -/// Panics if the number of blocks in the Store response exceeds the length -/// of local_hashes. -pub fn block_response_to_kv_event( - response: MoveBlockResponse, - local_hashes: &[BlockHash], -) -> KvCacheEventData { - match response { - MoveBlockResponse::Store(full_blocks, parent_hash) => { - let num_blocks = full_blocks.len(); - let local_hashes_slice = &local_hashes[local_hashes - .len() - .checked_sub(num_blocks) - .expect("local hashes fewer than block response signal")..]; - - KvCacheEventData::Stored(KvCacheStoreData { - parent_hash: parent_hash.map(ExternalSequenceBlockHash), - blocks: full_blocks - .into_iter() - .zip(local_hashes_slice.iter()) - .map(|(global_hash, local_hash)| KvCacheStoredBlockData { - block_hash: ExternalSequenceBlockHash(global_hash), - tokens_hash: LocalBlockHash(*local_hash), - }) - .collect(), - }) - } - MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData { - block_hashes: full_blocks - .into_iter() - .map(ExternalSequenceBlockHash) - .collect(), - }), - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/lib/llm/src/mocker/running_mean.rs b/lib/llm/src/mocker/running_mean.rs new file mode 100644 index 0000000000..5ae65e2b87 --- /dev/null +++ b/lib/llm/src/mocker/running_mean.rs @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::VecDeque; +use std::ops::{Add, Div, Sub}; + +/// A generic running mean calculator with a fixed-size sliding window. +/// Maintains a running sum and count to compute the mean in O(1) time. +#[derive(Debug, Clone)] +pub struct RunningMean +where + T: Copy + Add + Sub + Div + Default + From, +{ + max_size: u16, + sum: T, + values: VecDeque, +} + +impl RunningMean +where + T: Copy + Add + Sub + Div + Default + From, +{ + pub fn new(max_size: u16) -> Self { + Self { + max_size, + sum: T::default(), + values: VecDeque::with_capacity(max_size as usize), + } + } + + pub fn push(&mut self, value: T) { + // If at capacity, remove the oldest value from sum + if self.values.len() >= self.max_size as usize + && let Some(old_value) = self.values.pop_front() + { + self.sum = self.sum - old_value; + } + + // Add new value + self.sum = self.sum + value; + self.values.push_back(value); + } + + pub fn mean(&self) -> T { + if self.values.is_empty() { + T::default() + } else { + self.sum / T::from(self.values.len() as u16) + } + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + /// Clear all values from the window. + pub fn clear(&mut self) { + self.sum = T::default(); + self.values.clear(); + } +} diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs index adc764c475..4026232765 100644 --- a/lib/llm/src/mocker/scheduler.rs +++ b/lib/llm/src/mocker/scheduler.rs @@ -28,16 +28,16 @@ //! ## NOTE //! The current prefill and decoding time simulations are not scientific at all and are WIP -use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats}; +use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; use crate::mocker::evictor::LRUEvictor; use crate::mocker::kv_manager::KvManager; -use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse}; -use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event}; +use crate::mocker::protocols::{ + DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType, +}; +use crate::mocker::running_mean::RunningMean; use crate::mocker::sequence::ActiveSequence; -use crate::tokens::BlockHash; use crate::tokens::blocks::UniqueBlock; -use std::collections::HashMap; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use tokio::sync::mpsc; use tokio::time::Duration; use tokio_util::sync::CancellationToken; @@ -111,9 +111,8 @@ impl SchedulerState { /// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where: /// - `prefill_compute`: The compute time in milliseconds for this prefill operation /// - `creation_signal`: Optional MoveBlock signal for KV cache block creation - /// - `block_hashes`: Block hashes of the sequence beign prefilled /// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked - fn try_prefill(&mut self) -> Option<(f64, Option, Vec, bool)> { + fn try_prefill(&mut self) -> Option<(f64, Option, bool)> { let uuid = self.prefill.pop_front()?; // Remove and extract prefill_compute from prefill_costs @@ -168,7 +167,6 @@ impl SchedulerState { Some(( prefill_compute, sequence.take_creation_signal(), - sequence.block_hashes(), is_full_prefill, )) } @@ -247,17 +245,9 @@ impl Scheduler { args: MockEngineArgs, dp_rank: u32, output_tx: Option>, - kv_events_tx: Option>, + component: Option, cancellation_token: Option, ) -> Self { - // Create internal channel for KV events only if needed - let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() { - let (tx, rx) = mpsc::unbounded_channel::(); - (Some(tx), Some(rx)) - } else { - (None, None) - }; - // Assert speedup_ratio is greater than 0 assert!( args.speedup_ratio > 0.0, @@ -278,121 +268,64 @@ impl Scheduler { tokio::spawn(async move { // Create state and kv_manager as local variables owned by this task let mut state = SchedulerState::new(args.max_num_batched_tokens); - let mut kv_manager = - KvManager::new_with_sender(args.num_gpu_blocks, args.block_size, block_resp_tx); - let mut hit_rates = VecDeque::with_capacity(1000); - let mut should_schedule = true; + let mut kv_manager = KvManager::new_with_publisher( + args.num_gpu_blocks, + args.block_size, + component, + dp_rank, + ); + let mut hit_rates = RunningMean::new(1000); loop { - { - // Enqueue new request, blocks until at least one is received, so no redundant work is done - if state.is_empty() { - let Some(request) = request_rx.recv().await else { - tracing::warn!("request sender is dropped"); + // 1. Receive requests + if state.is_empty() { + // Fully idle - block until new request arrives + tokio::select! { + biased; + Some(request) = request_rx.recv() => { + state.receive(request); + } + _ = cancel_token_clone.cancelled() => { break; - }; - state.receive(request); + } } - } - - tokio::select! { - biased; - - // Enqueue new request - Some(request) = request_rx.recv() => { + } else { + // Has active/waiting work - collect any pending requests without blocking + while let Ok(request) = request_rx.try_recv() { state.receive(request); } - // Try Scheduling Requests - runs on normal interval or after simulation - _ = tokio::task::yield_now() => { - // Skip if we just ran scheduling after simulation to prevent consecutive runs - if !should_schedule { - continue; - } - - // Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't - // schedule anymore. - let mut current_blocks = kv_manager.num_active_blocks(); - let mut current_tokens = state.active_tokens + state.waiting_tokens; - let mut current_seqs = state.num_active_requests(); - - while let Some((uuid, request)) = state.next() { - let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching); - - // Update predictive budgets - let prefill_cost = kv_manager.get_prefill_cost(&active_sequence); - let total_tokens = active_sequence.len(); - // this is conservative, assumes no cache hit so never over-schedules - let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize; - let new_tokens = prefill_cost.new_tokens; - - current_blocks += new_blocks; - current_tokens += new_tokens; - current_seqs += 1; - - // Check various budgets to see if possible to schedule - let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64; - // If chunked prefill is enabled, we can be under token budget when scheduling - let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens}; - let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit); - let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit); - - // Cannot schedule, put first in line instead - if !(under_block_budget && under_token_budget && under_seq_budget) { - state.first_in_line(uuid, Request::Active(active_sequence)); - break; - } - - // Compute and store hit rate - let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 }; - hit_rates.push_back(hit_rate); - if hit_rates.len() > 1000 { - hit_rates.pop_front(); - } - - state.move_to_prefill(uuid, active_sequence, prefill_cost); - should_schedule = false; - } - } - // Check for cancellation - _ = cancel_token_clone.cancelled() => { + if cancel_token_clone.is_cancelled() { break; } } - // Simulates prefill + decode - // Base time needed for decoding using active percentage and quadratic formula - let active_perc = kv_manager.get_active_perc(); - let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44; - let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0); + // Start timing for this forward pass (schedule + simulate) + let iteration_start = std::time::Instant::now(); + + // 2. Schedule waiting requests (once per iteration) + try_schedule(&mut state, &kv_manager, &mut hit_rates, &args); + + // 3. Simulate prefill + decode + let mut total_time = Duration::ZERO; // Process prefilling - while let Some(( - prefill_compute, - maybe_creation_signal, - block_hashes, - is_full_prefill, - )) = state.try_prefill() + while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = + state.try_prefill() { // NOTE: Prefill cost/time is always incremented for new blocks, even if they // could be cached by other requests in the same batch. This matches vLLM behavior. - total_time += Duration::from_secs_f64(prefill_compute / 1000.0); - - if let Some(creation_signal) = maybe_creation_signal { - if !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal)) - { - panic!("Block allocation for prefilling cannot fail."); - } + // For decode workers, skip adding prefill compute time + if args.worker_type != WorkerType::Decode { + total_time += Duration::from_secs_f64(prefill_compute / 1000.0); + } - // Drain KV events and forward to relay after prefill signal processing - if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) { - while let Ok(event) = rx.try_recv() { - let _ = - relay_tx.send(block_response_to_kv_event(event, &block_hashes)); - } - } - }; + if let Some(creation_signal) = maybe_creation_signal + && !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal)) + { + panic!("Block allocation for prefilling cannot fail."); + } // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill if !is_full_prefill { @@ -400,13 +333,14 @@ impl Scheduler { } } + let active_perc = kv_manager.get_active_perc(); + let decoding_time = -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74; + total_time += Duration::from_secs_f64(decoding_time / 1000.0); + state.reset_active_tokens(); // Process decoding let uuids: Vec = state.decode.keys().cloned().collect(); - if !uuids.is_empty() { - should_schedule = true - }; for uuid in uuids { let Some(sequence) = state.run(uuid) else { continue; @@ -423,14 +357,6 @@ impl Scheduler { continue; } - // Drain KV events and forward to relay after decode signal processing - if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) { - while let Ok(event) = rx.try_recv() { - let _ = relay_tx - .send(block_response_to_kv_event(event, &sequence.block_hashes())); - } - } - // Check completion and send notification let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); let should_output = @@ -465,11 +391,13 @@ impl Scheduler { let _ = metrics_tx.send(metrics); } - // Sleep once for the adjusted duration - let adjusted_time = + // 4. Sleep to maintain target iteration timing + let target_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); - if adjusted_time.as_millis() > 0 { - tokio::time::sleep(adjusted_time).await; + let elapsed = iteration_start.elapsed(); + + if elapsed < target_duration { + tokio::time::sleep(target_duration - elapsed).await; } } }); @@ -499,7 +427,7 @@ impl Scheduler { fn get_fwd_pass_metrics( state: &SchedulerState, kv_manager: &KvManager, - hit_rates: &VecDeque, + hit_rates: &RunningMean, dp_rank: u32, ) -> ForwardPassMetrics { // Get state metrics @@ -507,7 +435,7 @@ fn get_fwd_pass_metrics( let num_requests_waiting = state.waiting.len() as u64; // Get KV manager metrics - let active_blocks_count = kv_manager.active_blocks().len() as u64; + let active_blocks_count = kv_manager.num_active_blocks() as u64; let total_capacity = kv_manager.max_capacity() as u64; let gpu_cache_usage_perc = if total_capacity > 0 { active_blocks_count as f32 / total_capacity as f32 @@ -515,13 +443,8 @@ fn get_fwd_pass_metrics( 0.0 }; - // Get hit rate metrics - let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() { - 0.0 - } else { - let sum: f32 = hit_rates.iter().sum(); - sum / hit_rates.len() as f32 - }; + // Get hit rate metrics - O(1) access + let gpu_prefix_cache_hit_rate = hit_rates.mean(); let worker_stats = WorkerStats { data_parallel_rank: Some(dp_rank), @@ -546,26 +469,75 @@ fn get_fwd_pass_metrics( } } -/// Convert a Request to an ActiveSequence -fn get_active_sequence( - request: Request, - block_size: usize, - enable_prefix_caching: bool, -) -> ActiveSequence { - if let Request::Active(active_seq) = request { - return active_seq; - } +/// Attempts to schedule waiting requests from the state queue. +/// Returns the number of requests successfully scheduled. +fn try_schedule( + state: &mut SchedulerState, + kv_manager: &KvManager, + hit_rates: &mut RunningMean, + args: &MockEngineArgs, +) -> usize { + let mut scheduled_count = 0; + let mut current_blocks = kv_manager.num_active_blocks(); + let mut current_tokens = state.active_tokens + state.waiting_tokens; + let mut current_seqs = state.num_active_requests(); + + while let Some((uuid, request)) = state.next() { + // Convert Request to ActiveSequence + let active_sequence = match request { + Request::Active(active_seq) => active_seq, + Request::Direct(direct_request) => ActiveSequence::new( + direct_request.tokens, + direct_request.max_output_tokens, + Some(args.block_size), + args.enable_prefix_caching, + ), + }; - let Request::Direct(direct_request) = request else { - unreachable!("Request must be either Direct or Active"); - }; + // Update predictive budgets + let prefill_cost = kv_manager.get_prefill_cost(&active_sequence); + let total_tokens = active_sequence.len(); + // this is conservative, assumes no cache hit so never over-schedules + let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize; + let new_tokens = prefill_cost.new_tokens; + + current_blocks += new_blocks; + current_tokens += new_tokens; + current_seqs += 1; + + // Check various budgets to see if possible to schedule + let under_block_budget = + current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64; + // If chunked prefill is enabled, we can be under token budget when scheduling + let comparison_tokens = if args.enable_chunked_prefill { + current_tokens - new_tokens + } else { + current_tokens + }; + let under_token_budget = args + .max_num_batched_tokens + .is_none_or(|limit| comparison_tokens <= limit); + let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit); + + // Cannot schedule, put first in line instead + if !(under_block_budget && under_token_budget && under_seq_budget) { + state.first_in_line(uuid, Request::Active(active_sequence)); + break; + } + + // Compute and store hit rate + let hit_rate = if !active_sequence.is_empty() { + 1.0 - (new_tokens as f32 / active_sequence.len() as f32) + } else { + 0.0 + }; + hit_rates.push(hit_rate); + + state.move_to_prefill(uuid, active_sequence, prefill_cost); + scheduled_count += 1; + } - ActiveSequence::new( - direct_request.tokens, - direct_request.max_output_tokens, - Some(block_size), - enable_prefix_caching, - ) + scheduled_count } /// Processes MoveBlock signals with the KvManager. @@ -582,7 +554,7 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool { } // Check we have a Use signal with blocks - let MoveBlock::Use(blocks) = signal else { + let MoveBlock::Use(blocks, _hashes) = signal else { panic!( "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}" ); diff --git a/lib/llm/src/mocker/sequence.rs b/lib/llm/src/mocker/sequence.rs index d467ebaa29..6f19da7897 100644 --- a/lib/llm/src/mocker/sequence.rs +++ b/lib/llm/src/mocker/sequence.rs @@ -6,12 +6,10 @@ use crate::tokens::blocks::UniqueBlock; use crate::tokens::{TokenBlockSequence, Tokens}; use derive_getters::Getters; use rand::random; -use uuid; /// Create unique blocks from a TokenBlockSequence fn create_unique_blocks_from_sequence( tokens: &TokenBlockSequence, - uuid: Option, block_size: usize, enable_prefix_caching: bool, ) -> Vec { @@ -29,10 +27,7 @@ fn create_unique_blocks_from_sequence( // Only push the partial block if tokens count isn't a multiple of block_size if !tokens.total_tokens().is_multiple_of(block_size) { - unique_blocks.push(match uuid { - Some(uuid) => UniqueBlock::PartialBlock(uuid), - None => UniqueBlock::default(), - }); + unique_blocks.push(UniqueBlock::default()); } unique_blocks } @@ -80,8 +75,9 @@ impl ActiveSequence { let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337)); let unique_blocks = - create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching); - let creation_signal = Some(MoveBlock::Use(unique_blocks.clone())); + create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching); + let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect(); + let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), block_hashes)); Self { unique_blocks, @@ -132,17 +128,6 @@ impl ActiveSequence { (sequence, signal) } - /// Get the parent hash from the second-to-last block if it exists and is a FullBlock - fn get_parent_hash(&self) -> Option { - if self.unique_blocks.len() < 2 { - return None; - } - match &self.unique_blocks[self.unique_blocks.len() - 2] { - UniqueBlock::FullBlock(hash) => Some(*hash), - _ => panic!("Cannot have a partial block as parent"), - } - } - /// Push a token to the sequence pub fn push(&mut self, token: u32) -> Option> { self.tokens.append(token).expect("Token push failed."); @@ -158,24 +143,33 @@ impl ActiveSequence { // Replace last partial block with full block if it exists if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() { - let last_block_hash = if self.enable_prefix_caching { + let last_seq_hash = if self.enable_prefix_caching { self.tokens.last_complete_block().unwrap().sequence_hash() } else { random::() }; + let last_block_hash = self.tokens.last_complete_block().unwrap().block_hash(); self.unique_blocks.pop(); + + // After pop, the last element is the parent block + let second_to_last_hash = self.unique_blocks.last().map(|block| match block { + UniqueBlock::FullBlock(hash) => *hash, + UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"), + }); + self.unique_blocks - .push(UniqueBlock::FullBlock(last_block_hash)); + .push(UniqueBlock::FullBlock(last_seq_hash)); signals.push(MoveBlock::Promote( uuid, + last_seq_hash, + second_to_last_hash, last_block_hash, - self.get_parent_hash(), )); } let new_partial_block = UniqueBlock::default(); self.unique_blocks.push(new_partial_block.clone()); - signals.push(MoveBlock::Use(vec![new_partial_block])); + signals.push(MoveBlock::Use(vec![new_partial_block], vec![])); Some(signals) } @@ -241,13 +235,15 @@ impl ActiveSequence { self.tokens.truncate(self.num_input_tokens).unwrap(); self.unique_blocks = create_unique_blocks_from_sequence( &self.tokens, - None, self.block_size, self.enable_prefix_caching, ); self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens); self.generated_tokens = 0; - self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone())); + self.creation_signal = Some(MoveBlock::Use( + self.unique_blocks.clone(), + self.block_hashes(), + )); free_signal } @@ -280,7 +276,7 @@ mod tests { // Check that we got a Use signal assert!(signal1.is_some()); match &signal1 { - Some(MoveBlock::Use(blocks)) => { + Some(MoveBlock::Use(blocks, _hashes)) => { assert_eq!(blocks.len(), 1); } _ => panic!("Expected Use signal"), @@ -301,7 +297,7 @@ mod tests { // First signal should be Promote for the previous block match &signal_16[0] { - MoveBlock::Promote(_, _, parent_hash) => { + MoveBlock::Promote(_, _, parent_hash, _hash) => { assert_eq!(*parent_hash, None); } _ => panic!("Expected Promote signal as second signal"), @@ -309,7 +305,7 @@ mod tests { // Second signal should be Use for new partial block match &signal_16[1] { - MoveBlock::Use(blocks) => { + MoveBlock::Use(blocks, _hashes) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } @@ -396,7 +392,7 @@ mod tests { // Check that signal[0] is promote match &signal[0] { - MoveBlock::Promote(_, _, parent_hash) => { + MoveBlock::Promote(_, _, parent_hash, _hash) => { // Check that the parent_hash matches unique_blocks[1], which should be a full block if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] { assert_eq!( @@ -430,7 +426,7 @@ mod tests { // Initial signal - should have received a Use signal for the partial block assert!(signal.is_some()); match signal { - Some(MoveBlock::Use(blocks)) => { + Some(MoveBlock::Use(blocks, _hashes)) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } @@ -448,7 +444,7 @@ mod tests { // First signal should be Promote match &signals_second[0] { - MoveBlock::Promote(_, _, parent_hash) => { + MoveBlock::Promote(_, _, parent_hash, _hash) => { assert_eq!(*parent_hash, None); } _ => panic!("Expected Promote signal as first signal after second token"), @@ -456,7 +452,7 @@ mod tests { // Second signal should be Use for new partial block match &signals_second[1] { - MoveBlock::Use(blocks) => { + MoveBlock::Use(blocks, _hashes) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); }