Skip to content

Commit daa3adf

Browse files
Added proactive heartbeat timeout failure propagation (#164) (#188)
1 parent b84c5a6 commit daa3adf

14 files changed

+1255
-49
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ slog-stdlog = "4.1.1"
2121
stderrlog = "0.6.0"
2222
structopt = "0.3.26"
2323
tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] }
24+
tokio-stream = {version = "0.1.14", features = ["sync"]}
2425
tonic = "0.12.2"
26+
futures-core = "0.3"
2527

2628
[build-dependencies]
2729
tonic-build = "0.12.2"

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast
246246

247247
By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery.
248248

249+
### Proactive Failure Recovery Mode (Experimental)
250+
251+
You can experiment with proactive failure recovery mode by:
252+
253+
```sh
254+
export TORCHFT_PROACTIVE_RECOVERY=1
255+
```
256+
257+
With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.
258+
259+
You can test this out by running `train_ddp_proactive.py`
260+
261+
On shell 1 (one replica groups starts initial training):
262+
```sh
263+
export REPLICA_GROUP_ID=0
264+
export NUM_REPLICA_GROUPS=2
265+
export TORCHFT_PROACTIVE_RECOVERY=1
266+
267+
CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
268+
```
269+
270+
On shell 2 (a second replica group joins):
271+
```sh
272+
export REPLICA_GROUP_ID=1
273+
export NUM_REPLICA_GROUPS=2
274+
export TORCHFT_PROACTIVE_RECOVERY=1
275+
276+
CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
277+
```
278+
279+
You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.
280+
249281
### Example Parameter Server
250282

251283
torchft has a fault tolerant parameter server implementation built on it's

proto/torchft.proto

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,17 @@ message LighthouseHeartbeatRequest {
6767

6868
message LighthouseHeartbeatResponse {}
6969

70+
message SubscribeFailuresRequest {}
71+
72+
message FailureNotification {
73+
string replica_id = 1;
74+
string error_message = 2;
75+
}
76+
7077
service LighthouseService {
7178
rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse);
7279
rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse);
80+
rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification);
7381
}
7482

7583
message ManagerQuorumRequest {
@@ -126,3 +134,9 @@ service ManagerService {
126134
rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
127135
rpc Kill(KillRequest) returns (KillResponse);
128136
}
137+
138+
message LighthouseClientRequest {
139+
string replica_id = 1;
140+
}
141+
142+

src/lib.rs

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ mod timeout;
1313
use anyhow::Result;
1414
use atty::Stream;
1515
use core::time::Duration;
16-
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
16+
use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError};
1717
use std::cmp;
1818
use std::env;
1919
use std::sync::Arc;
2020
use std::thread::available_parallelism;
2121
use structopt::StructOpt;
2222
use tokio::runtime::Runtime;
2323
use tokio::task::JoinHandle;
24+
use tokio_stream::StreamExt;
2425
use tonic::transport::Channel;
2526
use tonic::Status;
2627

