Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/config/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ pub enum RoutingMode {
PrefillDecode {
/// Prefill worker URLs with optional bootstrap ports
prefill_urls: Vec<(String, Option<u16>)>,
/// Cold prefill worker URLs with optional bootstrap ports (for is_sub_llm requests)
#[serde(default, skip_serializing_if = "Vec::is_empty")]
cold_prefill_urls: Vec<(String, Option<u16>)>,
/// Decode worker URLs
decode_urls: Vec<String>,
/// Optional separate policy for prefill workers
Expand All @@ -146,6 +149,9 @@ pub enum RoutingMode {
VllmPrefillDecode {
/// Prefill worker URLs with optional bootstrap ports
prefill_urls: Vec<(String, Option<u16>)>,
/// Cold prefill worker URLs with optional bootstrap ports (for is_sub_llm requests)
#[serde(default, skip_serializing_if = "Vec::is_empty")]
cold_prefill_urls: Vec<(String, Option<u16>)>,
/// Decode worker URLs
decode_urls: Vec<String>,
/// Optional separate policy for prefill workers
Expand Down Expand Up @@ -638,6 +644,7 @@ mod tests {

let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
Expand All @@ -661,6 +668,7 @@ mod tests {
("http://prefill1".to_string(), Some(8001)),
("http://prefill2".to_string(), None),
],
cold_prefill_urls: vec![],
decode_urls: vec![
"http://decode1".to_string(),
"http://decode2".to_string(),
Expand Down Expand Up @@ -690,6 +698,7 @@ mod tests {
// Test PrefillDecode mode
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
Expand Down Expand Up @@ -888,6 +897,7 @@ mod tests {
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![],
cold_prefill_urls: vec![],
decode_urls: vec![],
prefill_policy: None,
decode_policy: None,
Expand Down Expand Up @@ -1010,6 +1020,7 @@ mod tests {
("http://prefill1:8000".to_string(), Some(8001)),
("http://prefill2:8000".to_string(), None),
],
cold_prefill_urls: vec![],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
Expand Down Expand Up @@ -1213,6 +1224,7 @@ mod tests {
// When both prefill and decode policies are specified, they should be used
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
Expand Down Expand Up @@ -1245,6 +1257,7 @@ mod tests {
// When only prefill policy is specified, decode should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
Expand Down Expand Up @@ -1276,6 +1289,7 @@ mod tests {
// When only decode policy is specified, prefill should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: Some(PolicyConfig::PowerOfTwo {
Expand Down Expand Up @@ -1303,6 +1317,7 @@ mod tests {
// When no specific policies are specified, both should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
Expand Down
9 changes: 8 additions & 1 deletion src/config/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ impl ConfigValidator {
decode_urls,
prefill_policy,
decode_policy,
..
} => {
// Only require URLs if service discovery is disabled
if !has_service_discovery {
Expand Down Expand Up @@ -107,7 +108,7 @@ impl ConfigValidator {
decode_urls,
prefill_policy,
decode_policy,
discovery_address: _,
..
} => {
// Only require URLs if service discovery is disabled
if !has_service_discovery {
Expand Down Expand Up @@ -480,6 +481,7 @@ impl ConfigValidator {
decode_urls,
prefill_policy,
decode_policy,
..
} = &config.mode
{
// Check power-of-two for prefill
Expand Down Expand Up @@ -665,6 +667,7 @@ mod tests {
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
Expand All @@ -681,6 +684,7 @@ mod tests {
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
Expand All @@ -698,6 +702,7 @@ mod tests {
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)],
cold_prefill_urls: vec![],
decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
Expand Down Expand Up @@ -743,6 +748,7 @@ mod tests {
("http://prefill1:8000".to_string(), None),
("http://prefill2:8000".to_string(), None),
],
cold_prefill_urls: vec![],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
Expand Down Expand Up @@ -771,6 +777,7 @@ mod tests {
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
cold_prefill_urls: vec![],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
Expand Down
29 changes: 29 additions & 0 deletions src/core/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ pub enum WorkerType {
/// Bootstrap port for communication with decode workers
bootstrap_port: Option<u16>,
},
/// Cold prefill worker for `is_sub_llm` requests
ColdPrefill {
/// Bootstrap port for communication with decode workers
bootstrap_port: Option<u16>,
},
/// Decode worker for PD disaggregated mode
Decode,
}
Expand All @@ -303,6 +308,10 @@ impl fmt::Display for WorkerType {
Some(port) => write!(f, "Prefill(bootstrap:{})", port),
None => write!(f, "Prefill"),
},
WorkerType::ColdPrefill { bootstrap_port } => match bootstrap_port {
Some(port) => write!(f, "ColdPrefill(bootstrap:{})", port),
None => write!(f, "ColdPrefill"),
},
WorkerType::Decode => write!(f, "Decode"),
}
}
Expand Down Expand Up @@ -762,6 +771,26 @@ impl WorkerFactory {
)
}

/// Create a cold prefill worker with optional bootstrap port
pub fn create_cold_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::new(
url,
WorkerType::ColdPrefill { bootstrap_port },
))
}

/// Create a cold prefill worker with custom circuit breaker configuration
pub fn create_cold_prefill_with_config(
url: String,
bootstrap_port: Option<u16>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorker::new(url, WorkerType::ColdPrefill { bootstrap_port })
.with_circuit_breaker_config(circuit_breaker_config),
)
}

/// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Decode))
Expand Down
18 changes: 18 additions & 0 deletions src/core/worker_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ impl WorkerRegistry {
.collect()
}

/// Get all cold prefill workers (regardless of bootstrap_port)
pub fn get_cold_prefill_workers(&self) -> Vec<Arc<dyn Worker>> {
self.workers
.iter()
.filter_map(|entry| {
let worker = entry.value();
match worker.worker_type() {
WorkerType::ColdPrefill { .. } => Some(worker.clone()),
_ => None,
}
})
.collect()
}

/// Get all decode workers
pub fn get_decode_workers(&self) -> Vec<Arc<dyn Worker>> {
self.get_by_type(&WorkerType::Decode)
Expand Down Expand Up @@ -406,6 +420,7 @@ impl WorkerRegistry {
let mut total_load = 0;
let mut regular_count = 0;
let mut prefill_count = 0;
let mut cold_prefill_count = 0;
let mut decode_count = 0;

for worker in self.get_all() {
Expand All @@ -417,6 +432,7 @@ impl WorkerRegistry {
match worker.worker_type() {
WorkerType::Regular => regular_count += 1,
WorkerType::Prefill { .. } => prefill_count += 1,
WorkerType::ColdPrefill { .. } => cold_prefill_count += 1,
WorkerType::Decode => decode_count += 1,
}
}
Expand All @@ -428,6 +444,7 @@ impl WorkerRegistry {
total_load,
regular_workers: regular_count,
prefill_workers: prefill_count,
cold_prefill_workers: cold_prefill_count,
decode_workers: decode_count,
}
}
Expand Down Expand Up @@ -569,6 +586,7 @@ pub struct WorkerRegistryStats {
pub total_load: usize,
pub regular_workers: usize,
pub prefill_workers: usize,
pub cold_prefill_workers: usize,
pub decode_workers: usize,
}

Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl Router {
} else if self.vllm_pd_disaggregation {
RoutingMode::VllmPrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
cold_prefill_urls: vec![],
decode_urls: self.decode_urls.clone().unwrap_or_default(),
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
decode_policy: self.decode_policy.as_ref().map(convert_policy),
Expand All @@ -162,6 +163,7 @@ impl Router {
} else if self.pd_disaggregation {
RoutingMode::PrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
cold_prefill_urls: vec![],
decode_urls: self.decode_urls.clone().unwrap_or_default(),
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
decode_policy: self.decode_policy.as_ref().map(convert_policy),
Expand Down
Loading
Loading