Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions lib/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ pub enum DynamoLlmResult {
pub unsafe extern "C" fn dynamo_llm_init(
namespace_c_str: *const c_char,
component_c_str: *const c_char,
worker_id: i64,
kv_block_size: u32,
) -> DynamoLlmResult {
initialize_tracing();
Expand Down Expand Up @@ -102,7 +101,7 @@ pub unsafe extern "C" fn dynamo_llm_init(

match result {
Ok(_) => match KV_PUB.get_or_try_init(move || {
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size)
dynamo_create_kv_publisher(namespace, component, kv_block_size)
}) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
Expand Down Expand Up @@ -144,7 +143,6 @@ pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
fn dynamo_create_kv_publisher(
namespace: String,
component: String,
worker_id: i64,
kv_block_size: u32,
) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component);
Expand All @@ -154,7 +152,7 @@ fn dynamo_create_kv_publisher(
{
Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(backend, worker_id as u64, kv_block_size, None)
KvEventPublisher::new(backend, kv_block_size, None)
}
Err(e) => Err(e),
}
Expand Down
11 changes: 2 additions & 9 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ impl ZmqKvEventPublisher {
fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
config.worker_id,
config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint,
Expand Down Expand Up @@ -239,20 +238,14 @@ pub(crate) struct KvEventPublisher {
#[pymethods]
impl KvEventPublisher {
#[new]
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
#[pyo3(signature = (component, kv_block_size, dp_rank=0))]
fn new(component: Component, kv_block_size: usize, dp_rank: DpRank) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}

let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
worker_id,
kv_block_size as u32,
None,
)
Expand Down
7 changes: 5 additions & 2 deletions lib/llm/src/kv_router/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,10 @@ impl RadixTree {
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = ?worker.dp_rank,
dp_rank = worker.dp_rank,
id,
parent_hash = ?op.parent_hash,
num_blocks = op.blocks.len(),
"Failed to find parent block; skipping store operation"
);
return Err(KvCacheEventError::ParentBlockNotFound);
Expand Down Expand Up @@ -412,8 +413,10 @@ impl RadixTree {
Some(entry) => entry.clone(),
None => {
tracing::warn!(
worker_id = worker_id.to_string(),
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block,
"Failed to find block to remove; skipping remove operation"
);
return Err(KvCacheEventError::BlockNotFound);
Expand Down
8 changes: 7 additions & 1 deletion lib/llm/src/kv_router/publisher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,20 @@ pub struct KvEventPublisher {
impl KvEventPublisher {
pub fn new(
component: Component,
worker_id: u64,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();

let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();

// Infer worker_id from component's primary lease
let worker_id = component
.drt()
.primary_lease()
.expect("Cannot publish KV events without lease")
.id();

// Create our event source (if any)
let mut source = None;
if let Some(config) = source_config {
Expand Down
1 change: 1 addition & 0 deletions lib/llm/src/mocker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ pub mod engine;
pub mod evictor;
pub mod kv_manager;
pub mod protocols;
pub mod running_mean;
pub mod scheduler;
pub mod sequence;
139 changes: 30 additions & 109 deletions lib/llm/src/mocker/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal};
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
Expand All @@ -23,9 +23,6 @@ use dynamo_runtime::{
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
};

use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use crate::kv_router::publisher::KvEventPublisher;
use futures::StreamExt;
use rand::Rng;
use std::collections::HashMap;
Expand All @@ -37,10 +34,9 @@ use uuid::Uuid;

pub const MOCKER_COMPONENT: &str = "mocker";

/// Generate a random token ID from 1k to 5k
fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng();
rng.random_range(1000..5000)
rng.random_range(100..200)
}

/// AsyncEngine wrapper around the Scheduler that generates random character tokens
Expand Down Expand Up @@ -71,26 +67,25 @@ impl MockVllmEngine {
tracing::info!("Engine startup simulation completed");
}

let (schedulers, kv_event_receiver) = self.start_schedulers(
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode
{
Some(component.clone())
} else {
None
};

let schedulers = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
scheduler_component,
cancel_token.clone(),
);

Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?;

// Start KV events publishing with the actual receivers from schedulers
if self.engine_args.enable_prefix_caching {
Self::start_kv_events_publishing(
kv_event_receiver,
Some(component.clone()),
self.engine_args.block_size,
cancel_token.clone(),
)
.await?;
}

Ok(())
}

Expand All @@ -100,39 +95,31 @@ impl MockVllmEngine {
}

/// Create schedulers and spawn their background tasks for distributing token notifications
/// Returns schedulers and their corresponding KV event receivers
fn start_schedulers(
&self,
args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
cancel_token: CancellationToken,
) -> (
Vec<Scheduler>,
Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
) {
) -> Vec<Scheduler> {
let mut schedulers = Vec::<Scheduler>::new();
let mut kv_event_receivers = Vec::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);

// Create multiple schedulers and their background tasks
for dp_rank in 0..args.dp_size {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

// Create a channel for KV events from this scheduler
let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::<KvCacheEventData>();

let scheduler = Scheduler::new(
args.clone(),
dp_rank,
Some(output_tx),
Some(kv_events_tx), // Pass the KV events sender to scheduler
component.clone(),
Some(cancel_token.clone()),
);

senders.push(scheduler.request_sender());
schedulers.push(scheduler);
kv_event_receivers.push(kv_events_rx);

// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
Expand Down Expand Up @@ -166,7 +153,7 @@ impl MockVllmEngine {
.set(senders)
.expect("Already initialized");

(schedulers, kv_event_receivers)
schedulers
}

/// Start background tasks to publish metrics on change
Expand Down Expand Up @@ -228,83 +215,6 @@ impl MockVllmEngine {
tracing::info!("Metrics background tasks started");
Ok(())
}

/// Start background tasks to collect and publish KV events from schedulers
async fn start_kv_events_publishing(
kv_event_receivers: Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
component: Option<Component>,
block_size: usize,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::debug!("Starting KV events publishing");

// Only start KV events publishing if we have a component
let Some(comp) = component else {
tracing::warn!("No component provided, skipping KV events publishing");
return Ok(());
};
tracing::debug!("Component found for KV events publishing");

tracing::debug!("Getting worker_id");
let worker_id = comp
.drt()
.primary_lease()
.expect("Cannot publish KV events without lease") // ← This will PANIC on static!
.id();
// let worker_id = 0;
tracing::debug!("Worker_id set to: {worker_id}");

tracing::debug!("Creating KV event publisher");
let kv_event_publisher = Arc::new(KvEventPublisher::new(
comp.clone(),
worker_id,
block_size as u32,
None,
)?);
tracing::debug!("KV event publisher created");

tracing::debug!(
"Starting KV event background tasks for {} receivers",
kv_event_receivers.len()
);
for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() {
tracing::debug!("Starting background task for DP rank {dp_rank}");
let publisher = kv_event_publisher.clone();
let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone();

tokio::spawn(async move {
tracing::debug!("Background task started for DP rank {dp_rank}");
loop {
tokio::select! {
// Receive actual KV events from the scheduler
Some(event_data) = kv_events_rx.recv() => {
// Convert KvCacheEventData to KvCacheEvent with random UUID as event_id
let event = KvCacheEvent {
event_id: Uuid::new_v4().as_u128() as u64,
data: event_data,
dp_rank,
};

// Publish the event
if let Err(e) = publisher.publish(event) {
tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}");
} else {
tracing::trace!("Published KV event for DP rank {dp_rank}");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}");
break;
}
}
}
});
}
tracing::info!("All KV event background tasks started");

Ok(())
}
}

#[async_trait]
Expand Down Expand Up @@ -356,7 +266,13 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>

let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize;
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
// Override max_tokens to 1 for prefill workers
let max_tokens = if is_prefill {
1
} else {
request.stop_conditions.max_tokens.unwrap() as usize
};

// Spawn a task to handle the complex async logic
tokio::spawn(async move {
Expand All @@ -383,7 +299,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None,
finish_reason: None,
index: None,
disaggregated_params: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None,
};

Expand Down
Loading
Loading