Skip to content

Commit ebb3953

Browse files
Added proactive heartbeat timeout failure propagation (#164) (#188)
1 parent b84c5a6 commit ebb3953

15 files changed

+1426
-56
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ slog-stdlog = "4.1.1"
2121
stderrlog = "0.6.0"
2222
structopt = "0.3.26"
2323
tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] }
24+
tokio-stream = {version = "0.1.14", features = ["sync"]}
2425
tonic = "0.12.2"
26+
futures-core = "0.3"
2527

2628
[build-dependencies]
2729
tonic-build = "0.12.2"

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast
246246

247247
By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery.
248248

249+
### Proactive Failure Recovery Mode (Experimental)
250+
251+
You can experiment with proactive failure recovery mode by:
252+
253+
```sh
254+
export TORCHFT_PROACTIVE_RECOVERY=1
255+
```
256+
257+
With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.
258+
259+
You can test this out by running `train_ddp_proactive.py`
260+
261+
On shell 1 (one replica groups starts initial training):
262+
```sh
263+
export REPLICA_GROUP_ID=0
264+
export NUM_REPLICA_GROUPS=2
265+
export TORCHFT_PROACTIVE_RECOVERY=1
266+
267+
CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
268+
```
269+
270+
On shell 2 (a second replica group joins):
271+
```sh
272+
export REPLICA_GROUP_ID=1
273+
export NUM_REPLICA_GROUPS=2
274+
export TORCHFT_PROACTIVE_RECOVERY=1
275+
276+
CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
277+
```
278+
279+
You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.
280+
249281
### Example Parameter Server
250282

251283
torchft has a fault tolerant parameter server implementation built on it's

proto/torchft.proto

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,16 @@ message LighthouseHeartbeatRequest {
6767

6868
message LighthouseHeartbeatResponse {}
6969

70+
message SubscribeFailuresRequest {}
71+
72+
message FailureNotification {
73+
string replica_id = 1;
74+
}
75+
7076
service LighthouseService {
7177
rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse);
7278
rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse);
79+
rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification);
7380
}
7481

7582
message ManagerQuorumRequest {
@@ -126,3 +133,9 @@ service ManagerService {
126133
rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
127134
rpc Kill(KillRequest) returns (KillResponse);
128135
}
136+
137+
message LighthouseClientRequest {
138+
string replica_id = 1;
139+
}
140+
141+

src/lib.rs

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ mod timeout;
1313
use anyhow::Result;
1414
use atty::Stream;
1515
use core::time::Duration;
16-
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
16+
use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError};
1717
use std::cmp;
1818
use std::env;
1919
use std::sync::Arc;
2020
use std::thread::available_parallelism;
2121
use structopt::StructOpt;
2222
use tokio::runtime::Runtime;
2323
use tokio::task::JoinHandle;
24+
use tokio_stream::StreamExt;
2425
use tonic::transport::Channel;
2526
use tonic::Status;
2627

