@@ -13,14 +13,15 @@ mod timeout;
13
13
use anyhow:: Result ;
14
14
use atty:: Stream ;
15
15
use core:: time:: Duration ;
16
- use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
16
+ use pyo3:: exceptions:: { PyRuntimeError , PyStopIteration , PyTimeoutError } ;
17
17
use std:: cmp;
18
18
use std:: env;
19
19
use std:: sync:: Arc ;
20
20
use std:: thread:: available_parallelism;
21
21
use structopt:: StructOpt ;
22
22
use tokio:: runtime:: Runtime ;
23
23
use tokio:: task:: JoinHandle ;
24
+ use tokio_stream:: StreamExt ;
24
25
use tonic:: transport:: Channel ;
25
26
use tonic:: Status ;
26
27
@@ -35,11 +36,13 @@ pub mod torchftpb {
35
36
use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
36
37
use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
37
38
use crate :: torchftpb:: {
38
- CheckpointMetadataRequest , LighthouseHeartbeatRequest , LighthouseQuorumRequest ,
39
- ManagerQuorumRequest , ShouldCommitRequest ,
39
+ CheckpointMetadataRequest , FailureNotification as ProtoFailureNotification ,
40
+ LighthouseHeartbeatRequest , LighthouseQuorumRequest , ManagerQuorumRequest , ShouldCommitRequest ,
41
+ SubscribeFailuresRequest ,
40
42
} ;
41
43
use pyo3:: prelude:: * ;
42
44
use pyo3:: types:: { PyDict , PyString } ;
45
+ use pyo3:: { PyRef , PyRefMut } ;
43
46
44
47
// Get the number of threads to use for the tokio runtime
45
48
fn num_threads ( ) -> usize {
@@ -290,6 +293,45 @@ struct QuorumResult {
290
293
heal : bool ,
291
294
}
292
295
296
+ #[ pyclass( unsendable) ]
297
+ struct FailureStream {
298
+ runtime : Arc < Runtime > ,
299
+ stream : tonic:: Streaming < ProtoFailureNotification > ,
300
+ timeout : Duration ,
301
+ }
302
+
303
+ #[ pymethods]
304
+ impl FailureStream {
305
+ fn __iter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
306
+ slf
307
+ }
308
+ fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> PyResult < FailureNotification > {
309
+ let runtime = slf. runtime . clone ( ) ;
310
+ let timeout = slf. timeout ;
311
+ // borrow stream mutably for the whole async block
312
+ let fut = async { tokio:: time:: timeout ( timeout, slf. stream . next ( ) ) . await } ;
313
+
314
+ match runtime. block_on ( fut) {
315
+ Ok ( Some ( Ok ( note) ) ) => Ok ( FailureNotification {
316
+ replica_id : note. replica_id ,
317
+ error_message : note. error_message ,
318
+ } ) ,
319
+ Ok ( Some ( Err ( status) ) ) => Err ( StatusError ( status) . into ( ) ) ,
320
+ Ok ( None ) => Err ( PyStopIteration :: new_err ( ( ) ) ) ,
321
+ Err ( _) => Err ( PyTimeoutError :: new_err (
322
+ "Timeout waiting for failure notification" ,
323
+ ) ) ,
324
+ }
325
+ }
326
+ }
327
+
328
+ #[ pyclass( get_all, set_all) ]
329
+ #[ derive( Clone ) ]
330
+ struct FailureNotification {
331
+ replica_id : String ,
332
+ error_message : String ,
333
+ }
334
+
293
335
#[ pymethods]
294
336
impl QuorumResult {
295
337
#[ new]
@@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478
520
#[ pyclass]
479
521
struct LighthouseClient {
480
522
client : LighthouseServiceClient < Channel > ,
481
- runtime : Runtime ,
523
+ runtime : Arc < Runtime > ,
482
524
}
483
525
484
526
#[ pymethods]
@@ -487,11 +529,13 @@ impl LighthouseClient {
487
529
#[ new]
488
530
fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
489
531
py. allow_threads ( move || {
490
- let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
491
- . worker_threads ( num_threads ( ) )
492
- . thread_name ( "torchft-lhclnt" )
493
- . enable_all ( )
494
- . build ( ) ?;
532
+ let runtime = Arc :: new (
533
+ tokio:: runtime:: Builder :: new_multi_thread ( )
534
+ . worker_threads ( num_threads ( ) )
535
+ . thread_name ( "torchft-lhclnt" )
536
+ . enable_all ( )
537
+ . build ( ) ?,
538
+ ) ;
495
539
let client = runtime
496
540
. block_on ( manager:: lighthouse_client_new ( addr, connect_timeout) )
497
541
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -586,6 +630,22 @@ impl LighthouseClient {
586
630
Ok ( ( ) )
587
631
} )
588
632
}
633
+
634
+ #[ pyo3( signature = ( timeout = Duration :: from_secs( 5 ) ) ) ]
635
+ fn subscribe_failures ( & self , py : Python < ' _ > , timeout : Duration ) -> PyResult < FailureStream > {
636
+ py. allow_threads ( move || {
637
+ let req = tonic:: Request :: new ( SubscribeFailuresRequest { } ) ;
638
+ let response = self
639
+ . runtime
640
+ . block_on ( self . client . clone ( ) . subscribe_failures ( req) )
641
+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
642
+ Ok ( FailureStream {
643
+ runtime : self . runtime . clone ( ) ,
644
+ stream : response. into_inner ( ) ,
645
+ timeout : timeout,
646
+ } )
647
+ } )
648
+ }
589
649
}
590
650
591
651
/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -610,7 +670,7 @@ struct LighthouseServer {
610
670
611
671
#[ pymethods]
612
672
impl LighthouseServer {
613
- #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None ) ) ]
673
+ #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None , failure_tick_ms= None ) ) ]
614
674
#[ new]
615
675
fn new (
616
676
py : Python < ' _ > ,
@@ -619,10 +679,12 @@ impl LighthouseServer {
619
679
join_timeout_ms : Option < u64 > ,
620
680
quorum_tick_ms : Option < u64 > ,
621
681
heartbeat_timeout_ms : Option < u64 > ,
682
+ failure_tick_ms : Option < u64 > ,
622
683
) -> PyResult < Self > {
623
684
let join_timeout_ms = join_timeout_ms. unwrap_or ( 100 ) ;
624
685
let quorum_tick_ms = quorum_tick_ms. unwrap_or ( 100 ) ;
625
686
let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
687
+ let failure_tick_ms = failure_tick_ms. unwrap_or ( 1000 ) ;
626
688
627
689
py. allow_threads ( move || {
628
690
let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
@@ -638,6 +700,7 @@ impl LighthouseServer {
638
700
join_timeout_ms : join_timeout_ms,
639
701
quorum_tick_ms : quorum_tick_ms,
640
702
heartbeat_timeout_ms : heartbeat_timeout_ms,
703
+ failure_tick_ms : failure_tick_ms,
641
704
} ) )
642
705
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
643
706
@@ -663,6 +726,22 @@ impl LighthouseServer {
663
726
self . handle . abort ( ) ;
664
727
} )
665
728
}
729
+
730
+ /// inject_failure broadcasts a failure notification for the given replica.
731
+ ///
732
+ /// This helper is intended for testing `subscribe_failures` from Python.
733
+ #[ pyo3( signature = ( replica_id) ) ]
734
+ fn inject_failure ( & self , py : Python < ' _ > , replica_id : String ) {
735
+ let lighthouse = self . lighthouse . clone ( ) ;
736
+ let runtime = & self . _runtime ;
737
+ py. allow_threads ( move || {
738
+ let _ = runtime. block_on ( async {
739
+ if let Err ( e) = lighthouse. inject_failure ( replica_id) . await {
740
+ eprintln ! ( "Failed to inject failure: {}" , e) ;
741
+ }
742
+ } ) ;
743
+ } ) ;
744
+ }
666
745
}
667
746
668
747
struct StatusError ( Status ) ;
@@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750
829
m. add_class :: < LighthouseServer > ( ) ?;
751
830
m. add_class :: < LighthouseClient > ( ) ?;
752
831
m. add_class :: < QuorumResult > ( ) ?;
832
+ m. add_class :: < FailureNotification > ( ) ?;
833
+ m. add_class :: < FailureStream > ( ) ?;
753
834
m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
754
835
755
836
Ok ( ( ) )
0 commit comments