@@ -35,11 +36,13 @@ pub mod torchftpb {
3536
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
3637
use crate::torchftpb::manager_service_client::ManagerServiceClient;
3738
use crate::torchftpb::{
38-
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
39-
ManagerQuorumRequest, ShouldCommitRequest,
39+
CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification,
40+
LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest,
41+
SubscribeFailuresRequest,
4042
};
4143
use pyo3::prelude::*;
4244
use pyo3::types::{PyDict, PyString};
45+
use pyo3::{PyRef, PyRefMut};
4346

4447
// Get the number of threads to use for the tokio runtime
4548
fn num_threads() -> usize {
@@ -290,6 +293,45 @@ struct QuorumResult {
290293
heal: bool,
291294
}
292295

296+
#[pyclass(unsendable)]
297+
struct FailureStream {
298+
runtime: Arc<Runtime>,
299+
stream: tonic::Streaming<ProtoFailureNotification>,
300+
timeout: Duration,
301+
}
302+
303+
#[pymethods]
304+
impl FailureStream {
305+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
306+
slf
307+
}
308+
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> {
309+
let runtime = slf.runtime.clone();
310+
let timeout = slf.timeout;
311+
// borrow stream mutably for the whole async block
312+
let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await };
313+
314+
match runtime.block_on(fut) {
315+
Ok(Some(Ok(note))) => Ok(FailureNotification {
316+
replica_id: note.replica_id,
317+
error_message: note.error_message,
318+
}),
319+
Ok(Some(Err(status))) => Err(StatusError(status).into()),
320+
Ok(None) => Err(PyStopIteration::new_err(())),
321+
Err(_) => Err(PyTimeoutError::new_err(
322+
"Timeout waiting for failure notification",
323+
)),
324+
}
325+
}
326+
}
327+
328+
#[pyclass(get_all, set_all)]
329+
#[derive(Clone)]
330+
struct FailureNotification {
331+
replica_id: String,
332+
error_message: String,
333+
}
334+
293335
#[pymethods]
294336
impl QuorumResult {
295337
#[new]
@@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478520
#[pyclass]
479521
struct LighthouseClient {
480522
client: LighthouseServiceClient<Channel>,
481-
runtime: Runtime,
523+
runtime: Arc<Runtime>,
482524
}
483525

484526
#[pymethods]
@@ -487,11 +529,13 @@ impl LighthouseClient {
487529
#[new]
488530
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
489531
py.allow_threads(move || {
490-
let runtime = tokio::runtime::Builder::new_multi_thread()
491-
.worker_threads(num_threads())
492-
.thread_name("torchft-lhclnt")
493-
.enable_all()
494-
.build()?;
532+
let runtime = Arc::new(
533+
tokio::runtime::Builder::new_multi_thread()
534+
.worker_threads(num_threads())
535+
.thread_name("torchft-lhclnt")
536+
.enable_all()
537+
.build()?,
538+
);
495539
let client = runtime
496540
.block_on(manager::lighthouse_client_new(addr, connect_timeout))
497541
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
@@ -586,6 +630,22 @@ impl LighthouseClient {
586630
Ok(())
587631
})
588632
}
633+
634+
#[pyo3(signature = (timeout = Duration::from_secs(5)))]
635+
fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult<FailureStream> {
636+
py.allow_threads(move || {
637+
let req = tonic::Request::new(SubscribeFailuresRequest {});
638+
let response = self
639+
.runtime
640+
.block_on(self.client.clone().subscribe_failures(req))
641+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
642+
Ok(FailureStream {
643+
runtime: self.runtime.clone(),
644+
stream: response.into_inner(),
645+
timeout: timeout,
646+
})
647+
})
648+
}
589649
}
590650

591651
/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -610,7 +670,7 @@ struct LighthouseServer {
610670

611671
#[pymethods]
612672
impl LighthouseServer {
613-
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
673+
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None))]
614674
#[new]
615675
fn new(
616676
py: Python<'_>,
@@ -619,10 +679,12 @@ impl LighthouseServer {
619679
join_timeout_ms: Option<u64>,
620680
quorum_tick_ms: Option<u64>,
621681
heartbeat_timeout_ms: Option<u64>,
682+
failure_tick_ms: Option<u64>,
622683
) -> PyResult<Self> {
623684
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
624685
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
625686
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
687+
let failure_tick_ms = failure_tick_ms.unwrap_or(1000);
626688

627689
py.allow_threads(move || {
628690
let rt = tokio::runtime::Builder::new_multi_thread()
@@ -638,6 +700,7 @@ impl LighthouseServer {
638700
join_timeout_ms: join_timeout_ms,
639701
quorum_tick_ms: quorum_tick_ms,
640702
heartbeat_timeout_ms: heartbeat_timeout_ms,
703+
failure_tick_ms: failure_tick_ms,
641704
}))
642705
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
643706

@@ -663,6 +726,22 @@ impl LighthouseServer {
663726
self.handle.abort();
664727
})
665728
}
729+
730+
/// inject_failure broadcasts a failure notification for the given replica.
731+
///
732+
/// This helper is intended for testing `subscribe_failures` from Python.
733+
#[pyo3(signature = (replica_id))]
734+
fn inject_failure(&self, py: Python<'_>, replica_id: String) {
735+
let lighthouse = self.lighthouse.clone();
736+
let runtime = &self._runtime;
737+
py.allow_threads(move || {
738+
let _ = runtime.block_on(async {
739+
if let Err(e) = lighthouse.inject_failure(replica_id).await {
740+
eprintln!("Failed to inject failure: {}", e);
741+
}
742+
});
743+
});
744+
}
666745
}
667746

668747
struct StatusError(Status);
@@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750829
m.add_class::<LighthouseServer>()?;
751830
m.add_class::<LighthouseClient>()?;
752831
m.add_class::<QuorumResult>()?;
832+
m.add_class::<FailureNotification>()?;
833+
m.add_class::<FailureStream>()?;
753834
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
754835

755836
Ok(())

0 commit comments

Comments
 (0)