@@ -35,11 +36,13 @@ pub mod torchftpb {
3536
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
3637
use crate::torchftpb::manager_service_client::ManagerServiceClient;
3738
use crate::torchftpb::{
38-
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
39-
ManagerQuorumRequest, ShouldCommitRequest,
39+
CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification,
40+
LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest,
41+
SubscribeFailuresRequest,
4042
};
4143
use pyo3::prelude::*;
4244
use pyo3::types::{PyDict, PyString};
45+
use pyo3::{PyRef, PyRefMut};
4346

4447
// Get the number of threads to use for the tokio runtime
4548
fn num_threads() -> usize {
@@ -290,6 +293,43 @@ struct QuorumResult {
290293
heal: bool,
291294
}
292295

296+
#[pyclass(unsendable)]
297+
struct FailureStream {
298+
runtime: Arc<Runtime>,
299+
stream: tonic::Streaming<ProtoFailureNotification>,
300+
timeout: Duration,
301+
}
302+
303+
#[pymethods]
304+
impl FailureStream {
305+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
306+
slf
307+
}
308+
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> {
309+
let runtime = slf.runtime.clone();
310+
let timeout = slf.timeout;
311+
// borrow stream mutably for the whole async block
312+
let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await };
313+
314+
match runtime.block_on(fut) {
315+
Ok(Some(Ok(note))) => Ok(FailureNotification {
316+
replica_id: note.replica_id,
317+
}),
318+
Ok(Some(Err(status))) => Err(StatusError(status).into()),
319+
Ok(None) => Err(PyStopIteration::new_err(())),
320+
Err(_) => Err(PyTimeoutError::new_err(
321+
"Timeout waiting for failure notification",
322+
)),
323+
}
324+
}
325+
}
326+
327+
#[pyclass(get_all, set_all)]
328+
#[derive(Clone)]
329+
struct FailureNotification {
330+
replica_id: String,
331+
}
332+
293333
#[pymethods]
294334
impl QuorumResult {
295335
#[new]
@@ -396,6 +436,12 @@ pub struct Timestamp {
396436
pub nanos: i32,
397437
}
398438

439+
#[pyclass(get_all, set_all)]
440+
#[derive(Clone)]
441+
struct FailureNotificationPy {
442+
replica_id: String,
443+
}
444+
399445
/// quorum result.
400446
///
401447
/// Args:
@@ -478,7 +524,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478524
#[pyclass]
479525
struct LighthouseClient {
480526
client: LighthouseServiceClient<Channel>,
481-
runtime: Runtime,
527+
runtime: Arc<Runtime>,
482528
}
483529

484530
#[pymethods]
@@ -487,11 +533,13 @@ impl LighthouseClient {
487533
#[new]
488534
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
489535
py.allow_threads(move || {
490-
let runtime = tokio::runtime::Builder::new_multi_thread()
491-
.worker_threads(num_threads())
492-
.thread_name("torchft-lhclnt")
493-
.enable_all()
494-
.build()?;
536+
let runtime = Arc::new(
537+
tokio::runtime::Builder::new_multi_thread()
538+
.worker_threads(num_threads())
539+
.thread_name("torchft-lhclnt")
540+
.enable_all()
541+
.build()?,
542+
);
495543
let client = runtime
496544
.block_on(manager::lighthouse_client_new(addr, connect_timeout))
497545
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
@@ -586,6 +634,22 @@ impl LighthouseClient {
586634
Ok(())
587635
})
588636
}
637+
638+
#[pyo3(signature = (timeout = Duration::from_secs(5)))]
639+
fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult<FailureStream> {
640+
py.allow_threads(move || {
641+
let req = tonic::Request::new(SubscribeFailuresRequest {});
642+
let response = self
643+
.runtime
644+
.block_on(self.client.clone().subscribe_failures(req))
645+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
646+
Ok(FailureStream {
647+
runtime: self.runtime.clone(),
648+
stream: response.into_inner(),
649+
timeout: timeout,
650+
})
651+
})
652+
}
589653
}
590654

591655
/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -610,7 +674,7 @@ struct LighthouseServer {
610674

611675
#[pymethods]
612676
impl LighthouseServer {
613-
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
677+
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None))]
614678
#[new]
615679
fn new(
616680
py: Python<'_>,
@@ -619,10 +683,12 @@ impl LighthouseServer {
619683
join_timeout_ms: Option<u64>,
620684
quorum_tick_ms: Option<u64>,
621685
heartbeat_timeout_ms: Option<u64>,
686+
failure_tick_ms: Option<u64>,
622687
) -> PyResult<Self> {
623688
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
624689
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
625690
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
691+
let failure_tick_ms = failure_tick_ms.unwrap_or(1000);
626692

627693
py.allow_threads(move || {
628694
let rt = tokio::runtime::Builder::new_multi_thread()
@@ -638,6 +704,7 @@ impl LighthouseServer {
638704
join_timeout_ms: join_timeout_ms,
639705
quorum_tick_ms: quorum_tick_ms,
640706
heartbeat_timeout_ms: heartbeat_timeout_ms,
707+
failure_tick_ms: failure_tick_ms,
641708
}))
642709
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
643710

@@ -663,6 +730,22 @@ impl LighthouseServer {
663730
self.handle.abort();
664731
})
665732
}
733+
734+
/// inject_failure broadcasts a failure notification for the given replica.
735+
///
736+
/// This helper is intended for testing `subscribe_failures` from Python.
737+
#[pyo3(signature = (replica_id))]
738+
fn inject_failure(&self, py: Python<'_>, replica_id: String) {
739+
let lighthouse = self.lighthouse.clone();
740+
let runtime = &self._runtime;
741+
py.allow_threads(move || {
742+
let _ = runtime.block_on(async {
743+
if let Err(e) = lighthouse.inject_failure(replica_id).await {
744+
eprintln!("Failed to inject failure: {}", e);
745+
}
746+
});
747+
});
748+
}
666749
}
667750

668751
struct StatusError(Status);
@@ -750,6 +833,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750833
m.add_class::<LighthouseServer>()?;
751834
m.add_class::<LighthouseClient>()?;
752835
m.add_class::<QuorumResult>()?;
836+
m.add_class::<FailureNotificationPy>()?;
837+
m.add_class::<FailureStream>()?;
753838
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
754839

755840
Ok(())

0 commit comments

Comments
 (0)