Skip to content

Added proactive heartbeat timeout failure propagation (#164) (#188) #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ slog-stdlog = "4.1.1"
stderrlog = "0.6.0"
structopt = "0.3.26"
tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] }
tokio-stream = {version = "0.1.14", features = ["sync"]}
tonic = "0.12.2"
futures-core = "0.3"

[build-dependencies]
tonic-build = "0.12.2"
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast

By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery.

### Proactive Failure Recovery Mode (Experimental)

You can experiment with proactive failure recovery mode by:

```sh
export TORCHFT_PROACTIVE_RECOVERY=1
```

With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.

You can test this out by running `train_ddp_proactive.py`

On shell 1 (one replica groups starts initial training):
```sh
export REPLICA_GROUP_ID=0
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1

CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
```

On shell 2 (a second replica group joins):
```sh
export REPLICA_GROUP_ID=1
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1

CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
```

You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.

### Example Parameter Server

torchft has a fault tolerant parameter server implementation built on it's
Expand Down
14 changes: 14 additions & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,17 @@ message LighthouseHeartbeatRequest {

message LighthouseHeartbeatResponse {}

message SubscribeFailuresRequest {}

message FailureNotification {
string replica_id = 1;
string error_message = 2;
}

service LighthouseService {
rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse);
rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse);
rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification);
}

message ManagerQuorumRequest {
Expand Down Expand Up @@ -126,3 +134,9 @@ service ManagerService {
rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
rpc Kill(KillRequest) returns (KillResponse);
}

message LighthouseClientRequest {
string replica_id = 1;
}


101 changes: 91 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ mod timeout;
use anyhow::Result;
use atty::Stream;
use core::time::Duration;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError};
use std::cmp;
use std::env;
use std::sync::Arc;
use std::thread::available_parallelism;
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tokio_stream::StreamExt;
use tonic::transport::Channel;
use tonic::Status;

Expand All @@ -35,11 +36,13 @@ pub mod torchftpb {
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::{
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
ManagerQuorumRequest, ShouldCommitRequest,
CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification,
LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest,
SubscribeFailuresRequest,
};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};
use pyo3::{PyRef, PyRefMut};

// Get the number of threads to use for the tokio runtime
fn num_threads() -> usize {
Expand Down Expand Up @@ -290,6 +293,45 @@ struct QuorumResult {
heal: bool,
}

#[pyclass(unsendable)]
struct FailureStream {
runtime: Arc<Runtime>,
stream: tonic::Streaming<ProtoFailureNotification>,
timeout: Duration,
}

#[pymethods]
impl FailureStream {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just use self

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyRef<'_, FailureStream> cannot be used as the type of self without the arbitrary_self_types feature
see issue #44874 rust-lang/rust#44874 for more information
consider changing to self, &self, &mut self, or a type implementing Receiver such as self: Box<Self>, self: Rc<Self>, or self: Arc<Self>rustcClick for full compiler diagnostic

I get this error when use self.

let runtime = slf.runtime.clone();
let timeout = slf.timeout;
// borrow stream mutably for the whole async block
let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await };

match runtime.block_on(fut) {
Ok(Some(Ok(note))) => Ok(FailureNotification {
replica_id: note.replica_id,
error_message: note.error_message,
}),
Ok(Some(Err(status))) => Err(StatusError(status).into()),
Ok(None) => Err(PyStopIteration::new_err(())),
Err(_) => Err(PyTimeoutError::new_err(
"Timeout waiting for failure notification",
)),
}
}
}

#[pyclass(get_all, set_all)]
#[derive(Clone)]
struct FailureNotification {
replica_id: String,
error_message: String,
}

#[pymethods]
impl QuorumResult {
#[new]
Expand Down Expand Up @@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
#[pyclass]
struct LighthouseClient {
client: LighthouseServiceClient<Channel>,
runtime: Runtime,
runtime: Arc<Runtime>,
}

#[pymethods]
Expand All @@ -487,11 +529,13 @@ impl LighthouseClient {
#[new]
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads())
.thread_name("torchft-lhclnt")
.enable_all()
.build()?;
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads())
.thread_name("torchft-lhclnt")
.enable_all()
.build()?,
);
let client = runtime
.block_on(manager::lighthouse_client_new(addr, connect_timeout))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Expand Down Expand Up @@ -586,6 +630,22 @@ impl LighthouseClient {
Ok(())
})
}

#[pyo3(signature = (timeout = Duration::from_secs(5)))]
fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult<FailureStream> {
py.allow_threads(move || {
let req = tonic::Request::new(SubscribeFailuresRequest {});
let response = self
.runtime
.block_on(self.client.clone().subscribe_failures(req))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(FailureStream {
runtime: self.runtime.clone(),
stream: response.into_inner(),
timeout: timeout,
})
})
}
}

/// LighthouseServer is a GRPC server for the lighthouse service.
Expand All @@ -610,7 +670,7 @@ struct LighthouseServer {

#[pymethods]
impl LighthouseServer {
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None))]
#[new]
fn new(
py: Python<'_>,
Expand All @@ -619,10 +679,12 @@ impl LighthouseServer {
join_timeout_ms: Option<u64>,
quorum_tick_ms: Option<u64>,
heartbeat_timeout_ms: Option<u64>,
failure_tick_ms: Option<u64>,
) -> PyResult<Self> {
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
let failure_tick_ms = failure_tick_ms.unwrap_or(1000);

py.allow_threads(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
Expand All @@ -638,6 +700,7 @@ impl LighthouseServer {
join_timeout_ms: join_timeout_ms,
quorum_tick_ms: quorum_tick_ms,
heartbeat_timeout_ms: heartbeat_timeout_ms,
failure_tick_ms: failure_tick_ms,
}))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

Expand All @@ -663,6 +726,22 @@ impl LighthouseServer {
self.handle.abort();
})
}

/// inject_failure broadcasts a failure notification for the given replica.
///
/// This helper is intended for testing `subscribe_failures` from Python.
#[pyo3(signature = (replica_id))]
fn inject_failure(&self, py: Python<'_>, replica_id: String) {
let lighthouse = self.lighthouse.clone();
let runtime = &self._runtime;
py.allow_threads(move || {
let _ = runtime.block_on(async {
if let Err(e) = lighthouse.inject_failure(replica_id).await {
eprintln!("Failed to inject failure: {}", e);
}
});
});
}
}

struct StatusError(Status);
Expand Down Expand Up @@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<LighthouseServer>()?;
m.add_class::<LighthouseClient>()?;
m.add_class::<QuorumResult>()?;
m.add_class::<FailureNotification>()?;
m.add_class::<FailureStream>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;

Ok(())
Expand Down
Loading
Loading