-
Notifications
You must be signed in to change notification settings - Fork 33
Added proactive heartbeat timeout failure propagation (#164) (#188) #196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,14 +13,15 @@ mod timeout; | |
use anyhow::Result; | ||
use atty::Stream; | ||
use core::time::Duration; | ||
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; | ||
use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError}; | ||
use std::cmp; | ||
use std::env; | ||
use std::sync::Arc; | ||
use std::thread::available_parallelism; | ||
use structopt::StructOpt; | ||
use tokio::runtime::Runtime; | ||
use tokio::task::JoinHandle; | ||
use tokio_stream::StreamExt; | ||
use tonic::transport::Channel; | ||
use tonic::Status; | ||
|
||
|
@@ -35,11 +36,13 @@ pub mod torchftpb { | |
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; | ||
use crate::torchftpb::manager_service_client::ManagerServiceClient; | ||
use crate::torchftpb::{ | ||
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, | ||
ManagerQuorumRequest, ShouldCommitRequest, | ||
CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification, | ||
LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest, | ||
SubscribeFailuresRequest, | ||
}; | ||
use pyo3::prelude::*; | ||
use pyo3::types::{PyDict, PyString}; | ||
use pyo3::{PyRef, PyRefMut}; | ||
|
||
// Get the number of threads to use for the tokio runtime | ||
fn num_threads() -> usize { | ||
|
@@ -290,6 +293,45 @@ struct QuorumResult { | |
heal: bool, | ||
} | ||
|
||
#[pyclass(unsendable)] | ||
struct FailureStream { | ||
runtime: Arc<Runtime>, | ||
stream: tonic::Streaming<ProtoFailureNotification>, | ||
timeout: Duration, | ||
} | ||
|
||
#[pymethods] | ||
impl FailureStream { | ||
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { | ||
slf | ||
} | ||
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I get this error when use self. |
||
let runtime = slf.runtime.clone(); | ||
let timeout = slf.timeout; | ||
// borrow stream mutably for the whole async block | ||
let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await }; | ||
|
||
match runtime.block_on(fut) { | ||
Ok(Some(Ok(note))) => Ok(FailureNotification { | ||
replica_id: note.replica_id, | ||
error_message: note.error_message, | ||
}), | ||
Ok(Some(Err(status))) => Err(StatusError(status).into()), | ||
Ok(None) => Err(PyStopIteration::new_err(())), | ||
Err(_) => Err(PyTimeoutError::new_err( | ||
"Timeout waiting for failure notification", | ||
)), | ||
} | ||
} | ||
} | ||
|
||
#[pyclass(get_all, set_all)] | ||
#[derive(Clone)] | ||
struct FailureNotification { | ||
replica_id: String, | ||
error_message: String, | ||
} | ||
|
||
#[pymethods] | ||
impl QuorumResult { | ||
#[new] | ||
|
@@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> { | |
#[pyclass] | ||
struct LighthouseClient { | ||
client: LighthouseServiceClient<Channel>, | ||
runtime: Runtime, | ||
runtime: Arc<Runtime>, | ||
} | ||
|
||
#[pymethods] | ||
|
@@ -487,11 +529,13 @@ impl LighthouseClient { | |
#[new] | ||
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> { | ||
py.allow_threads(move || { | ||
let runtime = tokio::runtime::Builder::new_multi_thread() | ||
.worker_threads(num_threads()) | ||
.thread_name("torchft-lhclnt") | ||
.enable_all() | ||
.build()?; | ||
let runtime = Arc::new( | ||
tokio::runtime::Builder::new_multi_thread() | ||
.worker_threads(num_threads()) | ||
.thread_name("torchft-lhclnt") | ||
.enable_all() | ||
.build()?, | ||
); | ||
let client = runtime | ||
.block_on(manager::lighthouse_client_new(addr, connect_timeout)) | ||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?; | ||
|
@@ -586,6 +630,22 @@ impl LighthouseClient { | |
Ok(()) | ||
}) | ||
} | ||
|
||
#[pyo3(signature = (timeout = Duration::from_secs(5)))] | ||
fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult<FailureStream> { | ||
py.allow_threads(move || { | ||
let req = tonic::Request::new(SubscribeFailuresRequest {}); | ||
let response = self | ||
.runtime | ||
.block_on(self.client.clone().subscribe_failures(req)) | ||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?; | ||
Ok(FailureStream { | ||
runtime: self.runtime.clone(), | ||
stream: response.into_inner(), | ||
timeout: timeout, | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
/// LighthouseServer is a GRPC server for the lighthouse service. | ||
|
@@ -610,7 +670,7 @@ struct LighthouseServer { | |
|
||
#[pymethods] | ||
impl LighthouseServer { | ||
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] | ||
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None))] | ||
#[new] | ||
fn new( | ||
py: Python<'_>, | ||
|
@@ -619,10 +679,12 @@ impl LighthouseServer { | |
join_timeout_ms: Option<u64>, | ||
quorum_tick_ms: Option<u64>, | ||
heartbeat_timeout_ms: Option<u64>, | ||
failure_tick_ms: Option<u64>, | ||
) -> PyResult<Self> { | ||
let join_timeout_ms = join_timeout_ms.unwrap_or(100); | ||
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); | ||
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000); | ||
let failure_tick_ms = failure_tick_ms.unwrap_or(1000); | ||
|
||
py.allow_threads(move || { | ||
let rt = tokio::runtime::Builder::new_multi_thread() | ||
|
@@ -638,6 +700,7 @@ impl LighthouseServer { | |
join_timeout_ms: join_timeout_ms, | ||
quorum_tick_ms: quorum_tick_ms, | ||
heartbeat_timeout_ms: heartbeat_timeout_ms, | ||
failure_tick_ms: failure_tick_ms, | ||
})) | ||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?; | ||
|
||
|
@@ -663,6 +726,22 @@ impl LighthouseServer { | |
self.handle.abort(); | ||
}) | ||
} | ||
|
||
/// inject_failure broadcasts a failure notification for the given replica. | ||
/// | ||
/// This helper is intended for testing `subscribe_failures` from Python. | ||
#[pyo3(signature = (replica_id))] | ||
fn inject_failure(&self, py: Python<'_>, replica_id: String) { | ||
WarrenZhu050413 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let lighthouse = self.lighthouse.clone(); | ||
let runtime = &self._runtime; | ||
py.allow_threads(move || { | ||
let _ = runtime.block_on(async { | ||
if let Err(e) = lighthouse.inject_failure(replica_id).await { | ||
eprintln!("Failed to inject failure: {}", e); | ||
} | ||
}); | ||
}); | ||
} | ||
} | ||
|
||
struct StatusError(Status); | ||
|
@@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { | |
m.add_class::<LighthouseServer>()?; | ||
m.add_class::<LighthouseClient>()?; | ||
m.add_class::<QuorumResult>()?; | ||
m.add_class::<FailureNotification>()?; | ||
m.add_class::<FailureStream>()?; | ||
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; | ||
|
||
Ok(()) | ||
|
Uh oh!
There was an error while loading. Please reload this page.