@@ -13,14 +13,15 @@ mod timeout;
13
13
use anyhow:: Result ;
14
14
use atty:: Stream ;
15
15
use core:: time:: Duration ;
16
- use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
16
+ use pyo3:: exceptions:: { PyRuntimeError , PyStopIteration , PyTimeoutError } ;
17
17
use std:: cmp;
18
18
use std:: env;
19
19
use std:: sync:: Arc ;
20
20
use std:: thread:: available_parallelism;
21
21
use structopt:: StructOpt ;
22
22
use tokio:: runtime:: Runtime ;
23
23
use tokio:: task:: JoinHandle ;
24
+ use tokio_stream:: StreamExt ;
24
25
use tonic:: transport:: Channel ;
25
26
use tonic:: Status ;
26
27
@@ -35,11 +36,13 @@ pub mod torchftpb {
35
36
use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
36
37
use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
37
38
use crate :: torchftpb:: {
38
- CheckpointMetadataRequest , LighthouseHeartbeatRequest , LighthouseQuorumRequest ,
39
- ManagerQuorumRequest , ShouldCommitRequest ,
39
+ CheckpointMetadataRequest , FailureNotification as ProtoFailureNotification ,
40
+ LighthouseHeartbeatRequest , LighthouseQuorumRequest , ManagerQuorumRequest , ShouldCommitRequest ,
41
+ SubscribeFailuresRequest ,
40
42
} ;
41
43
use pyo3:: prelude:: * ;
42
44
use pyo3:: types:: { PyDict , PyString } ;
45
+ use pyo3:: { PyRef , PyRefMut } ;
43
46
44
47
// Get the number of threads to use for the tokio runtime
45
48
fn num_threads ( ) -> usize {
@@ -290,6 +293,43 @@ struct QuorumResult {
290
293
heal : bool ,
291
294
}
292
295
296
+ #[ pyclass( unsendable) ]
297
+ struct FailureStream {
298
+ runtime : Arc < Runtime > ,
299
+ stream : tonic:: Streaming < ProtoFailureNotification > ,
300
+ timeout : Duration ,
301
+ }
302
+
303
+ #[ pymethods]
304
+ impl FailureStream {
305
+ fn __iter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
306
+ slf
307
+ }
308
+ fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> PyResult < FailureNotification > {
309
+ let runtime = slf. runtime . clone ( ) ;
310
+ let timeout = slf. timeout ;
311
+ // borrow stream mutably for the whole async block
312
+ let fut = async { tokio:: time:: timeout ( timeout, slf. stream . next ( ) ) . await } ;
313
+
314
+ match runtime. block_on ( fut) {
315
+ Ok ( Some ( Ok ( note) ) ) => Ok ( FailureNotification {
316
+ replica_id : note. replica_id ,
317
+ } ) ,
318
+ Ok ( Some ( Err ( status) ) ) => Err ( StatusError ( status) . into ( ) ) ,
319
+ Ok ( None ) => Err ( PyStopIteration :: new_err ( ( ) ) ) ,
320
+ Err ( _) => Err ( PyTimeoutError :: new_err (
321
+ "Timeout waiting for failure notification" ,
322
+ ) ) ,
323
+ }
324
+ }
325
+ }
326
+
327
+ #[ pyclass( get_all, set_all) ]
328
+ #[ derive( Clone ) ]
329
+ struct FailureNotification {
330
+ replica_id : String ,
331
+ }
332
+
293
333
#[ pymethods]
294
334
impl QuorumResult {
295
335
#[ new]
@@ -396,6 +436,12 @@ pub struct Timestamp {
396
436
pub nanos : i32 ,
397
437
}
398
438
439
+ #[ pyclass( get_all, set_all) ]
440
+ #[ derive( Clone ) ]
441
+ struct FailureNotificationPy {
442
+ replica_id : String ,
443
+ }
444
+
399
445
/// quorum result.
400
446
///
401
447
/// Args:
@@ -478,7 +524,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478
524
#[ pyclass]
479
525
struct LighthouseClient {
480
526
client : LighthouseServiceClient < Channel > ,
481
- runtime : Runtime ,
527
+ runtime : Arc < Runtime > ,
482
528
}
483
529
484
530
#[ pymethods]
@@ -487,11 +533,13 @@ impl LighthouseClient {
487
533
#[ new]
488
534
fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
489
535
py. allow_threads ( move || {
490
- let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
491
- . worker_threads ( num_threads ( ) )
492
- . thread_name ( "torchft-lhclnt" )
493
- . enable_all ( )
494
- . build ( ) ?;
536
+ let runtime = Arc :: new (
537
+ tokio:: runtime:: Builder :: new_multi_thread ( )
538
+ . worker_threads ( num_threads ( ) )
539
+ . thread_name ( "torchft-lhclnt" )
540
+ . enable_all ( )
541
+ . build ( ) ?,
542
+ ) ;
495
543
let client = runtime
496
544
. block_on ( manager:: lighthouse_client_new ( addr, connect_timeout) )
497
545
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -586,6 +634,22 @@ impl LighthouseClient {
586
634
Ok ( ( ) )
587
635
} )
588
636
}
637
+
638
+ #[ pyo3( signature = ( timeout = Duration :: from_secs( 5 ) ) ) ]
639
+ fn subscribe_failures ( & self , py : Python < ' _ > , timeout : Duration ) -> PyResult < FailureStream > {
640
+ py. allow_threads ( move || {
641
+ let req = tonic:: Request :: new ( SubscribeFailuresRequest { } ) ;
642
+ let response = self
643
+ . runtime
644
+ . block_on ( self . client . clone ( ) . subscribe_failures ( req) )
645
+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
646
+ Ok ( FailureStream {
647
+ runtime : self . runtime . clone ( ) ,
648
+ stream : response. into_inner ( ) ,
649
+ timeout : timeout,
650
+ } )
651
+ } )
652
+ }
589
653
}
590
654
591
655
/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -610,7 +674,7 @@ struct LighthouseServer {
610
674
611
675
#[ pymethods]
612
676
impl LighthouseServer {
613
- #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None ) ) ]
677
+ #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None , failure_tick_ms= None ) ) ]
614
678
#[ new]
615
679
fn new (
616
680
py : Python < ' _ > ,
@@ -619,10 +683,12 @@ impl LighthouseServer {
619
683
join_timeout_ms : Option < u64 > ,
620
684
quorum_tick_ms : Option < u64 > ,
621
685
heartbeat_timeout_ms : Option < u64 > ,
686
+ failure_tick_ms : Option < u64 > ,
622
687
) -> PyResult < Self > {
623
688
let join_timeout_ms = join_timeout_ms. unwrap_or ( 100 ) ;
624
689
let quorum_tick_ms = quorum_tick_ms. unwrap_or ( 100 ) ;
625
690
let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
691
+ let failure_tick_ms = failure_tick_ms. unwrap_or ( 1000 ) ;
626
692
627
693
py. allow_threads ( move || {
628
694
let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
@@ -638,6 +704,7 @@ impl LighthouseServer {
638
704
join_timeout_ms : join_timeout_ms,
639
705
quorum_tick_ms : quorum_tick_ms,
640
706
heartbeat_timeout_ms : heartbeat_timeout_ms,
707
+ failure_tick_ms : failure_tick_ms,
641
708
} ) )
642
709
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
643
710
@@ -663,6 +730,22 @@ impl LighthouseServer {
663
730
self . handle . abort ( ) ;
664
731
} )
665
732
}
733
+
734
+ /// inject_failure broadcasts a failure notification for the given replica.
735
+ ///
736
+ /// This helper is intended for testing `subscribe_failures` from Python.
737
+ #[ pyo3( signature = ( replica_id) ) ]
738
+ fn inject_failure ( & self , py : Python < ' _ > , replica_id : String ) {
739
+ let lighthouse = self . lighthouse . clone ( ) ;
740
+ let runtime = & self . _runtime ;
741
+ py. allow_threads ( move || {
742
+ let _ = runtime. block_on ( async {
743
+ if let Err ( e) = lighthouse. inject_failure ( replica_id) . await {
744
+ eprintln ! ( "Failed to inject failure: {}" , e) ;
745
+ }
746
+ } ) ;
747
+ } ) ;
748
+ }
666
749
}
667
750
668
751
struct StatusError ( Status ) ;
@@ -750,6 +833,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750
833
m. add_class :: < LighthouseServer > ( ) ?;
751
834
m. add_class :: < LighthouseClient > ( ) ?;
752
835
m. add_class :: < QuorumResult > ( ) ?;
836
+ m. add_class :: < FailureNotificationPy > ( ) ?;
837
+ m. add_class :: < FailureStream > ( ) ?;
753
838
m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
754
839
755
840
Ok ( ( ) )
0 commit comments