From 7b550aab4fe0bb7d4fc7ef740d41445c812e5a7e Mon Sep 17 00:00:00 2001 From: WarrenZhu050413 Date: Mon, 19 May 2025 11:33:56 +0800 Subject: [PATCH] Added proactive heartbeat timeout failure propagation (#164) (#188) --- Cargo.toml | 2 + README.md | 32 +++ proto/torchft.proto | 14 ++ src/lib.rs | 101 +++++++- src/lighthouse.rs | 464 +++++++++++++++++++++++++++++++++++-- src/manager.rs | 15 +- torchft/_torchft.pyi | 16 +- torchft/data.py | 10 +- torchft/data_test.py | 2 +- torchft/lighthouse_test.py | 69 ++++++ torchft/manager.py | 209 ++++++++++++++++- torchft/manager_test.py | 148 +++++++++++- train_ddp.py | 2 +- train_ddp_proactive.py | 218 +++++++++++++++++ 14 files changed, 1254 insertions(+), 48 deletions(-) create mode 100644 train_ddp_proactive.py diff --git a/Cargo.toml b/Cargo.toml index 0c6ae6e..ec90c11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,9 @@ slog-stdlog = "4.1.1" stderrlog = "0.6.0" structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } +tokio-stream = {version = "0.1.14", features = ["sync"]} tonic = "0.12.2" +futures-core = "0.3" [build-dependencies] tonic-build = "0.12.2" diff --git a/README.md b/README.md index cb07b47..ff3cfea 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery. +### Proactive Failure Recovery Mode (Experimental) + +You can experiment with proactive failure recovery mode by: + +```sh +export TORCHFT_PROACTIVE_RECOVERY=1 +``` + +With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce. + +You can test this out by running `train_ddp_proactive.py` + +On shell 1 (one replica groups starts initial training): +```sh +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py +``` + +On shell 2 (a second replica group joins): +```sh +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py +``` + +You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing. + ### Example Parameter Server torchft has a fault tolerant parameter server implementation built on it's diff --git a/proto/torchft.proto b/proto/torchft.proto index 7c086eb..cf7c403 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -67,9 +67,17 @@ message LighthouseHeartbeatRequest { message LighthouseHeartbeatResponse {} +message SubscribeFailuresRequest {} + +message FailureNotification { + string replica_id = 1; + string error_message = 2; +} + service LighthouseService { rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse); rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse); + rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification); } message ManagerQuorumRequest { @@ -126,3 +134,9 @@ service ManagerService { rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); rpc Kill(KillRequest) returns (KillResponse); } + +message LighthouseClientRequest { + string replica_id = 1; +} + + diff --git a/src/lib.rs b/src/lib.rs index 32a7a37..5e9c53a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod timeout; use anyhow::Result; use atty::Stream; use core::time::Duration; -use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; +use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError}; use std::cmp; use std::env; use std::sync::Arc; @@ -21,6 +21,7 @@ use std::thread::available_parallelism; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; +use tokio_stream::StreamExt; use tonic::transport::Channel; use tonic::Status; @@ -35,11 +36,13 @@ pub mod torchftpb { use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ - CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, - ManagerQuorumRequest, ShouldCommitRequest, + CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification, + LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest, + SubscribeFailuresRequest, }; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; +use pyo3::{PyRef, PyRefMut}; // Get the number of threads to use for the tokio runtime fn num_threads() -> usize { @@ -290,6 +293,45 @@ struct QuorumResult { heal: bool, } +#[pyclass(unsendable)] +struct FailureStream { + runtime: Arc, + stream: tonic::Streaming, + timeout: Duration, +} + +#[pymethods] +impl FailureStream { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult { + let runtime = slf.runtime.clone(); + let timeout = slf.timeout; + // borrow stream mutably for the whole async block + let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await }; + + match runtime.block_on(fut) { + Ok(Some(Ok(note))) => Ok(FailureNotification { + replica_id: note.replica_id, + error_message: note.error_message, + }), + Ok(Some(Err(status))) => Err(StatusError(status).into()), + Ok(None) => Err(PyStopIteration::new_err(())), + Err(_) => Err(PyTimeoutError::new_err( + "Timeout waiting for failure notification", + )), + } + } +} + +#[pyclass(get_all, set_all)] +#[derive(Clone)] +struct FailureNotification { + replica_id: String, + error_message: String, +} + #[pymethods] impl QuorumResult { #[new] @@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { #[pyclass] struct LighthouseClient { client: LighthouseServiceClient, - runtime: Runtime, + runtime: Arc, } #[pymethods] @@ -487,11 +529,13 @@ impl LighthouseClient { #[new] fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult { py.allow_threads(move || { - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads()) - .thread_name("torchft-lhclnt") - .enable_all() - .build()?; + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads()) + .thread_name("torchft-lhclnt") + .enable_all() + .build()?, + ); let client = runtime .block_on(manager::lighthouse_client_new(addr, connect_timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -586,6 +630,22 @@ impl LighthouseClient { Ok(()) }) } + + #[pyo3(signature = (timeout = Duration::from_secs(5)))] + fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult { + py.allow_threads(move || { + let req = tonic::Request::new(SubscribeFailuresRequest {}); + let response = self + .runtime + .block_on(self.client.clone().subscribe_failures(req)) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(FailureStream { + runtime: self.runtime.clone(), + stream: response.into_inner(), + timeout: timeout, + }) + }) + } } /// LighthouseServer is a GRPC server for the lighthouse service. @@ -610,7 +670,7 @@ struct LighthouseServer { #[pymethods] impl LighthouseServer { - #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] + #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None))] #[new] fn new( py: Python<'_>, @@ -619,10 +679,12 @@ impl LighthouseServer { join_timeout_ms: Option, quorum_tick_ms: Option, heartbeat_timeout_ms: Option, + failure_tick_ms: Option, ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000); + let failure_tick_ms = failure_tick_ms.unwrap_or(1000); py.allow_threads(move || { let rt = tokio::runtime::Builder::new_multi_thread() @@ -638,6 +700,7 @@ impl LighthouseServer { join_timeout_ms: join_timeout_ms, quorum_tick_ms: quorum_tick_ms, heartbeat_timeout_ms: heartbeat_timeout_ms, + failure_tick_ms: failure_tick_ms, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -663,6 +726,22 @@ impl LighthouseServer { self.handle.abort(); }) } + + /// inject_failure broadcasts a failure notification for the given replica. + /// + /// This helper is intended for testing `subscribe_failures` from Python. + #[pyo3(signature = (replica_id))] + fn inject_failure(&self, py: Python<'_>, replica_id: String) { + let lighthouse = self.lighthouse.clone(); + let runtime = &self._runtime; + py.allow_threads(move || { + let _ = runtime.block_on(async { + if let Err(e) = lighthouse.inject_failure(replica_id).await { + eprintln!("Failed to inject failure: {}", e); + } + }); + }); + } } struct StatusError(Status); @@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; Ok(()) diff --git a/src/lighthouse.rs b/src/lighthouse.rs index a576003..063fcd2 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::{Instant, SystemTime}; +use crate::torchftpb::FailureNotification; use anyhow::{anyhow, Result}; use askama::Template; use axum::{ @@ -28,16 +29,23 @@ use tokio::sync::broadcast; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::interval; +use tokio_stream::wrappers::{ + errors::BroadcastStreamRecvError as TokioStreamBroadcastStreamRecvError, BroadcastStream, +}; +use tokio_stream::StreamExt; use tonic::service::Routes; use tonic::transport::server::TcpIncoming; use tonic::transport::Server; use tonic::{Request, Response, Status}; +use futures_core::Stream; +use std::pin::Pin; + use crate::manager::manager_client_new; use crate::torchftpb::{ lighthouse_service_server::{LighthouseService, LighthouseServiceServer}, KillRequest, LighthouseHeartbeatRequest, LighthouseHeartbeatResponse, LighthouseQuorumRequest, - LighthouseQuorumResponse, Quorum, QuorumMember, + LighthouseQuorumResponse, Quorum, QuorumMember, SubscribeFailuresRequest, }; #[derive(Clone)] @@ -47,14 +55,28 @@ struct QuorumMemberDetails { } struct State { - channel: broadcast::Sender, + quorum_channel: broadcast::Sender, + // Tracks currently active participants in the process of forming a quorum. + // Replicas are added upon receiving a `LighthouseQuorumRequest`. + // Replicas are cleared after a quorum is successfully formed OR + // removed by `_failure_tick` if their heartbeat expires. participants: HashMap, prev_quorum: Option, quorum_id: i64, - // heartbeat information - // replica_id -> last heartbeat + // Stores the last heartbeat time for each replica ID. + // Replicas are added/updated upon receiving `LighthouseHeartbeatRequest` or `LighthouseQuorumRequest`. + // Replicas are removed by `_failure_tick` if their heartbeat expires and a failure notification is sent. heartbeats: HashMap, + + // Stores the timestamp of when a replica was first detected as failed (heartbeat expired). + // This is used to ensure only one `FailureNotification` is sent per failure event. + // Replicas are added by `_failure_tick` upon detecting a new failure. + // Replicas are removed by `_failure_tick` if a subsequent heartbeat is received (signifying recovery). + failures: HashMap, + + // Broadcast channel for sending failure notifications to subscribers. + pub failure_channel: broadcast::Sender, } pub struct Lighthouse { @@ -83,7 +105,7 @@ impl ChangeLogger { } } -#[derive(StructOpt, Debug)] +#[derive(StructOpt, Debug, Clone)] #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. @@ -120,6 +142,13 @@ pub struct LighthouseOpt { help = "How long to wait for a heartbeat before considering a replica dead." )] pub heartbeat_timeout_ms: u64, + + #[structopt( + long = "failure_tick_ms", + default_value = "1000", + help = "How frequently to check for failures." + )] + pub failure_tick_ms: u64, } fn quorum_changed(a: &Vec, b: &Vec) -> bool { @@ -265,14 +294,45 @@ impl Lighthouse { let listener = tokio::net::TcpListener::bind(&opt.bind).await?; let (tx, _) = broadcast::channel(16); + let (failure_tx, failure_rx) = broadcast::channel::(16); + + // Create a task to monitor the failure channel + let mut failure_rx_cloned: broadcast::Receiver = + failure_rx.resubscribe(); + tokio::spawn(async move { + use tokio::time::{sleep, Duration}; + info!("Starting permanent failure channel subscriber"); + loop { + match failure_rx_cloned.recv().await { + Ok(note) => { + info!( + "Healthy replicas received failure notification for {} with error message: {}", + note.replica_id, + note.error_message + ); + } + Err(e) => { + error!("Healthy replicas error: {}", e); + // If the channel is closed, break the loop + if matches!(e, tokio::sync::broadcast::error::RecvError::Closed) { + break; + } + } + } + sleep(Duration::from_millis(100)).await; // Prevent thrashing if there are continuous errors + } + info!("Permanent failure channel subscriber exiting"); + }); Ok(Arc::new(Self { state: Mutex::new(State { participants: HashMap::new(), - channel: tx, + quorum_channel: tx, prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: failure_tx, }), opt: opt, local_addr: listener.local_addr()?, @@ -326,7 +386,7 @@ impl Lighthouse { state.prev_quorum = Some(quorum.clone()); state.participants.clear(); - match state.channel.send(quorum) { + match state.quorum_channel.send(quorum) { Ok(_) => (), Err(e) => error!("failed to send quorum {}", e), } @@ -391,6 +451,76 @@ impl Lighthouse { .map_err(|e| e.into()) } + async fn _run_failure_tick(self: Arc) -> Result<()> { + let mut interval = interval(Duration::from_millis(self.opt.failure_tick_ms)); + loop { + interval.tick().await; // Wait for the next tick + let mut state = self.state.lock().await; + self.clone()._failure_tick(&mut state)?; + } + } + + fn _failure_tick(self: Arc, state: &mut State) -> Result<()> { + let now = Instant::now(); + let timeout = Duration::from_millis(self.opt.heartbeat_timeout_ms); + + // Use a temporary list to collect replica IDs to remove from heartbeats + // to avoid modifying the map while iterating over it. + let mut failed_replica_ids_to_remove_from_heartbeats = Vec::new(); + let mut failure_detected = false; + + for (replica_id, last_heartbeat) in state.heartbeats.iter() { + if now.duration_since(*last_heartbeat) > timeout { + if !state.failures.contains_key(replica_id) { + info!( + "Replica {} timed out (last heartbeat: {:?}), sending failure notification.", + replica_id, + last_heartbeat + ); + if let Err(e) = state.failure_channel.send(FailureNotification { + replica_id: replica_id.clone(), + error_message: "heartbeat timeout".to_string(), + }) { + error!( + "Failed to send failure notification for {}: {} (receiver count: {})", + replica_id, + e, + state.failure_channel.receiver_count() + ); + } else { + failure_detected = true; // Set flag if notification sent successfully + } + // Record failure information + state.failures.insert(replica_id.clone(), now); + state.participants.remove(replica_id); + failed_replica_ids_to_remove_from_heartbeats.push(replica_id.clone()); + } + } else { + // If the participant sends heartbeat again, remove it from failures. + if state.failures.remove(replica_id).is_some() { + info!("Replica {} recovered from failure.", replica_id); + } + } + } + + // Remove failed replicas from heartbeats + for replica_id in failed_replica_ids_to_remove_from_heartbeats { + state.heartbeats.remove(&replica_id); + info!( + "Removed replica {} from heartbeats and participants due to timeout.", + replica_id + ); + } + + // If a new failure was detected and broadcasted, reset participants to restart quorum formation + if failure_detected { + info!("New failure detected, resetting all participants for quorum formation."); + state.participants.clear(); + } + + Ok(()) + } + pub async fn run(self: Arc) -> Result<()> { let mut set = JoinSet::new(); @@ -398,6 +528,8 @@ impl Lighthouse { set.spawn(self.clone()._run_grpc()); + set.spawn(self.clone()._run_failure_tick()); + while let Some(res) = set.join_next().await { res??; } @@ -469,6 +601,18 @@ impl Lighthouse { Ok(()) } + + pub async fn inject_failure(self: Arc, replica_id: String) -> Result<()> { + let state = self.state.lock().await; + state + .failure_channel + .send(FailureNotification { + replica_id, + error_message: "injected failure".to_string(), + }) + .map_err(|e| anyhow!("Failed to send failure notification: {}", e))?; + Ok(()) + } } #[tonic::async_trait] @@ -502,7 +646,7 @@ impl LighthouseService for Arc { member: requester.clone(), }, ); - let rx = state.channel.subscribe(); + let rx = state.quorum_channel.subscribe(); // proactively run quorum tick self.clone() @@ -556,6 +700,35 @@ impl LighthouseService for Arc { let reply = LighthouseHeartbeatResponse {}; Ok(Response::new(reply)) } + + type SubscribeFailuresStream = + Pin> + Send + 'static>>; + + async fn subscribe_failures( + &self, + _req: Request, + ) -> Result, Status> { + // clone a receiver + let rx = { + let state = self.state.lock().await; + let receiver_count = state.failure_channel.receiver_count(); + info!( + "subscribe_failures: Creating new subscriber (current count: {})", + receiver_count + ); + state.failure_channel.subscribe() + }; + + // Wrap the receiver; map its *internal* error into `tonic::Status` + let stream = BroadcastStream::new(rx).filter_map(|res| match res { + Ok(note) => Some(Ok(note)), + Err(TokioStreamBroadcastStreamRecvError::Lagged(n)) => Some(Err( + Status::resource_exhausted(format!("client lagged {n} messages")), + )), + }); + + Ok(Response::new(Box::pin(stream))) + } } #[derive(Template)] @@ -605,6 +778,8 @@ where mod tests { use super::*; use std::ops::Sub; + use tokio::sync::broadcast::error::RecvError as TokioBroadcastRecvError; + use tokio::time::timeout as tokio_timeout; use tonic::transport::Channel; @@ -624,14 +799,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -703,14 +881,17 @@ mod tests { join_timeout_ms: 0, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -789,14 +970,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -879,14 +1063,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -974,6 +1161,7 @@ mod tests { join_timeout_ms: 1, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let lighthouse = Lighthouse::new(opt).await?; @@ -1020,14 +1208,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -1103,6 +1294,185 @@ mod tests { assert!(quorum_changed(&a, &c)); } + // Helper to create a default QuorumMember for tests + fn test_quorum_member(replica_id: &str) -> QuorumMember { + QuorumMember { + replica_id: replica_id.to_string(), + address: format!("addr_{}", replica_id), + store_address: format!("store_{}", replica_id), + step: 1, + world_size: 2, // Assuming 2 for this test context + shrink_only: false, + data: String::new(), + commit_failures: 0, + } + } + + /// Test that `_failure_tick` correctly identifies timed-out replicas, + /// broadcasts a failure notification exactly once per failure, and + /// cleans up the replica from `heartbeats` and `participants` while + /// adding it to `failures`. Subsequent ticks should not re-notify + /// or change the state for an already failed replica. + #[tokio::test] + async fn test_failure_tick_single_notification_and_cleanup() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, // Not relevant for this test + quorum_tick_ms: 10, // Not directly relevant but keep it small + heartbeat_timeout_ms: 100, // Reasonably short for testing + failure_tick_ms: 50, // How often _failure_tick would be called + }; + let lighthouse = Lighthouse::new(opt.clone()).await?; + + let mut failure_rx = { + let state_guard = lighthouse.state.lock().await; + state_guard.failure_channel.subscribe() + }; + + let replica_id_failing = "failing_one"; + + let now = Instant::now(); + // Ensure expired_time is definitively older than heartbeat_timeout_ms + let expired_time = now - Duration::from_millis(opt.heartbeat_timeout_ms * 2); + + // Setup initial state: one about to fail + { + let mut state_guard = lighthouse.state.lock().await; + let state = &mut *state_guard; + + // Failing replica + state.participants.insert( + replica_id_failing.to_string(), + QuorumMemberDetails { + joined: now, // Joined time doesn't prevent failure due to heartbeat + member: test_quorum_member(replica_id_failing), + }, + ); + state + .heartbeats + .insert(replica_id_failing.to_string(), expired_time); + } + + // --- First call to _failure_tick --- + // This call should detect the failure, send a notification, and update state. + { + let mut state_guard = lighthouse.state.lock().await; + lighthouse.clone()._failure_tick(&mut *state_guard)?; + } + + // Assertions after first tick + // 1. Check notification for failing_replica + match tokio_timeout( + Duration::from_millis(opt.failure_tick_ms * 2), + failure_rx.recv(), + ) + .await + { + Ok(Ok(notification)) => { + assert_eq!( + notification.replica_id, replica_id_failing, + "Notification should be for the failing replica" + ); + } + Ok(Err(TokioBroadcastRecvError::Lagged(n))) => { + panic!( + "Broadcast channel lagged by {} messages, missed the failure notification", + n + ); + } + Ok(Err(TokioBroadcastRecvError::Closed)) => { + panic!("Broadcast channel closed unexpectedly after first tick"); + } + Err(_) => panic!( + "Did not receive failure notification for {} in time", + replica_id_failing + ), + } + + // 2. Verify state changes + { + let state_guard = lighthouse.state.lock().await; + let state = &*state_guard; + + // Failing replica assertions + assert!( + state.failures.contains_key(replica_id_failing), + "{} should be in failures map", + replica_id_failing + ); + assert!( + !state.heartbeats.contains_key(replica_id_failing), + "{} should be removed from heartbeats", + replica_id_failing + ); + assert!( + !state.participants.contains_key(replica_id_failing), + "{} should be removed from participants", + replica_id_failing + ); + } + + // --- Second call to _failure_tick --- + // This call should *not* detect a *new* failure for the same replica + // and should not send another notification. + { + let mut state_guard = lighthouse.state.lock().await; + lighthouse.clone()._failure_tick(&mut *state_guard)?; + } + + // Assertions after second tick + // 1. No new notification for failing_replica + match tokio_timeout( + Duration::from_millis(opt.failure_tick_ms * 2), + failure_rx.recv(), + ) + .await + { + Ok(Ok(notification)) => { + panic!( + "Received unexpected second failure notification for {}", + notification.replica_id + ); + } + Ok(Err(TokioBroadcastRecvError::Lagged(n))) => { + // This might happen if the test environment is slow and ticks are processed faster than receives. + // For this specific assertion (no *new* message), lagging is an acceptable outcome. + info!("Broadcast channel lagged by {} messages on second check, implies no new distinct message.", n); + } + Ok(Err(TokioBroadcastRecvError::Closed)) => { + // Channel might close if sender is dropped, implies no new message. + info!("Broadcast channel closed on second check, implies no new distinct message."); + } + Err(_) => { + // Expected: Timeout, meaning no new message was received for failing_replica. + } + } + + // 2. Verify state remains consistent for failing_replica + { + let state_guard = lighthouse.state.lock().await; + let state = &*state_guard; + + assert!( + state.failures.contains_key(replica_id_failing), + "{} should remain in failures map", + replica_id_failing + ); + assert!( + !state.heartbeats.contains_key(replica_id_failing), + "{} should remain removed from heartbeats", + replica_id_failing + ); + assert!( + !state.participants.contains_key(replica_id_failing), + "{} should remain removed from participants", + replica_id_failing + ); + } + Ok(()) + } + #[tokio::test] async fn test_lighthouse_join_during_shrink() -> Result<()> { fn create_member(id: &str, addr_num: &str, step: i64, shrink_only: bool) -> QuorumMember { @@ -1130,6 +1500,7 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; // Start the lighthouse service @@ -1237,6 +1608,7 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; // Start the lighthouse service @@ -1281,4 +1653,70 @@ mod tests { lighthouse_task.abort(); Ok(()) } + + #[tokio::test] + async fn test_lighthouse_subscribe_failures_basic() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + }; + + let lighthouse = Lighthouse::new(opt).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + let mut client = lighthouse_client_new(lighthouse.address()).await?; + let request = tonic::Request::new(SubscribeFailuresRequest {}); + client.subscribe_failures(request).await?; + + lighthouse_task.abort(); + Ok(()) + } + + #[tokio::test] + async fn test_subscribe_failures_delivers_notifications() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + }; + let lighthouse = Lighthouse::new(opt).await?; + let mut client = lighthouse_client_new(lighthouse.address()).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + // 1. Subscribe with a deadline + let mut req = tonic::Request::new(SubscribeFailuresRequest {}); + req.set_timeout(Duration::from_secs(5)); + let mut stream = client.subscribe_failures(req).await?.into_inner(); + + // 2. Trigger a failure notification + { + let state = lighthouse.state.lock().await; + state + .failure_channel + .send(FailureNotification { + replica_id: "replica_id_X".into(), + error_message: "injected failure".to_string(), + }) + .unwrap(); + } + + // 3. Ensure we receive it + match stream.next().await { + Some(Ok(note)) => { + assert_eq!(note.replica_id, "replica_id_X"); + assert_eq!(note.error_message, "injected failure"); + } + other => panic!("Expected notification, got {:?}", other), + } + + lighthouse_task.abort(); + Ok(()) + } } diff --git a/src/manager.rs b/src/manager.rs index e28cbeb..affda55 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -15,6 +15,7 @@ use tokio::sync::broadcast; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::sleep; +use tokio::time::timeout as tokio_timeout; use tonic::transport::server::TcpIncoming; use tonic::transport::Channel; use tonic::transport::Server; @@ -54,7 +55,7 @@ macro_rules! info_with_replica { struct ManagerState { checkpoint_metadata: HashMap, - channel: broadcast::Sender, + quorum_channel: broadcast::Sender, participants: HashMap, should_commit_channel: broadcast::Sender, @@ -126,7 +127,7 @@ impl Manager { heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_metadata: HashMap::new(), - channel: tx, + quorum_channel: tx, participants: HashMap::new(), should_commit_channel: should_commit_tx, @@ -204,7 +205,7 @@ impl Manager { }); lighthouse_request.set_timeout(timeout); - let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request)) + let response = tokio_timeout(timeout, client.quorum(lighthouse_request)) .await .unwrap_or_else(|e| { Err(Status::cancelled(format!( @@ -217,7 +218,7 @@ impl Manager { info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp); state - .channel + .quorum_channel .send( resp.quorum .ok_or_else(|| Status::internal("missing quorum"))?, @@ -273,7 +274,7 @@ impl ManagerService for Arc { }; // TODO check step state.participants.insert(group_rank, member.clone()); - let rx = state.channel.subscribe(); + let rx = state.quorum_channel.subscribe(); self._run_quorum(&mut state, member, timeout).await?; @@ -550,6 +551,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -597,6 +599,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -652,6 +655,7 @@ mod tests { min_replicas: 2, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -724,6 +728,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 9614d1b..01d1f9e 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,8 +11,8 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - commit_failures: int, init_sync: bool = True, + commit_failures: int = 0, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( @@ -60,9 +60,11 @@ class LighthouseServer: join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None, heartbeat_timeout_ms: Optional[int] = None, + failure_tick_ms: Optional[int] = None, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... + def inject_failure(self, replica_id: str) -> None: ... @dataclass class QuorumMember: @@ -85,6 +87,14 @@ class Quorum: participants: List[QuorumMember] created: Timestamp +@dataclass +class FailureNotification: + replica_id: str + +class FailureStream: + def __iter__(self) -> "FailureStream": ... + def __next__(self) -> FailureNotification: ... + @dataclass class LighthouseClient: addr: str @@ -106,3 +116,7 @@ class LighthouseClient: replica_id: str, timeout: timedelta = timedelta(seconds=5), ) -> None: ... + def subscribe_failures( + self, + timeout: timedelta = timedelta(seconds=5), + ) -> FailureStream: ... diff --git a/torchft/data.py b/torchft/data.py index 02e5b3b..77ec1de 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -38,15 +38,15 @@ class DistributedSampler(data.distributed.DistributedSampler): This will shard the input dataset into ``num_replicas*num_replica_group`` number of shards. - Each shard rank is calculated via: ``rank + num_replicas*replica_rank`` + Each shard rank is calculated via: ``rank + num_replicas*replica_group_id`` - num_replicas and replica_rank must be the same on all workers. + num_replicas and replica_group_id must be the same on all workers. """ def __init__( self, dataset: data.Dataset, - replica_rank: int, + replica_group_id: int, num_replica_groups: int, group_rank: Optional[int] = None, num_replicas: Optional[int] = None, @@ -55,7 +55,7 @@ def __init__( """ Args: data: the dataset to use - replica_rank: the group ID (0-num_replica_groups) to use for this shard of data. + replica_group_id: the group ID (0-num_replica_groups) to use for this shard of data. num_replica_groups: the max number of global replica groups rank: the local group rank num_replicas: the local group world size @@ -65,7 +65,7 @@ def __init__( if num_replicas is None: num_replicas = dist.get_world_size() - self.global_rank: int = group_rank + num_replicas * replica_rank + self.global_rank: int = group_rank + num_replicas * replica_group_id self.global_world_size: int = num_replicas * num_replica_groups super().__init__( diff --git a/torchft/data_test.py b/torchft/data_test.py index 8dae190..5b7c6b6 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -27,7 +27,7 @@ def test_distributed_sampler(self) -> None: dataset = DummyDataset(1000) sampler = DistributedSampler( dataset, - replica_rank=1, + replica_group_id=1, num_replica_groups=2, group_rank=3, num_replicas=4, diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 067a622..bbe3a97 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -155,3 +155,72 @@ def test_heartbeat_round_trip(self) -> None: finally: lighthouse.shutdown() + + def test_subscribe_failures(self) -> None: + """Test that subscribe_failures can be called without raising an exception.""" + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + try: + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + stream = client.subscribe_failures(timeout=timedelta(milliseconds=100)) + finally: + lighthouse.shutdown() + + def test_subscribe_failures_notification(self) -> None: + """Test that failure notifications are delivered to subscribers.""" + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + try: + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + stream = client.subscribe_failures(timeout=timedelta(seconds=1)) + lighthouse.inject_failure("nodeX") + note = next(stream) + assert note.replica_id == "nodeX" + finally: + lighthouse.shutdown() + + def test_inject_failure(self) -> None: + """Test that inject failure delivers a failure notification to subscribers""" + # Start a lighthouse server + server = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + print(f"Server address: {server.address()}") + + # Create a client to subscribe to failures + client = LighthouseClient(server.address(), timedelta(seconds=5)) + failure_stream = client.subscribe_failures(timedelta(seconds=5)) + + # Inject a failure + replica_id = "test_replica" + print(f"Injecting failure for replica: {replica_id}") + server.inject_failure(replica_id) + + # Wait a bit for the notification to be processed + time.sleep(1) + + # Try to get the failure notification + try: + notification = next(failure_stream) + print( + f"Received failure notification for replica: {notification.replica_id}" + ) + assert notification.replica_id == replica_id, "Received wrong replica_id" + print("Test passed!") + except Exception as e: + print(f"Error: {e}") + + # Clean up + server.shutdown() diff --git a/torchft/manager.py b/torchft/manager.py index 2c1c640..ae48cfd 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -24,25 +24,29 @@ and Hybrid FSDP. """ - import concurrent.futures import logging +import multiprocessing import os import socket +import threading +import time import traceback import uuid from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from datetime import timedelta from enum import Enum +from multiprocessing.connection import Connection from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import ReduceOp, TCPStore -from torchft._torchft import ManagerClient, ManagerServer +from torchft._torchft import LighthouseClient, ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout +from torchft.multiprocessing import _MonitoredPipe if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -103,6 +107,7 @@ def __init__( timeout: timedelta = timedelta(seconds=60), quorum_timeout: timedelta = timedelta(seconds=60), connect_timeout: timedelta = timedelta(seconds=60), + proactive_recovery_subscribe_timeout: timedelta = timedelta(milliseconds=100), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, @@ -116,6 +121,7 @@ def __init__( checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, init_sync: bool = True, max_retries: Optional[int] = None, + proactive_recovery: bool = False, ) -> None: """ Args: @@ -166,6 +172,9 @@ def __init__( self._timeout = timeout self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout + self._proactive_recovery_subscribe_timeout = ( + proactive_recovery_subscribe_timeout + ) self._replica_world_size_mode = world_size_mode self._init_sync = init_sync self._max_retries = max_retries @@ -187,9 +196,7 @@ def __init__( self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = ( checkpoint_transport ) - self._executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="async_quorum" - ) + self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="") self._quorum_future: Optional[concurrent.futures.Future] = None self._store = TCPStore( @@ -205,12 +212,57 @@ def __init__( torch.cuda.Stream() if torch.cuda.is_available() else None ) + lighthouse_addr: Optional[str] = lighthouse_addr + if os.environ.get("TORCHFT_LIGHTHOUSE") is not None: + lighthouse_addr = ( + lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] + ) # Else error in tests, since TORCHFT_LIGHTHOUSE may not be set + + self._proactive_recovery = proactive_recovery or int( + os.environ.get("TORCHFT_PROACTIVE_RECOVERY", 0) + ) + + if lighthouse_addr is not None and self._proactive_recovery: + ctx = multiprocessing.get_context("spawn") + error_local, error_remote = ctx.Pipe() + self._error_pipe = _MonitoredPipe(error_local) + self._error_remote = _MonitoredPipe(error_remote) + self._failure_listener_stop_event = ctx.Event() + + self._failure_listener_process = ctx.Process( + target=_failure_listener_process_main, + args=( + lighthouse_addr, + self._connect_timeout, + self._failure_listener_stop_event, + error_remote, + self._proactive_recovery_subscribe_timeout, + ), + daemon=True, + ) + self._failure_listener_process.start() + else: + self._failure_listener_process = None + self._error_pipe = None + self._failure_listener_stop_event = None + + # Initialize and start the error processing thread if the listener process is active + self._error_processor_thread: Optional[threading.Thread] = None + self._error_processor_stop_event: Optional[threading.Event] = None + if self._failure_listener_process is not None: + self._error_processor_stop_event = threading.Event() + self._error_processor_thread = threading.Thread( + target=self._error_processor_loop, + name="TorchFTErrorProcessor", + daemon=True, + ) + self._error_processor_thread.start() + if self._group_rank == 0: if port is None: port = int(os.environ.get(MANAGER_PORT_ENV, 0)) bind = f"[::]:{port}" - lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] # We need a unique identifier in the case that a worker restarts quickly and # replaces the previous worker with the same ID. @@ -219,6 +271,7 @@ def __init__( replica_id = new_uuid else: replica_id = f"{replica_id}:{new_uuid}" + self._manager = ManagerServer( replica_id=replica_id, lighthouse_addr=lighthouse_addr, @@ -229,13 +282,11 @@ def __init__( heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, ) - self._store.set(MANAGER_ADDR_KEY, self._manager.address()) self._store.set(REPLICA_ID_KEY, replica_id) addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8") self._client = ManagerClient(addr, connect_timeout=connect_timeout) - replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8") self._logger = _ManagerLogger( manager=self, replica_id=replica_id or "", group_rank=group_rank @@ -258,13 +309,96 @@ def set_state_dict_fns( self._load_state_dict = load_state_dict self._user_state_dict = state_dict + def _error_handler(self, err): + self._logger.info(f"Received error: {err}") + self.report_error(err) + self._pg.abort() + + def _error_processor_loop(self) -> None: + """Continuously checks the error pipe from the listener process and reports errors.""" + assert ( + self._error_pipe is not None + ), "Error pipe must be initialized for error processor loop." + assert ( + self._error_processor_stop_event is not None + ), "Stop event must be initialized for error processor loop." + + try: + while not self._error_processor_stop_event.is_set(): + try: + item = self._error_pipe.recv(0.1) + except TimeoutError: + continue + except OSError: + break + except Exception as e: + self._error_handler(e) + finally: + pass + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. """ - self._checkpoint_transport.shutdown(wait=wait) if self._manager is not None: self._manager.shutdown() + + # Stop the error processor thread first + if ( + self._error_processor_thread is not None + and self._error_processor_stop_event is not None + ): + self._logger.info("Setting error processor thread stop event") + self._error_processor_stop_event.set() + if wait: + self._logger.info("Waiting for error processor thread to complete") + try: + self._error_processor_thread.join(timeout=5) # Short timeout + if self._error_processor_thread.is_alive(): + self._logger.warn( + "Error processor thread did not terminate in time." + ) + else: + self._logger.info("Error processor thread shutdown completed.") + except Exception as e: + self._logger.warn(f"Error waiting for error processor thread: {e}") + + # Stop the failure listener process if it exists + if ( + hasattr(self, "_failure_listener_process") + and self._failure_listener_process is not None + ): + self._logger.info("Setting failure listener stop event for process") + if ( + hasattr(self, "_failure_listener_stop_event") + and self._failure_listener_stop_event is not None + ): + self._failure_listener_stop_event.set() + + if wait: + self._logger.info("Waiting for failure listener process to complete") + try: + self._failure_listener_process.join(timeout=10) # Process join + if self._failure_listener_process.is_alive(): + self._logger.warn( + "Failure listener process did not terminate, attempting to terminate." + ) + self._failure_listener_process.terminate() # Force terminate if join times out + self._failure_listener_process.join( + timeout=1 + ) # Wait for terminate + else: + self._logger.info("Failure listener process shutdown completed") + except Exception as e: + self._logger.warn( + f"Error waiting for/terminating failure listener process: {e}" + ) + + # Clean up pipe + if hasattr(self, "_error_pipe") and self._error_pipe is not None: + self._error_pipe.close() + + self._checkpoint_transport.shutdown(wait=wait) self._executor.shutdown(wait=wait) def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: @@ -824,3 +958,60 @@ def warn(self, msg: str) -> None: def exception(self, msg: str) -> None: self._logger.exception(f"{self.prefix()} {msg}") + + +def _failure_listener_process_main( + lighthouse_addr_str: Optional[str], + connect_timeout: timedelta, + stop_event: multiprocessing.Event, + error_pipe: Connection, + subscribe_timeout: timedelta = timedelta(milliseconds=100), +): + """ + Background process that monitors lighthouse for failures through gRPC stream (with an iterator interface) and reports them via error_pipe. + """ + if not lighthouse_addr_str: + return + + while not stop_event.is_set(): + try: + lighthouse_client = LighthouseClient( + lighthouse_addr_str, connect_timeout=connect_timeout + ) + stream = lighthouse_client.subscribe_failures(timeout=subscribe_timeout) + while not stop_event.is_set(): + try: + note = next( + stream + ) # This will block until a new item or timeout if stream supports it + if note: + if stop_event.is_set(): + break + error = Exception( + f"Peer failure detected in listener process: replica {note.replica_id} has failed" + ) + error_pipe.send(ExceptionWithTraceback(error)) + except StopIteration: + # Stream has ended, break out to outer loop to reconnect + if not stop_event.is_set(): + logging.warning( + "Failure Listener: Stream ended unexpectedly, attempting to reconnect..." + ) + break # Break the inner loop to reconnect + else: + break + except Exception as e_stream: + if not stop_event.is_set(): + continue # Break due to subscribe_timeout. Allows the process to check stop_event again. + else: + break + if stop_event.is_set(): + break + time.sleep(0.01) # Prevent CPU thrashing + except Exception as e_outer: + if not stop_event.is_set(): + logging.warning( + f"Failure Listener: Connection error: {e_outer}, retrying in 1 second..." + ) + time.sleep(1) + pass diff --git a/torchft/manager_test.py b/torchft/manager_test.py index bb058e4..2fb0373 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -5,18 +5,34 @@ # LICENSE file in the root directory of this source tree. import concurrent +import multiprocessing +import time +from dataclasses import dataclass from datetime import timedelta from typing import Optional from unittest import TestCase from unittest.mock import MagicMock, create_autospec, patch import torch +import torch.distributed as dist from torch.distributed import TCPStore -from torchft._torchft import QuorumResult +from torchft._torchft import ( + FailureStream, + LighthouseClient, + LighthouseServer, + QuorumResult, +) from torchft.checkpointing.transport import CheckpointTransport -from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode -from torchft.process_group import ProcessGroup, _DummyWork +from torchft.manager import ( + MANAGER_ADDR_KEY, + REPLICA_ID_KEY, + ExceptionWithTraceback, + Manager, + WorldSizeMode, + _failure_listener_process_main, +) +from torchft.process_group import ProcessGroup, ProcessGroupGloo, _DummyWork def mock_should_commit( @@ -43,6 +59,7 @@ def _create_manager( timeout: timedelta = timedelta(seconds=10), init_sync: bool = True, max_retries: Optional[int] = None, + proactive_recovery: bool = False, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -72,6 +89,7 @@ def _create_manager( timeout=timeout, init_sync=init_sync, max_retries=max_retries, + proactive_recovery=proactive_recovery, ) self.manager = manager return manager @@ -773,3 +791,127 @@ def test_max_retries(self, client_mock: MagicMock) -> None: # This should succeed and reset the counter self.assertTrue(manager.should_commit()) self.assertEqual(manager._commit_failures, 0) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_error_handler(self, client_mock: MagicMock) -> None: + """Test that the Manager correctly processes exceptions sent from the failure_listener_process.""" + # Create a manager + manager = self._create_manager() + + # Create an exception simulating what would be sent from _failure_listener_process_main + error = Exception("Peer failure detected: replica failed_replica has failed") + exception = ExceptionWithTraceback(error) + + # Directly test the error handling mechanism + manager._error_handler(error) + + # Verify the error was properly processed + captured_error = manager.errored() + self.assertIsNotNone(captured_error) + self.assertEqual(str(captured_error.original_exception), str(error)) + + def test_direct_error_pipe(self) -> None: + """Test sending an exception to the Manager's _error_pipe.""" + # Create a manager with proactive_recovery=True to ensure it has an error pipe + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + proactive_recovery=True, + ) + + # Make sure the error pipe is created + self.assertIsNotNone(manager._error_pipe, "Manager should have an error pipe") + time.sleep(1) + # Create a mock error message + mock_error_msg = "Test failure detected from direct pipe test" + test_exception = Exception(mock_error_msg) + + # Create an ExceptionWithTraceback and send it through the pipe + exc_with_tb = ExceptionWithTraceback(test_exception) + manager._error_remote.send(exc_with_tb) + + # Wait a short time for the error processor thread to process the message + time.sleep(1) + + # Verify that the error was properly processed by the Manager + error_obj = manager.errored() + self.assertIsNotNone( + error_obj, "Error should have been captured by the Manager" + ) + + # Clean up + manager.shutdown(wait=True) + + def test_manager_failure_e2e(self) -> None: + """Test that the Manager correctly handles errors from the failure_listener_process.""" + # Create a manager with proactive_recovery=True to ensure it has an error pipe + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + proactive_recovery=True, + ) + + time.sleep(1.5) + + failed_replica_id = "failed_replica" + lighthouse.inject_failure(failed_replica_id) + + time.sleep(1.5) # Prevent flakyness + error_obj = manager.errored() + + # Verify that the manager received the error notification + self.assertIsNotNone(error_obj, "Manager should have captured the failure") + self.assertIn( + failed_replica_id, + str(error_obj.original_exception), + f"Error should mention the failed replica: {error_obj.original_exception}", + ) + + # Clean up resources + manager.shutdown(wait=True) diff --git a/train_ddp.py b/train_ddp.py index fd79b8a..96c2c13 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -51,7 +51,7 @@ def main() -> None: # majority of groups will be available so few batches will be dropped. sampler = DistributedSampler( trainset, - replica_group=REPLICA_GROUP_ID, + replica_group_id=REPLICA_GROUP_ID, num_replica_groups=NUM_REPLICA_GROUPS, group_rank=0, # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. diff --git a/train_ddp_proactive.py b/train_ddp_proactive.py new file mode 100644 index 0000000..3d0002c --- /dev/null +++ b/train_ddp_proactive.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +import time +from datetime import timedelta + +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) +os.environ["CUDA_VISIBLE_DEVICES"] = str(REPLICA_GROUP_ID % 4) +os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) + +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import ( + DistributedDataParallel, + DistributedSampler, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, +) +from torchft.checkpointing.pg_transport import PGTransport + +logging.basicConfig(level=logging.INFO) + + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group_id=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + group_rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + shuffle=True, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=64, num_workers=2, sampler=sampler + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=30), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + sort_by_keyword = "self_" + device + "_time_total" + + def trace_handler(p): + output = p.key_averages().table( + sort_by=sort_by_keyword, + row_limit=100, + ) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + + # You can use an epoch based training but with faults it's easier to use step + # based training. + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + ) + + prof.start() + while True: + for i, (inputs, labels) in enumerate(trainloader): + prof.step() + + time.sleep(0.5) # Else each iteration runs too quickly + + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + optimizer.zero_grad() + + out = m(inputs) + criterion = nn.CrossEntropyLoss() + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + if manager.current_step() == 3: + if REPLICA_GROUP_ID == 0: + manager.shutdown() + exit(0) + # If proactive recovery, then the surviving process will reconfigure + # If not proactive recovery, then the surviving process will wait until timeout + + test_tensor = torch.tensor([1.0]).to(device) + manager.allreduce(test_tensor) + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + optimizer.step() + + if manager.current_step() % 100 == 0: + print(f"[{manager.current_step()}] loss = {loss.item()}") + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + if manager.current_step() >= 10000: + # complete training + prof.stop() + exit() + + +if __name__ == "__main__": + main()