diff --git a/lib/runtime/src/transports/etcd.rs b/lib/runtime/src/transports/etcd.rs index a6b16032c9..530fd8988e 100644 --- a/lib/runtime/src/transports/etcd.rs +++ b/lib/runtime/src/transports/etcd.rs @@ -20,31 +20,18 @@ use etcd_client::{ pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient}; use tokio::time::{Duration, interval}; +mod connector; mod lease; mod lock; mod path; +use connector::Connector; use lease::*; pub use lock::*; pub use path::*; use super::utils::build_in_runtime; -/// ETCD Client -#[derive(Clone)] -pub struct Client { - client: etcd_client::Client, - primary_lease: u64, - runtime: Runtime, - rt: Arc, -} - -impl std::fmt::Debug for Client { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "etcd::Client primary_lease={}", self.primary_lease) - } -} - #[derive(Debug, Clone)] pub struct Lease { /// ETCD lease ID @@ -86,6 +73,21 @@ impl Lease { } } +/// ETCD Client +#[derive(Clone)] +pub struct Client { + connector: Arc, + primary_lease: u64, + runtime: Runtime, + rt: Arc, +} + +impl std::fmt::Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "etcd::Client primary_lease={}", self.primary_lease) + } +} + impl Client { pub fn builder() -> ClientOptionsBuilder { ClientOptionsBuilder::default() @@ -102,24 +104,23 @@ impl Client { pub async fn new(config: ClientOptions, runtime: Runtime) -> Result { let token = runtime.primary_token(); - let ((client, lease_id), rt) = build_in_runtime( + let ((connector, lease_id), rt) = build_in_runtime( async move { - let client = etcd_client::Client::connect( - config.etcd_url.clone(), - config.etcd_connect_options, - ) - .await - .with_context(|| { - format!( - "Unable to connect to etcd server at {}. Check etcd server status", - config.etcd_url.join(", ") - ) - })?; + let etcd_urls = config.etcd_url.clone(); + let connect_options = config.etcd_connect_options.clone(); + + // Create the connector + let connector = Connector::new(etcd_urls, connect_options) + .await + .with_context(|| { + format!( + "Unable to connect to etcd server at {}. Check etcd server status", + config.etcd_url.join(", ") + ) + })?; let lease_id = if config.attach_lease { - let lease_client = client.lease_client(); - - let lease = create_lease(lease_client, 10, token) + let lease = create_lease(connector.clone(), 10, token) .await .with_context(|| { format!( @@ -133,23 +134,24 @@ impl Client { 0 }; - Ok((client, lease_id)) + Ok((connector, lease_id)) }, 1, ) .await?; Ok(Client { - client, + connector, primary_lease: lease_id, rt, runtime, }) } - /// Get a reference to the underlying [`etcd_client::Client`] instance. - pub(crate) fn etcd_client(&self) -> &etcd_client::Client { - &self.client + /// Get a clone of the underlying [`etcd_client::Client`] instance. + /// This returns a clone since the client is behind an RwLock. + pub fn etcd_client(&self) -> etcd_client::Client { + self.connector.get_client() } /// Get the primary lease ID. @@ -169,16 +171,16 @@ impl Client { /// This [`Lease`] will be tied to the [`Runtime`], specifically a child [`CancellationToken`]. pub async fn create_lease(&self, ttl: u64) -> Result { let token = self.runtime.child_token(); - let lease_client = self.client.lease_client(); self.rt - .spawn(create_lease(lease_client, ttl, token)) + .spawn(create_lease(self.connector.clone(), ttl, token)) .await? } // Revoke an etcd lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke pub async fn revoke_lease(&self, lease_id: u64) -> Result<()> { - let lease_client = self.client.lease_client(); - self.rt.spawn(revoke_lease(lease_client, lease_id)).await? + self.rt + .spawn(revoke_lease(self.connector.clone(), lease_id)) + .await? } pub async fn kv_create(&self, key: &str, value: Vec, lease_id: Option) -> Result<()> { @@ -193,7 +195,7 @@ impl Client { ]); // Execute the transaction - let result = self.client.kv_client().txn(txn).await?; + let result = self.connector.get_client().kv_client().txn(txn).await?; if result.succeeded() { Ok(()) @@ -232,7 +234,7 @@ impl Client { ]); // Execute the transaction - let result = self.client.kv_client().txn(txn).await?; + let result = self.connector.get_client().kv_client().txn(txn).await?; // We have to enumerate the response paths to determine if the transaction succeeded if result.succeeded() { @@ -266,7 +268,8 @@ impl Client { let id = lease_id.unwrap_or(self.lease_id()); let put_options = PutOptions::new().with_lease(id as i64); let _ = self - .client + .connector + .get_client() .kv_client() .put(key.as_ref(), value.as_ref(), Some(put_options)) .await?; @@ -282,7 +285,8 @@ impl Client { let options = options .unwrap_or_default() .with_lease(self.primary_lease().id() as i64); - self.client + self.connector + .get_client() .kv_client() .put(key.as_ref(), value.as_ref(), Some(options)) .await @@ -294,7 +298,12 @@ impl Client { key: impl Into>, options: Option, ) -> Result> { - let mut get_response = self.client.kv_client().get(key, options).await?; + let mut get_response = self + .connector + .get_client() + .kv_client() + .get(key, options) + .await?; Ok(get_response.take_kvs()) } @@ -303,7 +312,8 @@ impl Client { key: impl Into>, options: Option, ) -> Result { - self.client + self.connector + .get_client() .kv_client() .delete(key, options) .await @@ -313,7 +323,8 @@ impl Client { pub async fn kv_get_prefix(&self, prefix: impl AsRef) -> Result> { let mut get_response = self - .client + .connector + .get_client() .kv_client() .get(prefix.as_ref(), Some(GetOptions::new().with_prefix())) .await?; @@ -328,7 +339,7 @@ impl Client { key: impl Into>, lease_id: Option, ) -> Result { - let mut lock_client = self.client.lock_client(); + let mut lock_client = self.connector.get_client().lock_client(); let id = lease_id.unwrap_or(self.lease_id()); let options = LockOptions::new().with_lease(id as i64); lock_client @@ -339,7 +350,7 @@ impl Client { /// Release a distributed lock using the key from the LockResponse pub async fn unlock(&self, lock_key: impl Into>) -> Result<()> { - let mut lock_client = self.client.lock_client(); + let mut lock_client = self.connector.get_client().lock_client(); lock_client .unlock(lock_key) .await @@ -367,8 +378,9 @@ impl Client { prefix: impl AsRef + std::fmt::Display, include_existing: bool, ) -> Result { - let mut kv_client = self.client.kv_client(); - let mut watch_client = self.client.watch_client(); + let client = self.connector.get_client(); + let mut kv_client = client.kv_client(); + let mut watch_client = client.watch_client(); let mut get_response = kv_client .get(prefix.as_ref(), Some(GetOptions::new().with_prefix())) diff --git a/lib/runtime/src/transports/etcd/connector.rs b/lib/runtime/src/transports/etcd/connector.rs new file mode 100644 index 0000000000..5ff6563ae0 --- /dev/null +++ b/lib/runtime/src/transports/etcd/connector.rs @@ -0,0 +1,169 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ErrorContext, Result, error}; +use etcd_client::ConnectOptions; +use parking_lot::RwLock; +use std::{sync::Arc, time::Duration}; +use tokio::{sync::Mutex, time::sleep}; + +/// Manages ETCD client connections with reconnection support +pub struct Connector { + /// The actual ETCD client, protected by RwLock for safe updates during reconnection + /// WARNING: Do not recursively acquire a read lock when the current thread already holds one + client: RwLock, + /// Configuration for connecting to ETCD + etcd_urls: Vec, + connect_options: Option, + /// Tracks the current backoff duration and last successful connect time + /// The Mutex ensures only one reconnect operation runs at a time + backoff_state: Mutex, +} + +impl Connector { + /// Create a new connector with an established connection + pub async fn new( + etcd_urls: Vec, + connect_options: Option, + ) -> Result> { + // Connect to ETCD + let client = Self::connect(&etcd_urls, &connect_options).await?; + + Ok(Arc::new(Self { + client: RwLock::new(client), + etcd_urls, + connect_options, + backoff_state: Mutex::new(BackoffState::default()), + })) + } + + /// Connect to ETCD cluster + async fn connect( + etcd_urls: &[String], + connect_options: &Option, + ) -> Result { + etcd_client::Client::connect(etcd_urls.to_vec(), connect_options.clone()) + .await + .with_context(|| { + format!( + "Unable to connect to etcd server at {}. Check etcd server status", + etcd_urls.join(", ") + ) + }) + } + + /// Get a clone of the current ETCD client + pub fn get_client(&self) -> etcd_client::Client { + self.client.read().clone() + } + + /// Reconnect to ETCD cluster with retry logic + /// Respects the deadline and returns error if exceeded + /// + /// Backoff behavior: + /// - Starts at 0 (immediate reconnect) if this is the first reconnect or enough time has passed + /// since the last reconnect + /// - Increments exponentially for continuous failures + /// - Resets to 0 only when: this is a new call AND current_time > last_connect_time + residual_backoff + /// + /// The mutex ensures only one reconnect operation runs at a time globally + pub async fn reconnect(&self, deadline: std::time::Instant) -> Result<()> { + let mut backoff_state = self.backoff_state.lock().await; + + tracing::warn!("Reconnecting to ETCD cluster at: {:?}", self.etcd_urls); + backoff_state.attempt_reset(); + + loop { + backoff_state.apply_backoff(deadline).await; + if std::time::Instant::now() >= deadline { + return Err(error!( + "Unable to reconnect to ETCD cluster: deadline exceeded" + )); + } + + match Self::connect(&self.etcd_urls, &self.connect_options).await { + Ok(new_client) => { + tracing::info!("Successfully reconnected to ETCD cluster"); + // Update the client behind the lock + let mut client_guard = self.client.write(); + *client_guard = new_client; + return Ok(()); + } + Err(e) => { + tracing::warn!( + "Reconnection failed (remaining time: {:?}): {}", + deadline.saturating_duration_since(std::time::Instant::now()), + e + ); + } + } + } + } + + /// Get the ETCD URLs + pub fn etcd_urls(&self) -> &[String] { + &self.etcd_urls + } + + /// Get the connection options + pub fn connect_options(&self) -> &Option { + &self.connect_options + } +} + +#[derive(Debug)] +struct BackoffState { + /// Initial backoff duration for reconnection attempts + pub initial_backoff: Duration, + /// Minimum backoff duration for reconnection attempts + pub min_backoff: Duration, + /// Maximum backoff duration for reconnection attempts + pub max_backoff: Duration, + /// Current backoff duration (starts at 0 for immediate reconnect) + current_backoff: Duration, + /// Last time a connection establishment was attempted + last_connect_attempt: std::time::Instant, +} + +impl Default for BackoffState { + fn default() -> Self { + Self { + initial_backoff: Duration::from_millis(500), + min_backoff: Duration::from_millis(50), + max_backoff: Duration::from_secs(5), + current_backoff: Duration::ZERO, + last_connect_attempt: std::time::Instant::now(), + } + } +} + +impl BackoffState { + /// Reset backoff to 0 if enough time has passed since the last connection + pub fn attempt_reset(&mut self) { + if std::time::Instant::now() > self.last_connect_attempt + self.current_backoff { + tracing::debug!("Resetting backoff to 0 (first reconnect or enough time has passed)"); + self.current_backoff = Duration::ZERO; + } + } + + /// Apply backoff and update backoff state for possible next connection attempt + pub async fn apply_backoff(&mut self, deadline: std::time::Instant) { + if self.current_backoff > Duration::ZERO { + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + let backoff = std::cmp::min(self.current_backoff, remaining / 2); + let backoff = std::cmp::min(backoff, self.max_backoff); + let backoff = std::cmp::max(backoff, self.min_backoff); + self.current_backoff = backoff * 2; + + tracing::debug!( + "Applying backoff of {:?} (remaining time: {:?})", + backoff, + remaining + ); + sleep(backoff).await; + } else { + self.current_backoff = self.initial_backoff; + } + self.last_connect_attempt = std::time::Instant::now(); + } +} diff --git a/lib/runtime/src/transports/etcd/lease.rs b/lib/runtime/src/transports/etcd/lease.rs index 493f70e104..086e58b6a1 100644 --- a/lib/runtime/src/transports/etcd/lease.rs +++ b/lib/runtime/src/transports/etcd/lease.rs @@ -1,14 +1,19 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use super::connector::Connector; use super::*; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::{sleep, timeout}; /// Create a [`Lease`] with a given time-to-live (TTL) attached to the [`CancellationToken`]. pub async fn create_lease( - mut lease_client: LeaseClient, + connector: Arc, ttl: u64, token: CancellationToken, ) -> Result { + let mut lease_client = connector.get_client().lease_client(); let lease = lease_client.grant(ttl as i64, None).await?; let id = lease.id() as u64; @@ -17,7 +22,7 @@ pub async fn create_lease( let clone = token.clone(); tokio::spawn(async move { - match keep_alive(lease_client, id, ttl, child).await { + match keep_alive(connector, id, ttl, child).await { Ok(_) => tracing::trace!("keep alive task exited successfully"), Err(e) => { tracing::error!( @@ -36,7 +41,8 @@ pub async fn create_lease( } /// Revoke a lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke -pub async fn revoke_lease(mut lease_client: LeaseClient, lease_id: u64) -> Result<()> { +pub async fn revoke_lease(connector: Arc, lease_id: u64) -> Result<()> { + let mut lease_client = connector.get_client().lease_client(); match lease_client.revoke(lease_id as i64).await { Ok(_) => Ok(()), Err(e) => { @@ -46,72 +52,115 @@ pub async fn revoke_lease(mut lease_client: LeaseClient, lease_id: u64) -> Resul } } -/// Task to keep leases alive. +/// Task to keep leases alive with reconnection support. /// /// If this task returns an error, the cancellation token will be invoked on the runtime. -/// If -pub async fn keep_alive( - client: LeaseClient, +async fn keep_alive( + connector: Arc, lease_id: u64, - ttl: u64, + mut ttl: u64, token: CancellationToken, ) -> Result<()> { - let mut ttl = ttl; let mut deadline = create_deadline(ttl)?; - let mut client = client; - let (mut heartbeat_sender, mut heartbeat_receiver) = client.keep_alive(lease_id as i64).await?; - loop { - // if the deadline is exceeded, then we have failed to issue a heartbeat in time - // we may be permanently disconnected from the etcd server, so we are now officially done - if deadline < std::time::Instant::now() { - return Err(error!( - "Unable to refresh lease - deadline exceeded. Check etcd server status" - )); - } - - tokio::select! { - biased; - - status = heartbeat_receiver.message() => { - if let Some(resp) = status? { - tracing::trace!(lease_id, "keep alive response received: {:?}", resp); + // Try to establish or re-establish the keep-alive stream + let mut lease_client = connector.get_client().lease_client(); + let (mut heartbeat_sender, mut heartbeat_receiver) = match lease_client + .keep_alive(lease_id as i64) + .await + { + Ok((sender, receiver)) => { + tracing::debug!(lease_id, "Established keep-alive stream"); + (sender, receiver) + } + Err(e) => { + tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream"); - // update ttl and deadline - ttl = resp.ttl() as u64; - deadline = create_deadline(ttl)?; + // Try to reconnect with the deadline, but also check for cancellation + tokio::select! { + biased; - if resp.ttl() == 0 { - return Err(error!("Unable to maintain lease - expired or revoked. Check etcd server status")); + reconnect_result = connector.reconnect(deadline) => { + match reconnect_result { + Err(e) => return Err(e), + _ => continue, + } } + _ = token.cancelled() => { + tracing::debug!(lease_id, "Cancellation token triggered during reconnection"); + return Ok(()); + } } } - - _ = token.cancelled() => { - tracing::trace!(lease_id, "cancellation token triggered; revoking lease"); - let _ = client.revoke(lease_id as i64).await?; - return Ok(()); + }; + + // Keep-alive loop with the established stream + loop { + if deadline < std::time::Instant::now() { + return Err(error!( + "Unable to refresh lease - deadline exceeded. Check etcd server status" + )); } - _ = tokio::time::sleep(tokio::time::Duration::from_secs(ttl / 2)) => { - tracing::trace!(lease_id, "sending keep alive"); - - // if we get a error issuing the heartbeat, set the ttl to 0 - // this will allow us to poll the response stream once and the cancellation token once, then - // immediately try to tick the heartbeat - // this will repeat until either the heartbeat is reestablished or the deadline is exceeded - if let Err(e) = heartbeat_sender.keep_alive().await { - tracing::warn!( - lease_id, - error = %e, - "Unable to send lease heartbeat. Check etcd server status" - ); - ttl = 0; + tokio::select! { + biased; + + status = heartbeat_receiver.message() => { + match status { + Ok(Some(resp)) => { + tracing::trace!(lease_id, "keep alive response received: {:?}", resp); + + // Update ttl and deadline from response + ttl = resp.ttl() as u64; + deadline = create_deadline(ttl)?; + + if resp.ttl() == 0 { + return Err(error!("Unable to maintain lease - expired or revoked. Check etcd server status")); + } + } + Ok(None) => { + tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended"); + break; + } + Err(e) => { + tracing::warn!(lease_id, error = %e, "Keep-alive stream error"); + break; + } + } + } + + _ = token.cancelled() => { + tracing::debug!(lease_id, "cancellation token triggered; revoking lease"); + if let Err(e) = lease_client.revoke(lease_id as i64).await { + tracing::warn!( + lease_id, + error = %e, + "Failed to revoke lease during cancellation. Cleanup may be incomplete." + ); + } + return Ok(()); } - } + _ = tokio::time::sleep(Duration::from_secs(ttl / 2)) => { + tracing::trace!(lease_id, "sending keep alive"); + + // if we get a error issuing the heartbeat, set the ttl to 0 + // this will allow us to poll the response stream once and the cancellation + // token once, then immediately try to tick the heartbeat + // this will repeat until either the heartbeat is reestablished or the deadline + // is exceeded + if let Err(e) = heartbeat_sender.keep_alive().await { + tracing::warn!( + lease_id, + error = %e, + "Unable to send lease heartbeat. Check etcd server status" + ); + ttl = 0; + } + } + } } } } diff --git a/tests/fault_tolerance/etcd_ha/__init__.py b/tests/fault_tolerance/etcd_ha/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fault_tolerance/etcd_ha/test_vllm.py b/tests/fault_tolerance/etcd_ha/test_vllm.py new file mode 100644 index 0000000000..f5d26b7b9f --- /dev/null +++ b/tests/fault_tolerance/etcd_ha/test_vllm.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import shutil + +import pytest + +from tests.conftest import NatsServer +from tests.fault_tolerance.etcd_ha.utils import ( + DynamoFrontendProcess, + EtcdCluster, + send_inference_request, + wait_for_processes_to_terminate, +) +from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME +from tests.utils.engine_process import FRONTEND_PORT +from tests.utils.managed_process import ManagedProcess +from tests.utils.payloads import check_health_generate, check_models_api + +logger = logging.getLogger(__name__) + + +class DynamoWorkerProcess(ManagedProcess): + """Process manager for Dynamo worker with vLLM backend and ETCD HA support""" + + def __init__(self, request, etcd_endpoints: list): + command = [ + "python3", + "-m", + "dynamo.vllm", + "--model", + FAULT_TOLERANCE_MODEL_NAME, + "--enforce-eager", + "--gpu-memory-utilization", + "0.45", + "--max-model-len", + "8192", + ] + + # Health checks - frontend model registration + health_check_urls = [ + (f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api), + (f"http://localhost:{FRONTEND_PORT}/health", check_health_generate), + ] + + # Set debug logging and ETCD endpoints + env = os.environ.copy() + env["DYN_LOG"] = "debug" + env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints) + + log_dir = f"{request.node.name}_worker" + + # Clean up any existing log directory from previous runs + try: + shutil.rmtree(log_dir) + logger.info(f"Cleaned up existing log directory: {log_dir}") + except FileNotFoundError: + pass + + super().__init__( + command=command, + env=env, + health_check_urls=health_check_urls, + timeout=120, + display_output=True, + terminate_existing=False, + stragglers=[ + "VLLM::EngineCore", + ], + straggler_commands=[ + "-m dynamo.vllm", + ], + log_dir=log_dir, + ) + + +@pytest.mark.vllm +@pytest.mark.gpu_1 +@pytest.mark.e2e +@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) +@pytest.mark.xfail( + strict=False, + reason="Lease watch failover not yet implemented, only lease keep alive failover is implemented", +) +def test_etcd_ha_failover_vllm_aggregated(request, predownload_models): + """ + Test ETCD High Availability with leader failover. + + This test: + 1. Starts a 3-node ETCD cluster + 2. Starts NATS, frontend, and a vLLM worker + 3. Sends an inference request to verify the system works + 4. Terminates the ETCD leader node + 5. Sends another inference request to verify the system still works + """ + # Step 1: Start NATS server + with NatsServer(request): + logger.info("NATS server started successfully") + + # Step 2: Start 3-node ETCD cluster + with EtcdCluster(request) as etcd_cluster: + logger.info("3-node ETCD cluster started successfully") + + # Get the endpoints for all ETCD nodes + etcd_endpoints = etcd_cluster.get_client_endpoints() + logger.info(f"ETCD endpoints: {etcd_endpoints}") + + # Step 3: Start the frontend with ETCD endpoints + with DynamoFrontendProcess(request, etcd_endpoints): + logger.info("Frontend started successfully") + + # Step 4: Start a vLLM worker + with DynamoWorkerProcess(request, etcd_endpoints): + logger.info("Worker started successfully") + + # Step 5: Send first inference request to verify system is working + logger.info("Sending first inference request (before failover)") + result1 = send_inference_request("What is 2+2? The answer is") + assert ( + "4" in result1.lower() or "four" in result1.lower() + ), f"Expected '4' or 'four' in response, got: '{result1}'" + + # Step 6: Identify and terminate the ETCD leader + logger.info("Terminating ETCD leader to test failover") + terminated_idx = etcd_cluster.terminate_leader() + if terminated_idx is None: + pytest.fail("Failed to identify and terminate ETCD leader") + + logger.info(f"Terminated ETCD node {terminated_idx}") + + # Step 7: Send second inference request to verify system still works + logger.info("Sending second inference request (after failover)") + result2 = send_inference_request("The capital of France is") + assert ( + "paris" in result2.lower() + ), f"Expected 'Paris' in response, got: '{result2}'" + + +@pytest.mark.vllm +@pytest.mark.gpu_1 +@pytest.mark.e2e +@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) +def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models): + """ + Test that frontend and worker shut down when single ETCD node is terminated. + + This test: + 1. Starts a single ETCD node (no cluster) + 2. Starts NATS, frontend, and a vLLM worker + 3. Sends an inference request to verify the system works + 4. Terminates the single ETCD node + 5. Verifies that frontend and worker shut down gracefully + """ + # Step 1: Start NATS server + with NatsServer(request): + logger.info("NATS server started successfully") + + # Step 2: Start single ETCD node using EtcdCluster with num_replicas=1 + with EtcdCluster(request, num_replicas=1) as etcd_cluster: + logger.info("Single ETCD node started successfully") + + # Get the endpoint for the single ETCD node + etcd_endpoints = etcd_cluster.get_client_endpoints() + logger.info(f"ETCD endpoint: {etcd_endpoints}") + + # Step 3: Start the frontend with ETCD endpoint + with DynamoFrontendProcess(request, etcd_endpoints) as frontend: + logger.info("Frontend started successfully") + + # Step 4: Start a vLLM worker + with DynamoWorkerProcess(request, etcd_endpoints) as worker: + logger.info("Worker started successfully") + + # Step 5: Send inference request to verify system is working + logger.info("Sending inference request") + result = send_inference_request("What is 2+2? The answer is") + assert ( + "4" in result.lower() or "four" in result.lower() + ), f"Expected '4' or 'four' in response, got: '{result}'" + + logger.info("System is working correctly with single ETCD node") + + # Step 6: Terminate the ETCD node + logger.info("Terminating single ETCD node") + etcd_cluster.stop() + + # Step 7: Wait and verify frontend and worker detect the loss + wait_for_processes_to_terminate( + {"Worker": worker, "Frontend": frontend} + ) diff --git a/tests/fault_tolerance/etcd_ha/utils.py b/tests/fault_tolerance/etcd_ha/utils.py new file mode 100644 index 0000000000..cb92663aa7 --- /dev/null +++ b/tests/fault_tolerance/etcd_ha/utils.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import shutil +import tempfile +import time +from typing import List, Optional + +import pytest +import requests + +from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME +from tests.utils.engine_process import FRONTEND_PORT +from tests.utils.managed_process import ManagedProcess + +logger = logging.getLogger(__name__) + + +class DynamoFrontendProcess(ManagedProcess): + """Process manager for Dynamo frontend with ETCD HA support""" + + def __init__(self, request, etcd_endpoints: list): + command = ["python", "-m", "dynamo.frontend"] + + # Set debug logging and ETCD endpoints + env = os.environ.copy() + env["DYN_LOG"] = "debug" + env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints) + + log_dir = f"{request.node.name}_frontend" + + # Clean up any existing log directory from previous runs + try: + shutil.rmtree(log_dir) + logger.info(f"Cleaned up existing log directory: {log_dir}") + except FileNotFoundError: + pass + + super().__init__( + command=command, + env=env, + display_output=True, + terminate_existing=True, + log_dir=log_dir, + ) + + +class EtcdReplicaServer(ManagedProcess): + """Single ETCD replica server in a cluster""" + + def __init__( + self, + request, + name: str, + client_port: int, + peer_port: int, + initial_cluster: str, + data_dir: str, + log_dir: str, + timeout: int = 30, + ): + self.name = name + self.client_port = client_port + self.peer_port = peer_port + self.data_dir = data_dir + + etcd_env = os.environ.copy() + etcd_env["ETCD_ENDPOINTS"] = "" # Clear any inherited ETCD endpoints + etcd_env["ALLOW_NONE_AUTHENTICATION"] = "yes" + + command = [ + "etcd", + "--name", + name, + "--data-dir", + data_dir, + "--listen-client-urls", + f"http://0.0.0.0:{client_port}", + "--advertise-client-urls", + f"http://localhost:{client_port}", + "--listen-peer-urls", + f"http://0.0.0.0:{peer_port}", + "--initial-advertise-peer-urls", + f"http://localhost:{peer_port}", + "--initial-cluster", + initial_cluster, + "--initial-cluster-state", + "new", + "--initial-cluster-token", + "etcd-cluster", + ] + + super().__init__( + env=etcd_env, + command=command, + timeout=timeout, + display_output=False, + terminate_existing=False, + data_dir=data_dir, + log_dir=log_dir, + ) + + def get_status(self) -> dict: + """Get the status of this ETCD node""" + try: + response = requests.post( + f"http://localhost:{self.client_port}/v3/maintenance/status", + json={}, + timeout=2, + ) + if response.status_code == 200: + return response.json() + except Exception as e: + logger.warning(f"Failed to get status for {self.name}: {e}") + return {} + + def is_leader(self) -> bool: + """Check if this node is the current leader""" + status = self.get_status() + # In etcd v3 API, we check if this member ID matches the leader ID + if status: + member_id = status.get("header", {}).get("member_id", "") + leader_id = status.get("leader", "") + return member_id == leader_id + return False + + +class EtcdCluster: + """Manager for an ETCD cluster with configurable number of replicas""" + + def __init__( + self, + request, + num_replicas: int = 3, + base_client_port: int = 2379, + base_peer_port: int = 12380, + ): + self.request = request + self.num_replicas = num_replicas + self.base_client_port = base_client_port + self.base_peer_port = base_peer_port + self.replicas: List[Optional[EtcdReplicaServer]] = [] + self.data_dirs: List[str] = [] + self.log_base_dir = f"{request.node.name}_etcd_cluster" + + # Clean up any existing log directory + try: + shutil.rmtree(self.log_base_dir) + logger.info(f"Cleaned up existing log directory: {self.log_base_dir}") + except FileNotFoundError: + pass + + os.makedirs(self.log_base_dir, exist_ok=True) + + def start(self): + """Start ETCD cluster with configured number of replicas""" + logger.info(f"Starting {self.num_replicas}-node ETCD cluster") + + # Build initial cluster configuration + initial_cluster_parts = [] + for i in range(self.num_replicas): + name = f"etcd-{i}" + peer_port = self.base_peer_port + i + initial_cluster_parts.append(f"{name}=http://localhost:{peer_port}") + + initial_cluster = ",".join(initial_cluster_parts) + + # Start each replica + for i in range(self.num_replicas): + name = f"etcd-{i}" + client_port = self.base_client_port + i + peer_port = self.base_peer_port + i + data_dir = tempfile.mkdtemp(prefix=f"etcd_{i}_") + log_dir = os.path.join(self.log_base_dir, name) + + self.data_dirs.append(data_dir) + os.makedirs(log_dir, exist_ok=True) + + logger.info( + f"Starting {name} on client port {client_port}, peer port {peer_port}" + ) + + replica = EtcdReplicaServer( + request=self.request, + name=name, + client_port=client_port, + peer_port=peer_port, + initial_cluster=initial_cluster, + data_dir=data_dir, + log_dir=log_dir, + ) + + replica.__enter__() + self.replicas.append(replica) + + logger.info(f"All {self.num_replicas} ETCD replicas started successfully") + + # Wait for cluster to stabilize and elect a leader + self._wait_for_healthy_cluster(timeout=30) + + leader_idx = self.find_leader() + if leader_idx is not None: + logger.info(f"Initial leader elected: etcd-{leader_idx}") + else: + logger.warning("No leader elected yet") + + def _wait_for_healthy_cluster(self, timeout: int = 30): + """Wait for all replicas to be healthy and responsive. + + Args: + timeout: Maximum time to wait in seconds + + Raises: + RuntimeError: If cluster doesn't become healthy within timeout + """ + logger.info("Waiting for all replicas to be healthy...") + start_time = time.time() + + while time.time() - start_time < timeout: + time.sleep(1) + + # Check if all replicas are responding + all_healthy = True + for i, replica in enumerate(self.replicas): + if replica: + status = replica.get_status() + if not status: + logger.debug(f"etcd-{i} not yet responsive") + all_healthy = False + break + + if all_healthy: + logger.info("All replicas are healthy") + return + + raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s") + + def find_leader(self) -> Optional[int]: + """Find which replica is currently the leader""" + for i, replica in enumerate(self.replicas): + if replica and replica.is_leader(): + return i + return None + + def terminate_leader(self) -> Optional[int]: + """Terminate the current leader and return its index""" + leader_idx = self.find_leader() + + if leader_idx is None: + logger.warning("No leader found to terminate") + return None + + logger.info(f"Terminating current leader: etcd-{leader_idx}") + replica = self.replicas[leader_idx] + + if replica: + replica.__exit__(None, None, None) + self.replicas[leader_idx] = None + logger.info(f"Leader etcd-{leader_idx} has been terminated") + + return leader_idx + + def get_client_endpoints(self) -> List[str]: + """Get list of active client endpoints""" + endpoints = [] + for i, replica in enumerate(self.replicas): + if replica: # Only include active replicas + client_port = self.base_client_port + i + endpoints.append(f"http://localhost:{client_port}") + return endpoints + + def stop(self): + """Clean up all replicas and temporary directories""" + logger.info("Cleaning up ETCD cluster") + + # Stop all running replicas + for replica in self.replicas: + if replica: + try: + replica.__exit__(None, None, None) + except Exception as e: + logger.warning(f"Error stopping replica: {e}") + self.replicas = [] + + # Clean up data directories + for data_dir in self.data_dirs: + try: + shutil.rmtree(data_dir) + except Exception as e: + logger.warning(f"Error removing data directory {data_dir}: {e}") + self.data_dirs = [] + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + +def send_inference_request(prompt: str, max_tokens: int = 50) -> str: + """Send a simple inference request to the frontend and return the generated text""" + payload = { + "model": FAULT_TOLERANCE_MODEL_NAME, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, # Make output deterministic + } + + headers = {"Content-Type": "application/json"} + + logger.info(f"Sending inference request: '{prompt}'") + try: + response = requests.post( + f"http://localhost:{FRONTEND_PORT}/v1/completions", + headers=headers, + json=payload, + timeout=round(max_tokens * 0.6), + ) + + if response.status_code == 200: + result = response.json() + text = result.get("choices", [{}])[0].get("text", "") + logger.info(f"Inference generated text: '{text.strip()}'") + return text + else: + pytest.fail( + f"Inference request failed with code {response.status_code}: {response.text}" + ) + except Exception as e: + pytest.fail(f"Inference request failed: {e}") + + +def wait_for_processes_to_terminate( + processes: dict, timeout: int = 30, poll_interval: int = 1 +) -> None: + """ + Wait for multiple processes to terminate and fail if they don't within timeout. + + Args: + processes: Dictionary mapping process names to ManagedProcess instances + timeout: Maximum time to wait in seconds + poll_interval: Time between checks in seconds + + Raises: + pytest.fail: If any process is still running after timeout + """ + logger.info(f"Waiting for {len(processes)} process(es) to terminate") + elapsed = 0 + terminated = {name: False for name in processes} + + while elapsed < timeout: + time.sleep(poll_interval) + elapsed += poll_interval + + # Check each process + for name, process in processes.items(): + if ( + not terminated[name] + and process.proc + and process.proc.poll() is not None + ): + logger.info(f"{name} process has terminated after {elapsed}s") + terminated[name] = True + + # Exit early if all processes have terminated + if all(terminated.values()): + return + + # Check for any processes still running and fail + still_running = [name for name, term in terminated.items() if not term] + if still_running: + pytest.fail( + f"Process(es) still running after {elapsed}s: {', '.join(still_running)}" + )