11use std:: collections:: HashMap ;
22use std:: io:: { BufRead , BufReader , Write } ;
33use std:: net:: { Shutdown , SocketAddr , TcpListener , TcpStream } ;
4- use std:: sync:: mpsc:: { Sender , SyncSender , TrySendError } ;
4+ use std:: sync:: atomic:: AtomicBool ;
5+ use std:: sync:: mpsc:: { Receiver , Sender } ;
56use std:: sync:: { Arc , Mutex } ;
67use std:: thread;
78
@@ -100,6 +101,7 @@ struct Connection {
100101 chan : SyncChannel < Message > ,
101102 stats : Arc < Stats > ,
102103 txs_limit : usize ,
104+ die_please : Option < Receiver < ( ) > > ,
103105 #[ cfg( feature = "electrum-discovery" ) ]
104106 discovery : Option < Arc < DiscoveryManager > > ,
105107}
@@ -111,6 +113,7 @@ impl Connection {
111113 addr : SocketAddr ,
112114 stats : Arc < Stats > ,
113115 txs_limit : usize ,
116+ die_please : Receiver < ( ) > ,
114117 #[ cfg( feature = "electrum-discovery" ) ] discovery : Option < Arc < DiscoveryManager > > ,
115118 ) -> Connection {
116119 Connection {
@@ -122,6 +125,7 @@ impl Connection {
122125 chan : SyncChannel :: new ( 10 ) ,
123126 stats,
124127 txs_limit,
128+ die_please : Some ( die_please) ,
125129 #[ cfg( feature = "electrum-discovery" ) ]
126130 discovery,
127131 }
@@ -501,40 +505,52 @@ impl Connection {
501505 Ok ( ( ) )
502506 }
503507
504- fn handle_replies ( & mut self ) -> Result < ( ) > {
508+ fn handle_replies ( & mut self , shutdown : crossbeam_channel :: Receiver < ( ) > ) -> Result < ( ) > {
505509 let empty_params = json ! ( [ ] ) ;
506510 loop {
507- let msg = self . chan . receiver ( ) . recv ( ) . chain_err ( || "channel closed" ) ?;
508- trace ! ( "RPC {:?}" , msg) ;
509- match msg {
510- Message :: Request ( line) => {
511- let cmd: Value = from_str ( & line) . chain_err ( || "invalid JSON format" ) ?;
512- let reply = match (
513- cmd. get ( "method" ) ,
514- cmd. get ( "params" ) . unwrap_or_else ( || & empty_params) ,
515- cmd. get ( "id" ) ,
516- ) {
517- (
518- Some ( & Value :: String ( ref method) ) ,
519- & Value :: Array ( ref params) ,
520- Some ( ref id) ,
521- ) => self . handle_command ( method, params, id) ?,
522- _ => bail ! ( "invalid command: {}" , cmd) ,
523- } ;
524- self . send_values ( & [ reply] ) ?
511+ crossbeam_channel:: select! {
512+ recv( self . chan. receiver( ) ) -> msg => {
513+ let msg = msg. chain_err( || "channel closed" ) ?;
514+ trace!( "RPC {:?}" , msg) ;
515+ match msg {
516+ Message :: Request ( line) => {
517+ let cmd: Value = from_str( & line) . chain_err( || "invalid JSON format" ) ?;
518+ let reply = match (
519+ cmd. get( "method" ) ,
520+ cmd. get( "params" ) . unwrap_or( & empty_params) ,
521+ cmd. get( "id" ) ,
522+ ) {
523+ ( Some ( Value :: String ( method) ) , Value :: Array ( params) , Some ( id) ) => {
524+ self . handle_command( method, params, id) ?
525+ }
526+ _ => bail!( "invalid command: {}" , cmd) ,
527+ } ;
528+ self . send_values( & [ reply] ) ?
529+ }
530+ Message :: PeriodicUpdate => {
531+ let values = self
532+ . update_subscriptions( )
533+ . chain_err( || "failed to update subscriptions" ) ?;
534+ self . send_values( & values) ?
535+ }
536+ Message :: Done => {
537+ self . chan. close( ) ;
538+ return Ok ( ( ) ) ;
539+ }
540+ }
525541 }
526- Message :: PeriodicUpdate => {
527- let values = self
528- . update_subscriptions ( )
529- . chain_err ( || "failed to update subscriptions" ) ?;
530- self . send_values ( & values) ?
542+ recv( shutdown) -> _ => {
543+ self . chan. close( ) ;
544+ return Ok ( ( ) ) ;
531545 }
532- Message :: Done => return Ok ( ( ) ) ,
533546 }
534547 }
535548 }
536549
537- fn handle_requests ( mut reader : BufReader < TcpStream > , tx : SyncSender < Message > ) -> Result < ( ) > {
550+ fn handle_requests (
551+ mut reader : BufReader < TcpStream > ,
552+ tx : crossbeam_channel:: Sender < Message > ,
553+ ) -> Result < ( ) > {
538554 loop {
539555 let mut line = Vec :: < u8 > :: new ( ) ;
540556 reader
@@ -566,8 +582,24 @@ impl Connection {
566582 self . stats . clients . inc ( ) ;
567583 let reader = BufReader :: new ( self . stream . try_clone ( ) . expect ( "failed to clone TcpStream" ) ) ;
568584 let tx = self . chan . sender ( ) ;
585+
586+ let die_please = self . die_please . take ( ) . unwrap ( ) ;
587+ let ( reply_killer, reply_receiver) = crossbeam_channel:: unbounded ( ) ;
588+
589+ // We create a clone of the stream and put it in an Arc
590+ // This will drop at the end of the function.
591+ let arc_stream = Arc :: new ( self . stream . try_clone ( ) . expect ( "failed to clone TcpStream" ) ) ;
592+ // We don't want to keep the stream alive until SIGINT
593+ // It should drop (close) no matter what.
594+ let maybe_stream = Arc :: downgrade ( & arc_stream) ;
595+ spawn_thread ( "properly-die" , move || {
596+ let _ = die_please. recv ( ) ;
597+ let _ = maybe_stream. upgrade ( ) . map ( |s| s. shutdown ( Shutdown :: Both ) ) ;
598+ let _ = reply_killer. send ( ( ) ) ;
599+ } ) ;
600+
569601 let child = spawn_thread ( "reader" , || Connection :: handle_requests ( reader, tx) ) ;
570- if let Err ( e) = self . handle_replies ( ) {
602+ if let Err ( e) = self . handle_replies ( reply_receiver ) {
571603 error ! (
572604 "[{}] connection handling failed: {}" ,
573605 self . addr,
@@ -580,6 +612,8 @@ impl Connection {
580612 . sub ( self . status_hashes . len ( ) as i64 ) ;
581613
582614 debug ! ( "[{}] shutting down connection" , self . addr) ;
615+ // Drop the Arc so that the stream properly closes.
616+ drop ( arc_stream) ;
583617 let _ = self . stream . shutdown ( Shutdown :: Both ) ;
584618 if let Err ( err) = child. join ( ) . expect ( "receiver panicked" ) {
585619 error ! ( "[{}] receiver failed: {}" , self . addr, err) ;
@@ -633,30 +667,38 @@ struct Stats {
633667impl RPC {
634668 fn start_notifier (
635669 notification : Channel < Notification > ,
636- senders : Arc < Mutex < Vec < SyncSender < Message > > > > ,
670+ senders : Arc < Mutex < Vec < crossbeam_channel :: Sender < Message > > > > ,
637671 acceptor : Sender < Option < ( TcpStream , SocketAddr ) > > ,
672+ acceptor_shutdown : Sender < ( ) > ,
638673 ) {
639674 spawn_thread ( "notification" , move || {
640675 for msg in notification. receiver ( ) . iter ( ) {
641676 let mut senders = senders. lock ( ) . unwrap ( ) ;
642677 match msg {
643678 Notification :: Periodic => {
644679 for sender in senders. split_off ( 0 ) {
645- if let Err ( TrySendError :: Disconnected ( _) ) =
680+ if let Err ( crossbeam_channel :: TrySendError :: Disconnected ( _) ) =
646681 sender. try_send ( Message :: PeriodicUpdate )
647682 {
648683 continue ;
649684 }
650685 senders. push ( sender) ;
651686 }
652687 }
653- Notification :: Exit => acceptor. send ( None ) . unwrap ( ) , // mark acceptor as done
688+ Notification :: Exit => {
689+ acceptor_shutdown. send ( ( ) ) . unwrap ( ) ; // Stop the acceptor itself
690+ acceptor. send ( None ) . unwrap ( ) ; // mark acceptor as done
691+ break ;
692+ }
654693 }
655694 }
656695 } ) ;
657696 }
658697
659- fn start_acceptor ( addr : SocketAddr ) -> Channel < Option < ( TcpStream , SocketAddr ) > > {
698+ fn start_acceptor (
699+ addr : SocketAddr ,
700+ shutdown_channel : Channel < ( ) > ,
701+ ) -> Channel < Option < ( TcpStream , SocketAddr ) > > {
660702 let chan = Channel :: unbounded ( ) ;
661703 let acceptor = chan. sender ( ) ;
662704 spawn_thread ( "acceptor" , move || {
@@ -666,10 +708,29 @@ impl RPC {
666708 . set_nonblocking ( false )
667709 . expect ( "cannot set nonblocking to false" ) ;
668710 let listener = TcpListener :: from ( socket) ;
711+ let local_addr = listener. local_addr ( ) . unwrap ( ) ;
712+ let shutdown_bool = Arc :: new ( AtomicBool :: new ( false ) ) ;
713+
714+ {
715+ let shutdown_bool = Arc :: clone ( & shutdown_bool) ;
716+ crate :: util:: spawn_thread ( "shutdown-acceptor" , move || {
717+ // Block until shutdown is sent.
718+ let _ = shutdown_channel. receiver ( ) . recv ( ) ;
719+ // Store the bool so after the next accept it will break the loop
720+ shutdown_bool. store ( true , std:: sync:: atomic:: Ordering :: Release ) ;
721+ // Connect to the socket to cause it to unblock
722+ let _ = TcpStream :: connect ( local_addr) ;
723+ } ) ;
724+ }
669725
670726 info ! ( "Electrum RPC server running on {}" , addr) ;
671727 loop {
672728 let ( stream, addr) = listener. accept ( ) . expect ( "accept failed" ) ;
729+
730+ if shutdown_bool. load ( std:: sync:: atomic:: Ordering :: Acquire ) {
731+ break ;
732+ }
733+
673734 stream
674735 . set_nonblocking ( false )
675736 . expect ( "failed to set connection as blocking" ) ;
@@ -726,10 +787,18 @@ impl RPC {
726787 RPC {
727788 notification : notification. sender ( ) ,
728789 server : Some ( spawn_thread ( "rpc" , move || {
729- let senders = Arc :: new ( Mutex :: new ( Vec :: < SyncSender < Message > > :: new ( ) ) ) ;
730-
731- let acceptor = RPC :: start_acceptor ( rpc_addr) ;
732- RPC :: start_notifier ( notification, senders. clone ( ) , acceptor. sender ( ) ) ;
790+ let senders =
791+ Arc :: new ( Mutex :: new ( Vec :: < crossbeam_channel:: Sender < Message > > :: new ( ) ) ) ;
792+
793+ let acceptor_shutdown = Channel :: unbounded ( ) ;
794+ let acceptor_shutdown_sender = acceptor_shutdown. sender ( ) ;
795+ let acceptor = RPC :: start_acceptor ( rpc_addr, acceptor_shutdown) ;
796+ RPC :: start_notifier (
797+ notification,
798+ senders. clone ( ) ,
799+ acceptor. sender ( ) ,
800+ acceptor_shutdown_sender,
801+ ) ;
733802
734803 let mut threads = HashMap :: new ( ) ;
735804 let ( garbage_sender, garbage_receiver) = crossbeam_channel:: unbounded ( ) ;
@@ -740,6 +809,10 @@ impl RPC {
740809 let senders = Arc :: clone ( & senders) ;
741810 let stats = Arc :: clone ( & stats) ;
742811 let garbage_sender = garbage_sender. clone ( ) ;
812+
813+ // Kill the peers properly
814+ let ( killer, peace_receiver) = std:: sync:: mpsc:: channel ( ) ;
815+
743816 #[ cfg( feature = "electrum-discovery" ) ]
744817 let discovery = discovery. clone ( ) ;
745818
@@ -751,6 +824,7 @@ impl RPC {
751824 addr,
752825 stats,
753826 txs_limit,
827+ peace_receiver,
754828 #[ cfg( feature = "electrum-discovery" ) ]
755829 discovery,
756830 ) ;
@@ -761,24 +835,29 @@ impl RPC {
761835 } ) ;
762836
763837 trace ! ( "[{}] spawned {:?}" , addr, spawned. thread( ) . id( ) ) ;
764- threads. insert ( spawned. thread ( ) . id ( ) , spawned) ;
838+ threads. insert ( spawned. thread ( ) . id ( ) , ( spawned, killer ) ) ;
765839 while let Ok ( id) = garbage_receiver. try_recv ( ) {
766- if let Some ( thread) = threads. remove ( & id) {
840+ if let Some ( ( thread, killer ) ) = threads. remove ( & id) {
767841 trace ! ( "[{}] joining {:?}" , addr, id) ;
842+ let _ = killer. send ( ( ) ) ;
768843 if let Err ( error) = thread. join ( ) {
769844 error ! ( "failed to join {:?}: {:?}" , id, error) ;
770845 }
771846 }
772847 }
773848 }
849+ // Drop these
850+ drop ( acceptor) ;
851+ drop ( garbage_receiver) ;
774852
775853 trace ! ( "closing {} RPC connections" , senders. lock( ) . unwrap( ) . len( ) ) ;
776854 for sender in senders. lock ( ) . unwrap ( ) . iter ( ) {
777- let _ = sender. send ( Message :: Done ) ;
855+ let _ = sender. try_send ( Message :: Done ) ;
778856 }
779857
780- for ( id, thread) in threads {
858+ for ( id, ( thread, killer ) ) in threads {
781859 trace ! ( "joining {:?}" , id) ;
860+ let _ = killer. send ( ( ) ) ;
782861 if let Err ( error) = thread. join ( ) {
783862 error ! ( "failed to join {:?}: {:?}" , id, error) ;
784863 }
@@ -802,5 +881,8 @@ impl Drop for RPC {
802881 handle. join ( ) . unwrap ( ) ;
803882 }
804883 trace ! ( "RPC server is stopped" ) ;
884+ crate :: util:: with_spawned_threads ( |threads| {
885+ trace ! ( "Threads after dropping RPC: {:?}" , threads) ;
886+ } ) ;
805887 }
806888}
0 commit comments