From c684ef0313846cfe54ebeb361a5bc47376fe7177 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 2 Apr 2026 08:02:39 +0530 Subject: [PATCH 1/2] Feat: RLM routing --- src/config/types.rs | 15 ++++++++++ src/config/validation.rs | 9 +++++- src/core/worker.rs | 29 +++++++++++++++++++ src/core/worker_registry.rs | 18 +++++++++++- src/lib.rs | 2 ++ src/main.rs | 45 +++++++++++++++++++---------- src/routers/factory.rs | 33 ++++++++++++++------- src/routers/http/pd_router.rs | 14 +++++++++ src/routers/http/vllm_pd_router.rs | 46 ++++++++++++++++++++++++------ src/routers/router_manager.rs | 3 +- src/server.rs | 10 +++++-- tests/api_endpoints_test.rs | 1 + tests/test_dp_routing.rs | 1 + tests/test_pd_routing.rs | 3 ++ 14 files changed, 189 insertions(+), 40 deletions(-) diff --git a/src/config/types.rs b/src/config/types.rs index d46372e8..46e284cb 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -128,6 +128,9 @@ pub enum RoutingMode { PrefillDecode { /// Prefill worker URLs with optional bootstrap ports prefill_urls: Vec<(String, Option)>, + /// Cold prefill worker URLs with optional bootstrap ports (for is_sub_llm requests) + #[serde(default, skip_serializing_if = "Vec::is_empty")] + cold_prefill_urls: Vec<(String, Option)>, /// Decode worker URLs decode_urls: Vec, /// Optional separate policy for prefill workers @@ -146,6 +149,9 @@ pub enum RoutingMode { VllmPrefillDecode { /// Prefill worker URLs with optional bootstrap ports prefill_urls: Vec<(String, Option)>, + /// Cold prefill worker URLs with optional bootstrap ports (for is_sub_llm requests) + #[serde(default, skip_serializing_if = "Vec::is_empty")] + cold_prefill_urls: Vec<(String, Option)>, /// Decode worker URLs decode_urls: Vec, /// Optional separate policy for prefill workers @@ -638,6 +644,7 @@ mod tests { let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode1".to_string()], prefill_policy: None, decode_policy: None, @@ -661,6 +668,7 @@ mod tests { ("http://prefill1".to_string(), Some(8001)), ("http://prefill2".to_string(), None), ], + cold_prefill_urls: vec![], decode_urls: vec![ "http://decode1".to_string(), "http://decode2".to_string(), @@ -690,6 +698,7 @@ mod tests { // Test PrefillDecode mode let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode1".to_string()], prefill_policy: None, decode_policy: None, @@ -888,6 +897,7 @@ mod tests { let config = RouterConfig { mode: RoutingMode::PrefillDecode { prefill_urls: vec![], + cold_prefill_urls: vec![], decode_urls: vec![], prefill_policy: None, decode_policy: None, @@ -1010,6 +1020,7 @@ mod tests { ("http://prefill1:8000".to_string(), Some(8001)), ("http://prefill2:8000".to_string(), None), ], + cold_prefill_urls: vec![], decode_urls: vec![ "http://decode1:8000".to_string(), "http://decode2:8000".to_string(), @@ -1213,6 +1224,7 @@ mod tests { // When both prefill and decode policies are specified, they should be used let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), None)], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode1".to_string()], prefill_policy: Some(PolicyConfig::CacheAware { cache_threshold: 0.5, @@ -1245,6 +1257,7 @@ mod tests { // When only prefill policy is specified, decode should use main policy let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), None)], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode1".to_string()], prefill_policy: Some(PolicyConfig::CacheAware { cache_threshold: 0.5, @@ -1276,6 +1289,7 @@ mod tests { // When only decode policy is specified, prefill should use main policy let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), None)], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode1".to_string()], prefill_policy: None, decode_policy: Some(PolicyConfig::PowerOfTwo { @@ -1303,6 +1317,7 @@ mod tests { // When no specific policies are specified, both should use main policy let pd = RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1".to_string(), None)], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode1".to_string()], prefill_policy: None, decode_policy: None, diff --git a/src/config/validation.rs b/src/config/validation.rs index 2e4dffb3..7e1f9af4 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -56,6 +56,7 @@ impl ConfigValidator { decode_urls, prefill_policy, decode_policy, + .. } => { // Only require URLs if service discovery is disabled if !has_service_discovery { @@ -107,7 +108,7 @@ impl ConfigValidator { decode_urls, prefill_policy, decode_policy, - discovery_address: _, + .. } => { // Only require URLs if service discovery is disabled if !has_service_discovery { @@ -480,6 +481,7 @@ impl ConfigValidator { decode_urls, prefill_policy, decode_policy, + .. } = &config.mode { // Check power-of-two for prefill @@ -665,6 +667,7 @@ mod tests { let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode:8000".to_string()], prefill_policy: None, decode_policy: None, @@ -681,6 +684,7 @@ mod tests { let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode:8000".to_string()], prefill_policy: None, decode_policy: None, @@ -698,6 +702,7 @@ mod tests { let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], + cold_prefill_urls: vec![], decode_urls: vec!["http://decode:8000".to_string()], prefill_policy: None, decode_policy: None, @@ -743,6 +748,7 @@ mod tests { ("http://prefill1:8000".to_string(), None), ("http://prefill2:8000".to_string(), None), ], + cold_prefill_urls: vec![], decode_urls: vec![ "http://decode1:8000".to_string(), "http://decode2:8000".to_string(), @@ -771,6 +777,7 @@ mod tests { let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill + cold_prefill_urls: vec![], decode_urls: vec![ "http://decode1:8000".to_string(), "http://decode2:8000".to_string(), diff --git a/src/core/worker.rs b/src/core/worker.rs index 07e40da6..8fb8f98e 100644 --- a/src/core/worker.rs +++ b/src/core/worker.rs @@ -291,6 +291,11 @@ pub enum WorkerType { /// Bootstrap port for communication with decode workers bootstrap_port: Option, }, + /// Cold prefill worker for `is_sub_llm` requests + ColdPrefill { + /// Bootstrap port for communication with decode workers + bootstrap_port: Option, + }, /// Decode worker for PD disaggregated mode Decode, } @@ -303,6 +308,10 @@ impl fmt::Display for WorkerType { Some(port) => write!(f, "Prefill(bootstrap:{})", port), None => write!(f, "Prefill"), }, + WorkerType::ColdPrefill { bootstrap_port } => match bootstrap_port { + Some(port) => write!(f, "ColdPrefill(bootstrap:{})", port), + None => write!(f, "ColdPrefill"), + }, WorkerType::Decode => write!(f, "Decode"), } } @@ -762,6 +771,26 @@ impl WorkerFactory { ) } + /// Create a cold prefill worker with optional bootstrap port + pub fn create_cold_prefill(url: String, bootstrap_port: Option) -> Box { + Box::new(BasicWorker::new( + url, + WorkerType::ColdPrefill { bootstrap_port }, + )) + } + + /// Create a cold prefill worker with custom circuit breaker configuration + pub fn create_cold_prefill_with_config( + url: String, + bootstrap_port: Option, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + Box::new( + BasicWorker::new(url, WorkerType::ColdPrefill { bootstrap_port }) + .with_circuit_breaker_config(circuit_breaker_config), + ) + } + /// Create a decode worker pub fn create_decode(url: String) -> Box { Box::new(BasicWorker::new(url, WorkerType::Decode)) diff --git a/src/core/worker_registry.rs b/src/core/worker_registry.rs index 74e1814b..ebce9b9e 100644 --- a/src/core/worker_registry.rs +++ b/src/core/worker_registry.rs @@ -240,6 +240,20 @@ impl WorkerRegistry { .collect() } + /// Get all cold prefill workers (regardless of bootstrap_port) + pub fn get_cold_prefill_workers(&self) -> Vec> { + self.workers + .iter() + .filter_map(|entry| { + let worker = entry.value(); + match worker.worker_type() { + WorkerType::ColdPrefill { .. } => Some(worker.clone()), + _ => None, + } + }) + .collect() + } + /// Get all decode workers pub fn get_decode_workers(&self) -> Vec> { self.get_by_type(&WorkerType::Decode) @@ -416,7 +430,9 @@ impl WorkerRegistry { match worker.worker_type() { WorkerType::Regular => regular_count += 1, - WorkerType::Prefill { .. } => prefill_count += 1, + WorkerType::Prefill { .. } | WorkerType::ColdPrefill { .. } => { + prefill_count += 1 + } WorkerType::Decode => decode_count += 1, } } diff --git a/src/lib.rs b/src/lib.rs index 4150bf76..0fe4c110 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -154,6 +154,7 @@ impl Router { } else if self.vllm_pd_disaggregation { RoutingMode::VllmPrefillDecode { prefill_urls: self.prefill_urls.clone().unwrap_or_default(), + cold_prefill_urls: vec![], decode_urls: self.decode_urls.clone().unwrap_or_default(), prefill_policy: self.prefill_policy.as_ref().map(convert_policy), decode_policy: self.decode_policy.as_ref().map(convert_policy), @@ -162,6 +163,7 @@ impl Router { } else if self.pd_disaggregation { RoutingMode::PrefillDecode { prefill_urls: self.prefill_urls.clone().unwrap_or_default(), + cold_prefill_urls: vec![], decode_urls: self.decode_urls.clone().unwrap_or_default(), prefill_policy: self.prefill_policy.as_ref().map(convert_policy), decode_policy: self.decode_policy.as_ref().map(convert_policy), diff --git a/src/main.rs b/src/main.rs index 93cf5374..a6aa38ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,24 +9,32 @@ use vllm_router_rs::metrics::PrometheusConfig; use vllm_router_rs::server::{self, ServerConfig}; use vllm_router_rs::service_discovery::ServiceDiscoveryConfig; -// Helper function to parse prefill arguments from command line -// Returns prefill_entries with (URL, optional_bootstrap_port) +/// Parse `--prefill [port]` arguments from the raw command line. fn parse_prefill_args() -> Vec<(String, Option)> { + parse_url_port_args("--prefill") +} + +/// Parse `--cold-prefill [port]` arguments from the raw command line. +fn parse_cold_prefill_args() -> Vec<(String, Option)> { + parse_url_port_args("--cold-prefill") +} + +/// Generic parser for `--flag [port]` repeated arguments. +fn parse_url_port_args(flag: &str) -> Vec<(String, Option)> { let args: Vec = std::env::args().collect(); - let mut prefill_entries = Vec::new(); + let mut entries = Vec::new(); let mut i = 0; while i < args.len() { - if args[i] == "--prefill" && i + 1 < args.len() { + if args[i] == flag && i + 1 < args.len() { let url = args[i + 1].clone(); let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") { - // Check if next arg is a port number if let Ok(port) = args[i + 2].parse::() { - i += 1; // Skip the port argument + i += 1; Some(port) } else if args[i + 2].to_lowercase() == "none" { - i += 1; // Skip the "none" argument + i += 1; None } else { None @@ -34,14 +42,14 @@ fn parse_prefill_args() -> Vec<(String, Option)> { } else { None }; - prefill_entries.push((url, bootstrap_port)); - i += 2; // Skip --prefill and URL + entries.push((url, bootstrap_port)); + i += 2; } else { i += 1; } } - prefill_entries + entries } #[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)] @@ -417,6 +425,7 @@ impl CliArgs { fn to_router_config( &self, prefill_urls: Vec<(String, Option)>, + cold_prefill_urls: Vec<(String, Option)>, ) -> ConfigResult { // Validate mutually exclusive modes if self.pd_disaggregation && self.vllm_pd_disaggregation { @@ -449,6 +458,7 @@ impl CliArgs { RoutingMode::PrefillDecode { prefill_urls, + cold_prefill_urls, decode_urls, prefill_policy: self.prefill_policy.as_ref().map(|p| self.parse_policy(p)), decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)), @@ -502,6 +512,7 @@ impl CliArgs { RoutingMode::VllmPrefillDecode { prefill_urls: prefill_urls.clone(), + cold_prefill_urls, decode_urls: final_decode_urls, prefill_policy: self.prefill_policy.as_ref().map(|p| self.parse_policy(p)), decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)), @@ -716,12 +727,14 @@ fn main() -> Result<(), Box> { dotenvy::dotenv().ok(); println!("DEBUG: Main function started"); - // Parse prefill arguments manually before clap parsing + // Parse prefill and cold-prefill arguments manually before clap parsing println!("DEBUG: Parsing prefill arguments"); let prefill_urls = parse_prefill_args(); + let cold_prefill_urls = parse_cold_prefill_args(); println!("DEBUG: Prefill URLs parsed: {:?}", prefill_urls); + println!("DEBUG: Cold prefill URLs parsed: {:?}", cold_prefill_urls); - // Filter out prefill arguments and their values before passing to clap + // Filter out --prefill and --cold-prefill arguments before passing to clap println!("DEBUG: Filtering CLI arguments"); let mut filtered_args: Vec = Vec::new(); let raw_args: Vec = std::env::args().collect(); @@ -729,8 +742,10 @@ fn main() -> Result<(), Box> { let mut i = 0; while i < raw_args.len() { - if raw_args[i] == "--prefill" && i + 1 < raw_args.len() { - // Skip --prefill and its URL + if (raw_args[i] == "--prefill" || raw_args[i] == "--cold-prefill") + && i + 1 < raw_args.len() + { + // Skip the flag and its URL i += 2; // Also skip bootstrap port if present @@ -796,7 +811,7 @@ Provide --worker-urls or PD flags as usual.", // Convert to RouterConfig println!("DEBUG: Converting to RouterConfig"); - let router_config = cli_args.to_router_config(prefill_urls)?; + let router_config = cli_args.to_router_config(prefill_urls, cold_prefill_urls)?; println!("DEBUG: RouterConfig created successfully"); // Validate configuration diff --git a/src/routers/factory.rs b/src/routers/factory.rs index dde15b92..511cf249 100644 --- a/src/routers/factory.rs +++ b/src/routers/factory.rs @@ -31,6 +31,7 @@ impl RouterFactory { decode_urls, prefill_policy, decode_policy, + .. } => { Self::create_grpc_pd_router( prefill_urls, @@ -42,13 +43,9 @@ impl RouterFactory { ) .await } - RoutingMode::VllmPrefillDecode { - prefill_urls: _, - decode_urls: _, - prefill_policy: _, - decode_policy: _, - discovery_address: _, - } => Err("vLLM PD mode requires HTTP connection_mode".to_string()), + RoutingMode::VllmPrefillDecode { .. } => { + Err("vLLM PD mode requires HTTP connection_mode".to_string()) + } RoutingMode::OpenAI { .. } => { Err("OpenAI mode requires HTTP connection_mode".to_string()) } @@ -62,17 +59,20 @@ impl RouterFactory { } RoutingMode::PrefillDecode { prefill_urls, + cold_prefill_urls, decode_urls, prefill_policy, decode_policy, } => { tracing::info!( - "Creating regular PDRouter with prefill_urls: {:?}, decode_urls: {:?}", + "Creating regular PDRouter with prefill_urls: {:?}, cold_prefill_urls: {:?}, decode_urls: {:?}", prefill_urls, + cold_prefill_urls, decode_urls ); Self::create_pd_router( prefill_urls, + cold_prefill_urls, decode_urls, prefill_policy.as_ref(), decode_policy.as_ref(), @@ -83,15 +83,17 @@ impl RouterFactory { } RoutingMode::VllmPrefillDecode { prefill_urls, + cold_prefill_urls, decode_urls, prefill_policy, decode_policy, discovery_address, } => { - tracing::info!("Creating VllmPDRouter with prefill_urls: {:?}, decode_urls: {:?}, discovery: {:?}", - prefill_urls, decode_urls, discovery_address); + tracing::info!("Creating VllmPDRouter with prefill_urls: {:?}, cold_prefill_urls: {:?}, decode_urls: {:?}, discovery: {:?}", + prefill_urls, cold_prefill_urls, decode_urls, discovery_address); Self::create_vllm_pd_router( prefill_urls, + cold_prefill_urls, decode_urls, discovery_address.clone(), prefill_policy.as_ref(), @@ -123,6 +125,7 @@ impl RouterFactory { /// Create a PD router with injected policy pub async fn create_pd_router( prefill_urls: &[(String, Option)], + cold_prefill_urls: &[(String, Option)], decode_urls: &[String], prefill_policy_config: Option<&PolicyConfig>, decode_policy_config: Option<&PolicyConfig>, @@ -140,7 +143,13 @@ impl RouterFactory { ctx.policy_registry.set_decode_policy(decode_policy); // Create PD router with context (policies are in PolicyRegistry) - let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?; + let router = PDRouter::new( + prefill_urls.to_vec(), + cold_prefill_urls.to_vec(), + decode_urls.to_vec(), + ctx, + ) + .await?; Ok(Box::new(router)) } @@ -148,6 +157,7 @@ impl RouterFactory { /// Create a vLLM PD router with service discovery and/or static URLs pub async fn create_vllm_pd_router( prefill_urls: &[(String, Option)], + cold_prefill_urls: &[(String, Option)], decode_urls: &[String], discovery_address: Option, prefill_policy_config: Option<&PolicyConfig>, @@ -182,6 +192,7 @@ impl RouterFactory { let router = VllmPDRouter::new( prefill_urls.to_vec(), + cold_prefill_urls.to_vec(), decode_urls.to_vec(), discovery_address, ctx, diff --git a/src/routers/http/pd_router.rs b/src/routers/http/pd_router.rs index a170850a..6979578a 100644 --- a/src/routers/http/pd_router.rs +++ b/src/routers/http/pd_router.rs @@ -428,6 +428,7 @@ impl PDRouter { #[allow(clippy::too_many_arguments)] pub async fn new( prefill_urls: Vec<(String, Option)>, + cold_prefill_urls: Vec<(String, Option)>, decode_urls: Vec, ctx: &Arc, ) -> Result { @@ -533,6 +534,19 @@ impl PDRouter { ctx.worker_registry.register(worker); } + // Register cold prefill workers in the registry (no DP expansion for now) + for (url, port) in cold_prefill_urls { + let worker_type = WorkerType::ColdPrefill { + bootstrap_port: port, + }; + let worker: Arc = Arc::new( + BasicWorker::new(url, worker_type) + .with_circuit_breaker_config(core_cb_config.clone()) + .with_health_config(health_config.clone()), + ); + ctx.worker_registry.register(worker); + } + // Register decode workers in the registry for url in expanded_decode_urls { decode_workers_urls.push(url.clone()); diff --git a/src/routers/http/vllm_pd_router.rs b/src/routers/http/vllm_pd_router.rs index ccf17fa3..2395791b 100644 --- a/src/routers/http/vllm_pd_router.rs +++ b/src/routers/http/vllm_pd_router.rs @@ -1103,6 +1103,7 @@ impl VllmPDRouter { /// 2. Direct URL mode: discovery_address is None, prefill_urls and decode_urls are provided pub async fn new( prefill_urls: Vec<(String, Option)>, + cold_prefill_urls: Vec<(String, Option)>, decode_urls: Vec, discovery_address: Option, ctx: &Arc, @@ -1115,7 +1116,7 @@ impl VllmPDRouter { ); // Create underlying PD router with empty worker lists (they'll be discovered dynamically) - let pd_router = PDRouter::new(vec![], vec![], ctx).await?; + let pd_router = PDRouter::new(vec![], cold_prefill_urls, vec![], ctx).await?; // Initialize service discovery let mut service_registry = ServiceRegistry::new(); @@ -1148,7 +1149,7 @@ impl VllmPDRouter { ); // Create underlying PD router with provided worker lists - let pd_router = PDRouter::new(prefill_urls, decode_urls, ctx).await?; + let pd_router = PDRouter::new(prefill_urls, cold_prefill_urls, decode_urls, ctx).await?; // No service discovery in direct URL mode let service_registry = ServiceRegistry::new(); @@ -1212,6 +1213,37 @@ impl VllmPDRouter { self.pd_router.remove_decode_server(url).await } + /// Select the appropriate prefill worker pool based on request headers. + /// + /// When the `is_sub_llm: true` header is present and cold-prefill workers are + /// configured, returns the cold-prefill pool. Otherwise returns the normal + /// prefill pool. + fn get_prefill_workers_for_request( + &self, + headers: Option<&HeaderMap>, + ) -> Vec> { + let is_sub_llm = headers + .and_then(|h| h.get("is_sub_llm")) + .and_then(|v| v.to_str().ok()) + .map(|v| v == "true") + .unwrap_or(false); + + if is_sub_llm { + let cold = self.pd_router.worker_registry.get_cold_prefill_workers(); + if !cold.is_empty() { + debug!( + "is_sub_llm=true header detected, routing to cold prefill pool ({} workers)", + cold.len() + ); + return cold; + } + debug!( + "is_sub_llm=true header detected but no cold prefill workers configured, falling back to normal prefill pool" + ); + } + self.pd_router.worker_registry.get_prefill_workers() + } + /// Get a reference to the underlying PDRouter's worker registry /// This allows access to worker information for refresh operations pub fn worker_registry(&self) -> &crate::core::WorkerRegistry { @@ -1316,7 +1348,7 @@ impl RouterTrait for VllmPDRouter { }; // Get prefill and decode workers from worker_registry - let prefill_workers = self.pd_router.worker_registry.get_prefill_workers(); + let prefill_workers = self.get_prefill_workers_for_request(headers); let decode_workers = self.pd_router.worker_registry.get_decode_workers(); info!( @@ -1389,8 +1421,6 @@ impl RouterTrait for VllmPDRouter { let prefill_worker = &prefill_workers[prefill_idx]; let decode_worker = &decode_workers[decode_idx]; - // Load tracking is handled inside process_vllm_two_stage_request for fine-grained - // tracking: prefill load only during prefill phase, decode load only during decode phase. info!( "Chat: Selected prefill={} [policy:{}], decode={} [policy:{}]", @@ -1487,7 +1517,7 @@ impl RouterTrait for VllmPDRouter { }; // Get prefill and decode workers from worker_registry - let prefill_workers = self.pd_router.worker_registry.get_prefill_workers(); + let prefill_workers = self.get_prefill_workers_for_request(headers); let decode_workers = self.pd_router.worker_registry.get_decode_workers(); info!( @@ -1560,8 +1590,6 @@ impl RouterTrait for VllmPDRouter { let prefill_worker = &prefill_workers[prefill_idx]; let decode_worker = &decode_workers[decode_idx]; - // Load tracking is handled inside process_vllm_two_stage_request for fine-grained - // tracking: prefill load only during prefill phase, decode load only during decode phase. info!( "Completion: Selected prefill={} [policy:{}], decode={} [policy:{}]", @@ -1694,7 +1722,7 @@ impl RouterTrait for VllmPDRouter { self.process_vllm_request(request_json, path, headers).await } else { // Direct URL mode - use worker registry, filtered by availability - let all_prefill = self.pd_router.worker_registry.get_prefill_workers(); + let all_prefill = self.get_prefill_workers_for_request(headers); let prefill_workers: Vec> = all_prefill .iter() .filter(|w| w.is_available()) diff --git a/src/routers/router_manager.rs b/src/routers/router_manager.rs index 46701c33..4d6c229e 100644 --- a/src/routers/router_manager.rs +++ b/src/routers/router_manager.rs @@ -115,7 +115,7 @@ impl RouterManager { let has_pd_workers = workers.iter().any(|w| { matches!( w.worker_type(), - WorkerType::Prefill { .. } | WorkerType::Decode + WorkerType::Prefill { .. } | WorkerType::ColdPrefill { .. } | WorkerType::Decode ) }); @@ -362,6 +362,7 @@ impl RouterManager { worker_type: match worker.worker_type() { WorkerType::Regular => "regular".to_string(), WorkerType::Prefill { .. } => "prefill".to_string(), + WorkerType::ColdPrefill { .. } => "cold_prefill".to_string(), WorkerType::Decode => "decode".to_string(), }, is_healthy: worker.is_healthy(), diff --git a/src/server.rs b/src/server.rs index 92160505..13165997 100644 --- a/src/server.rs +++ b/src/server.rs @@ -592,6 +592,7 @@ async fn list_workers_rest( "worker_type": match worker.worker_type() { WorkerType::Regular => "regular", WorkerType::Prefill { .. } => "prefill", + WorkerType::ColdPrefill { .. } => "cold_prefill", WorkerType::Decode => "decode", }, "is_healthy": worker.is_healthy(), @@ -602,8 +603,12 @@ async fn list_workers_rest( }); // Add bootstrap_port for Prefill workers - if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() { - worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port); + match worker.worker_type() { + WorkerType::Prefill { bootstrap_port } + | WorkerType::ColdPrefill { bootstrap_port } => { + worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port); + } + _ => {} } worker_info @@ -923,6 +928,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Date: Thu, 2 Apr 2026 11:49:16 +0530 Subject: [PATCH 2/2] logs --- src/core/worker_registry.rs | 8 +++++--- src/main.rs | 3 +++ src/protocols/worker_spec.rs | 1 + src/routers/factory.rs | 3 ++- src/routers/http/pd_router.rs | 16 ++++++++++++++++ src/routers/http/vllm_pd_router.rs | 7 ++++--- src/routers/router_manager.rs | 1 + 7 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/core/worker_registry.rs b/src/core/worker_registry.rs index ebce9b9e..55b208bc 100644 --- a/src/core/worker_registry.rs +++ b/src/core/worker_registry.rs @@ -420,6 +420,7 @@ impl WorkerRegistry { let mut total_load = 0; let mut regular_count = 0; let mut prefill_count = 0; + let mut cold_prefill_count = 0; let mut decode_count = 0; for worker in self.get_all() { @@ -430,9 +431,8 @@ impl WorkerRegistry { match worker.worker_type() { WorkerType::Regular => regular_count += 1, - WorkerType::Prefill { .. } | WorkerType::ColdPrefill { .. } => { - prefill_count += 1 - } + WorkerType::Prefill { .. } => prefill_count += 1, + WorkerType::ColdPrefill { .. } => cold_prefill_count += 1, WorkerType::Decode => decode_count += 1, } } @@ -444,6 +444,7 @@ impl WorkerRegistry { total_load, regular_workers: regular_count, prefill_workers: prefill_count, + cold_prefill_workers: cold_prefill_count, decode_workers: decode_count, } } @@ -585,6 +586,7 @@ pub struct WorkerRegistryStats { pub total_load: usize, pub regular_workers: usize, pub prefill_workers: usize, + pub cold_prefill_workers: usize, pub decode_workers: usize, } diff --git a/src/main.rs b/src/main.rs index a6aa38ad..6ae1e17f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -806,6 +806,9 @@ Provide --worker-urls or PD flags as usual.", if cli_args.pd_disaggregation && !prefill_urls.is_empty() { println!("Prefill nodes: {:?}", prefill_urls); println!("Decode nodes: {:?}", cli_args.decode); + if !cold_prefill_urls.is_empty() { + println!("Cold prefill nodes: {:?}", cold_prefill_urls); + } } } diff --git a/src/protocols/worker_spec.rs b/src/protocols/worker_spec.rs index ca661d76..5cc781f1 100644 --- a/src/protocols/worker_spec.rs +++ b/src/protocols/worker_spec.rs @@ -115,6 +115,7 @@ pub struct WorkerStats { pub struct WorkerTypeStats { pub regular: usize, pub prefill: usize, + pub cold_prefill: usize, pub decode: usize, } diff --git a/src/routers/factory.rs b/src/routers/factory.rs index 511cf249..341ed465 100644 --- a/src/routers/factory.rs +++ b/src/routers/factory.rs @@ -184,8 +184,9 @@ impl RouterFactory { } if !prefill_urls.is_empty() || !decode_urls.is_empty() { tracing::info!( - "Creating VllmPDRouter with static URLs - prefill: {:?}, decode: {:?}", + "Creating VllmPDRouter with static URLs - prefill: {:?}, cold_prefill: {:?}, decode: {:?}", prefill_urls, + cold_prefill_urls, decode_urls ); } diff --git a/src/routers/http/pd_router.rs b/src/routers/http/pd_router.rs index 6979578a..cc736b5a 100644 --- a/src/routers/http/pd_router.rs +++ b/src/routers/http/pd_router.rs @@ -535,7 +535,9 @@ impl PDRouter { } // Register cold prefill workers in the registry (no DP expansion for now) + let mut cold_prefill_workers_urls = vec![]; for (url, port) in cold_prefill_urls { + cold_prefill_workers_urls.push(url.clone()); let worker_type = WorkerType::ColdPrefill { bootstrap_port: port, }; @@ -592,6 +594,20 @@ impl PDRouter { .await?; } + if !cold_prefill_workers_urls.is_empty() { + info!( + "Waiting for {} cold prefill worker(s) to become healthy: {:?}", + cold_prefill_workers_urls.len(), + cold_prefill_workers_urls + ); + crate::routers::http::router::Router::wait_for_healthy_workers( + &cold_prefill_workers_urls, + ctx.router_config.worker_startup_timeout_secs, + ctx.router_config.worker_startup_check_interval_secs, + ) + .await?; + } + // Initialize cache-aware policies with workers from registry // Note: We need to get workers by type and convert to Box for CacheAwarePolicy // This is a temporary workaround until CacheAwarePolicy is updated to work with Arc diff --git a/src/routers/http/vllm_pd_router.rs b/src/routers/http/vllm_pd_router.rs index 2395791b..f8d42c47 100644 --- a/src/routers/http/vllm_pd_router.rs +++ b/src/routers/http/vllm_pd_router.rs @@ -1143,8 +1143,9 @@ impl VllmPDRouter { } else { // Direct URL mode (same as PDRouter) info!( - "VllmPDRouter::new called in direct URL mode with {} prefill, {} decode workers", + "VllmPDRouter::new called in direct URL mode with {} prefill, {} cold prefill, {} decode workers", prefill_urls.len(), + cold_prefill_urls.len(), decode_urls.len() ); @@ -1231,13 +1232,13 @@ impl VllmPDRouter { if is_sub_llm { let cold = self.pd_router.worker_registry.get_cold_prefill_workers(); if !cold.is_empty() { - debug!( + info!( "is_sub_llm=true header detected, routing to cold prefill pool ({} workers)", cold.len() ); return cold; } - debug!( + warn!( "is_sub_llm=true header detected but no cold prefill workers configured, falling back to normal prefill pool" ); } diff --git a/src/routers/router_manager.rs b/src/routers/router_manager.rs index 4d6c229e..3a7d822e 100644 --- a/src/routers/router_manager.rs +++ b/src/routers/router_manager.rs @@ -312,6 +312,7 @@ impl RouterManager { by_type: WorkerTypeStats { regular: registry_stats.regular_workers, prefill: registry_stats.prefill_workers, + cold_prefill: registry_stats.cold_prefill_workers, decode: registry_stats.decode_workers, }, };