@@ -13,14 +13,15 @@ mod timeout;
1313use  anyhow:: Result ; 
1414use  atty:: Stream ; 
1515use  core:: time:: Duration ; 
16- use  pyo3:: exceptions:: { PyRuntimeError ,  PyTimeoutError } ; 
16+ use  pyo3:: exceptions:: { PyRuntimeError ,  PyStopIteration ,   PyTimeoutError } ; 
1717use  std:: cmp; 
1818use  std:: env; 
1919use  std:: sync:: Arc ; 
2020use  std:: thread:: available_parallelism; 
2121use  structopt:: StructOpt ; 
2222use  tokio:: runtime:: Runtime ; 
2323use  tokio:: task:: JoinHandle ; 
24+ use  tokio_stream:: StreamExt ; 
2425use  tonic:: transport:: Channel ; 
2526use  tonic:: Status ; 
2627
@@ -35,11 +36,13 @@ pub mod torchftpb {
3536use  crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ; 
3637use  crate :: torchftpb:: manager_service_client:: ManagerServiceClient ; 
3738use  crate :: torchftpb:: { 
38-     CheckpointMetadataRequest ,  LighthouseHeartbeatRequest ,  LighthouseQuorumRequest , 
39-     ManagerQuorumRequest ,  ShouldCommitRequest , 
39+     CheckpointMetadataRequest ,  FailureNotification  as  ProtoFailureNotification , 
40+     LighthouseHeartbeatRequest ,  LighthouseQuorumRequest ,  ManagerQuorumRequest ,  ShouldCommitRequest , 
41+     SubscribeFailuresRequest , 
4042} ; 
4143use  pyo3:: prelude:: * ; 
4244use  pyo3:: types:: { PyDict ,  PyString } ; 
45+ use  pyo3:: { PyRef ,  PyRefMut } ; 
4346
4447// Get the number of threads to use for the tokio runtime 
4548fn  num_threads ( )  -> usize  { 
@@ -290,6 +293,43 @@ struct QuorumResult {
290293    heal :  bool , 
291294} 
292295
296+ #[ pyclass( unsendable) ]  
297+ struct  FailureStream  { 
298+     runtime :  Arc < Runtime > , 
299+     stream :  tonic:: Streaming < ProtoFailureNotification > , 
300+     timeout :  Duration , 
301+ } 
302+ 
303+ #[ pymethods]  
304+ impl  FailureStream  { 
305+     fn  __iter__ ( slf :  PyRef < ' _ ,  Self > )  -> PyRef < ' _ ,  Self >  { 
306+         slf
307+     } 
308+     fn  __next__ ( mut  slf :  PyRefMut < ' _ ,  Self > )  -> PyResult < FailureNotification >  { 
309+         let  runtime = slf. runtime . clone ( ) ; 
310+         let  timeout = slf. timeout ; 
311+         // borrow stream mutably for the whole async block 
312+         let  fut = async  {  tokio:: time:: timeout ( timeout,  slf. stream . next ( ) ) . await  } ; 
313+ 
314+         match  runtime. block_on ( fut)  { 
315+             Ok ( Some ( Ok ( note) ) )  => Ok ( FailureNotification  { 
316+                 replica_id :  note. replica_id , 
317+             } ) , 
318+             Ok ( Some ( Err ( status) ) )  => Err ( StatusError ( status) . into ( ) ) , 
319+             Ok ( None )  => Err ( PyStopIteration :: new_err ( ( ) ) ) , 
320+             Err ( _)  => Err ( PyTimeoutError :: new_err ( 
321+                 "Timeout waiting for failure notification" , 
322+             ) ) , 
323+         } 
324+     } 
325+ } 
326+ 
327+ #[ pyclass( get_all,  set_all) ]  
328+ #[ derive( Clone ) ]  
329+ struct  FailureNotification  { 
330+     replica_id :  String , 
331+ } 
332+ 
293333#[ pymethods]  
294334impl  QuorumResult  { 
295335    #[ new]  
@@ -396,6 +436,12 @@ pub struct Timestamp {
396436    pub  nanos :  i32 , 
397437} 
398438
439+ #[ pyclass( get_all,  set_all) ]  
440+ #[ derive( Clone ) ]  
441+ struct  FailureNotificationPy  { 
442+     replica_id :  String , 
443+ } 
444+ 
399445/// quorum result. 
400446/// 
401447/// Args: 
@@ -478,7 +524,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478524#[ pyclass]  
479525struct  LighthouseClient  { 
480526    client :  LighthouseServiceClient < Channel > , 
481-     runtime :  Runtime , 
527+     runtime :  Arc < Runtime > , 
482528} 
483529
484530#[ pymethods]  
@@ -487,11 +533,13 @@ impl LighthouseClient {
487533    #[ new]  
488534    fn  new ( py :  Python < ' _ > ,  addr :  String ,  connect_timeout :  Duration )  -> PyResult < Self >  { 
489535        py. allow_threads ( move  || { 
490-             let  runtime = tokio:: runtime:: Builder :: new_multi_thread ( ) 
491-                 . worker_threads ( num_threads ( ) ) 
492-                 . thread_name ( "torchft-lhclnt" ) 
493-                 . enable_all ( ) 
494-                 . build ( ) ?; 
536+             let  runtime = Arc :: new ( 
537+                 tokio:: runtime:: Builder :: new_multi_thread ( ) 
538+                     . worker_threads ( num_threads ( ) ) 
539+                     . thread_name ( "torchft-lhclnt" ) 
540+                     . enable_all ( ) 
541+                     . build ( ) ?, 
542+             ) ; 
495543            let  client = runtime
496544                . block_on ( manager:: lighthouse_client_new ( addr,  connect_timeout) ) 
497545                . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?; 
@@ -586,6 +634,22 @@ impl LighthouseClient {
586634            Ok ( ( ) ) 
587635        } ) 
588636    } 
637+ 
638+     #[ pyo3( signature = ( timeout = Duration :: from_secs( 5 ) ) ) ]  
639+     fn  subscribe_failures ( & self ,  py :  Python < ' _ > ,  timeout :  Duration )  -> PyResult < FailureStream >  { 
640+         py. allow_threads ( move  || { 
641+             let  req = tonic:: Request :: new ( SubscribeFailuresRequest  { } ) ; 
642+             let  response = self 
643+                 . runtime 
644+                 . block_on ( self . client . clone ( ) . subscribe_failures ( req) ) 
645+                 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?; 
646+             Ok ( FailureStream  { 
647+                 runtime :  self . runtime . clone ( ) , 
648+                 stream :  response. into_inner ( ) , 
649+                 timeout :  timeout, 
650+             } ) 
651+         } ) 
652+     } 
589653} 
590654
591655/// LighthouseServer is a GRPC server for the lighthouse service. 
@@ -610,7 +674,7 @@ struct LighthouseServer {
610674
611675#[ pymethods]  
612676impl  LighthouseServer  { 
613-     #[ pyo3( signature = ( bind,  min_replicas,  join_timeout_ms=None ,  quorum_tick_ms=None ,  heartbeat_timeout_ms=None ) ) ]  
677+     #[ pyo3( signature = ( bind,  min_replicas,  join_timeout_ms=None ,  quorum_tick_ms=None ,  heartbeat_timeout_ms=None ,  failure_tick_ms= None ) ) ]  
614678    #[ new]  
615679    fn  new ( 
616680        py :  Python < ' _ > , 
@@ -619,10 +683,12 @@ impl LighthouseServer {
619683        join_timeout_ms :  Option < u64 > , 
620684        quorum_tick_ms :  Option < u64 > , 
621685        heartbeat_timeout_ms :  Option < u64 > , 
686+         failure_tick_ms :  Option < u64 > , 
622687    )  -> PyResult < Self >  { 
623688        let  join_timeout_ms = join_timeout_ms. unwrap_or ( 100 ) ; 
624689        let  quorum_tick_ms = quorum_tick_ms. unwrap_or ( 100 ) ; 
625690        let  heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ; 
691+         let  failure_tick_ms = failure_tick_ms. unwrap_or ( 1000 ) ; 
626692
627693        py. allow_threads ( move  || { 
628694            let  rt = tokio:: runtime:: Builder :: new_multi_thread ( ) 
@@ -638,6 +704,7 @@ impl LighthouseServer {
638704                    join_timeout_ms :  join_timeout_ms, 
639705                    quorum_tick_ms :  quorum_tick_ms, 
640706                    heartbeat_timeout_ms :  heartbeat_timeout_ms, 
707+                     failure_tick_ms :  failure_tick_ms, 
641708                } ) ) 
642709                . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?; 
643710
@@ -663,6 +730,22 @@ impl LighthouseServer {
663730            self . handle . abort ( ) ; 
664731        } ) 
665732    } 
733+ 
734+     /// inject_failure broadcasts a failure notification for the given replica. 
735+      /// 
736+      /// This helper is intended for testing `subscribe_failures` from Python. 
737+      #[ pyo3( signature = ( replica_id) ) ]  
738+     fn  inject_failure ( & self ,  py :  Python < ' _ > ,  replica_id :  String )  { 
739+         let  lighthouse = self . lighthouse . clone ( ) ; 
740+         let  runtime = & self . _runtime ; 
741+         py. allow_threads ( move  || { 
742+             let  _ = runtime. block_on ( async  { 
743+                 if  let  Err ( e)  = lighthouse. inject_failure ( replica_id) . await  { 
744+                     eprintln ! ( "Failed to inject failure: {}" ,  e) ; 
745+                 } 
746+             } ) ; 
747+         } ) ; 
748+     } 
666749} 
667750
668751struct  StatusError ( Status ) ; 
@@ -750,6 +833,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750833    m. add_class :: < LighthouseServer > ( ) ?; 
751834    m. add_class :: < LighthouseClient > ( ) ?; 
752835    m. add_class :: < QuorumResult > ( ) ?; 
836+     m. add_class :: < FailureNotificationPy > ( ) ?; 
837+     m. add_class :: < FailureStream > ( ) ?; 
753838    m. add_function ( wrap_pyfunction ! ( lighthouse_main,  m) ?) ?; 
754839
755840    Ok ( ( ) ) 
0 commit comments