6
6
import sys
7
7
from collections import deque
8
8
from signal import Signals
9
- from typing import Any , Callable , Deque , List , Optional
9
+ from typing import Any , Callable , Deque , List , Optional , cast
10
10
11
- from pynvim .msgpack_rpc .event_loop .base import BaseEventLoop
11
+ if sys .version_info >= (3 , 12 ):
12
+ from typing import Final , override
13
+ else :
14
+ from typing_extensions import Final , override
15
+
16
+ from pynvim .msgpack_rpc .event_loop .base import BaseEventLoop , TTransportType
12
17
13
18
logger = logging .getLogger (__name__ )
14
19
debug , info , warn = (logger .debug , logger .info , logger .warning ,)
27
32
28
33
# pylint: disable=logging-fstring-interpolation
29
34
30
- class AsyncioEventLoop (BaseEventLoop , asyncio .Protocol ,
31
- asyncio .SubprocessProtocol ):
32
- """`BaseEventLoop` subclass that uses `asyncio` as a backend."""
35
+ class Protocol (asyncio .Protocol , asyncio .SubprocessProtocol ):
36
+ """The protocol class used for asyncio-based RPC communication."""
33
37
34
- _queued_data : Deque [bytes ]
35
- if os .name != 'nt' :
36
- _child_watcher : Optional ['asyncio.AbstractChildWatcher' ]
38
+ def __init__ (self , on_data , on_error ):
39
+ """Initialize the Protocol object."""
40
+ assert on_data is not None
41
+ assert on_error is not None
42
+ self ._on_data = on_data
43
+ self ._on_error = on_error
37
44
45
+ @override
38
46
def connection_made (self , transport ):
39
47
"""Used to signal `asyncio.Protocol` of a successful connection."""
40
- self ._transport = transport
41
- self ._raw_transport = transport
42
- if isinstance (transport , asyncio .SubprocessTransport ):
43
- self ._transport = transport .get_pipe_transport (0 )
48
+ del transport # no-op
44
49
45
- def connection_lost (self , exc ):
50
+ @override
51
+ def connection_lost (self , exc : Optional [Exception ]) -> None :
46
52
"""Used to signal `asyncio.Protocol` of a lost connection."""
47
53
debug (f"connection_lost: exc = { exc } " )
48
- self ._on_error (exc . args [ 0 ] if exc else 'EOF' )
54
+ self ._on_error (exc if exc else EOFError () )
49
55
56
+ @override
50
57
def data_received (self , data : bytes ) -> None :
51
58
"""Used to signal `asyncio.Protocol` of incoming data."""
52
- if self ._on_data :
53
- self ._on_data (data )
54
- return
55
- self ._queued_data .append (data )
59
+ self ._on_data (data )
56
60
57
- def pipe_connection_lost (self , fd , exc ):
61
+ @override
62
+ def pipe_connection_lost (self , fd : int , exc : Optional [Exception ]) -> None :
58
63
"""Used to signal `asyncio.SubprocessProtocol` of a lost connection."""
59
64
debug ("pipe_connection_lost: fd = %s, exc = %s" , fd , exc )
60
65
if os .name == 'nt' and fd == 2 : # stderr
61
66
# On windows, ignore piped stderr being closed immediately (#505)
62
67
return
63
- self ._on_error (exc . args [ 0 ] if exc else 'EOF' )
68
+ self ._on_error (exc if exc else EOFError () )
64
69
70
+ @override
65
71
def pipe_data_received (self , fd , data ):
66
72
"""Used to signal `asyncio.SubprocessProtocol` of incoming data."""
67
73
if fd == 2 : # stderr fd number
68
74
# Ignore stderr message, log only for debugging
69
75
debug ("stderr: %s" , str (data ))
70
- elif self ._on_data :
71
- self ._on_data (data )
72
- else :
73
- self ._queued_data .append (data )
76
+ elif fd == 1 : # stdout
77
+ self .data_received (data )
74
78
79
+ @override
75
80
def process_exited (self ) -> None :
76
81
"""Used to signal `asyncio.SubprocessProtocol` when the child exits."""
77
82
debug ("process_exited" )
78
- self ._on_error ('EOF' )
83
+ self ._on_error (EOFError ())
84
+
85
+
86
+ class AsyncioEventLoop (BaseEventLoop ):
87
+ """`BaseEventLoop` subclass that uses core `asyncio` as a backend."""
88
+
89
+ _protocol : Optional [Protocol ]
90
+ _transport : Optional [asyncio .WriteTransport ]
91
+ _signals : List [Signals ]
92
+ _data_buffer : Deque [bytes ]
93
+ if os .name != 'nt' :
94
+ _child_watcher : Optional ['asyncio.AbstractChildWatcher' ]
79
95
80
- def _init (self ) -> None :
81
- self ._loop = loop_cls ()
82
- self ._queued_data = deque ()
83
- self ._fact = lambda : self
96
+ def __init__ (self ,
97
+ transport_type : TTransportType ,
98
+ * args : Any , ** kwargs : Any ):
99
+ """asyncio-specific initialization. see BaseEventLoop.__init__."""
100
+
101
+ # The underlying asyncio event loop.
102
+ self ._loop : Final [asyncio .AbstractEventLoop ] = loop_cls ()
103
+
104
+ # Handle messages from nvim that may arrive before run() starts.
105
+ self ._data_buffer = deque ()
106
+
107
+ def _on_data (data : bytes ) -> None :
108
+ if self ._on_data is None :
109
+ self ._data_buffer .append (data )
110
+ return
111
+ self ._on_data (data )
112
+
113
+ # pylint: disable-next=unnecessary-lambda
114
+ self ._protocol_factory = lambda : Protocol (
115
+ on_data = _on_data ,
116
+ on_error = self ._on_error ,
117
+ )
118
+ self ._protocol = None
119
+
120
+ # The communication channel (endpoint) created by _connect_*() method.
121
+ self ._transport = None
84
122
self ._raw_transport = None
85
123
self ._child_watcher = None
86
124
125
+ super ().__init__ (transport_type , * args , ** kwargs )
126
+
127
+ @override
87
128
def _connect_tcp (self , address : str , port : int ) -> None :
88
129
async def connect_tcp ():
89
- await self ._loop .create_connection (self ._fact , address , port )
130
+ transport , protocol = await self ._loop .create_connection (
131
+ self ._protocol_factory , address , port )
90
132
debug (f"tcp connection successful: { address } :{ port } " )
133
+ self ._transport = transport
134
+ self ._protocol = protocol
91
135
92
136
self ._loop .run_until_complete (connect_tcp ())
93
137
138
+ @override
94
139
def _connect_socket (self , path : str ) -> None :
95
140
async def connect_socket ():
96
141
if os .name == 'nt' :
97
- transport , _ = await self ._loop .create_pipe_connection ( self . _fact , path )
142
+ _create_connection = self ._loop .create_pipe_connection
98
143
else :
99
- transport , _ = await self ._loop .create_unix_connection (self ._fact , path )
100
- debug ("socket connection successful: %s" , transport )
144
+ _create_connection = self ._loop .create_unix_connection
145
+
146
+ transport , protocol = await _create_connection (
147
+ self ._protocol_factory , path )
148
+ debug ("socket connection successful: %s" , self ._transport )
149
+ self ._transport = transport
150
+ self ._protocol = protocol
101
151
102
152
self ._loop .run_until_complete (connect_socket ())
103
153
154
+ @override
104
155
def _connect_stdio (self ) -> None :
105
156
async def connect_stdin ():
106
157
if os .name == 'nt' :
107
158
pipe = PipeHandle (msvcrt .get_osfhandle (sys .stdin .fileno ()))
108
159
else :
109
160
pipe = sys .stdin
110
- await self ._loop .connect_read_pipe (self ._fact , pipe )
161
+ transport , protocol = await self ._loop .connect_read_pipe (
162
+ self ._protocol_factory , pipe )
111
163
debug ("native stdin connection successful" )
164
+ del transport , protocol
112
165
self ._loop .run_until_complete (connect_stdin ())
113
166
114
167
# Make sure subprocesses don't clobber stdout,
@@ -122,52 +175,74 @@ async def connect_stdout():
122
175
else :
123
176
pipe = os .fdopen (rename_stdout , 'wb' )
124
177
125
- await self ._loop .connect_write_pipe (self ._fact , pipe )
178
+ transport , protocol = await self ._loop .connect_write_pipe (
179
+ self ._protocol_factory , pipe )
126
180
debug ("native stdout connection successful" )
127
-
181
+ self ._transport = transport
182
+ self ._protocol = protocol
128
183
self ._loop .run_until_complete (connect_stdout ())
129
184
185
+ @override
130
186
def _connect_child (self , argv : List [str ]) -> None :
131
187
if os .name != 'nt' :
132
188
# see #238, #241
133
- _child_watcher = asyncio .get_child_watcher ()
134
- _child_watcher .attach_loop (self ._loop )
189
+ self . _child_watcher = asyncio .get_child_watcher ()
190
+ self . _child_watcher .attach_loop (self ._loop )
135
191
136
192
async def create_subprocess ():
137
- transport : asyncio .SubprocessTransport
138
- transport , protocol = await self ._loop .subprocess_exec (self ._fact , * argv )
193
+ transport : asyncio .SubprocessTransport # type: ignore
194
+ transport , protocol = await self ._loop .subprocess_exec (
195
+ self ._protocol_factory , * argv )
139
196
pid = transport .get_pid ()
140
197
debug ("child subprocess_exec successful, PID = %s" , pid )
141
198
199
+ self ._transport = cast (asyncio .WriteTransport ,
200
+ transport .get_pipe_transport (0 )) # stdin
201
+ self ._protocol = protocol
202
+
203
+ # await until child process have been launched and the transport has
204
+ # been established
142
205
self ._loop .run_until_complete (create_subprocess ())
143
206
207
+ @override
144
208
def _start_reading (self ) -> None :
145
209
pass
146
210
211
+ @override
147
212
def _send (self , data : bytes ) -> None :
213
+ assert self ._transport , "connection has not been established."
148
214
self ._transport .write (data )
149
215
216
+ @override
150
217
def _run (self ) -> None :
151
- while self ._queued_data :
152
- data = self ._queued_data .popleft ()
218
+ # process the early messages that arrived as soon as the transport
219
+ # channels are open and on_data is fully ready to receive messages.
220
+ while self ._data_buffer :
221
+ data : bytes = self ._data_buffer .popleft ()
153
222
if self ._on_data is not None :
154
223
self ._on_data (data )
224
+
155
225
self ._loop .run_forever ()
156
226
227
+ @override
157
228
def _stop (self ) -> None :
158
229
self ._loop .stop ()
159
230
231
+ @override
160
232
def _close (self ) -> None :
233
+ # TODO close all the transports
161
234
if self ._raw_transport is not None :
162
- self ._raw_transport .close ()
235
+ self ._raw_transport .close () # type: ignore[unreachable]
163
236
self ._loop .close ()
164
237
if self ._child_watcher is not None :
165
238
self ._child_watcher .close ()
166
239
self ._child_watcher = None
167
240
241
+ @override
168
242
def _threadsafe_call (self , fn : Callable [[], Any ]) -> None :
169
243
self ._loop .call_soon_threadsafe (fn )
170
244
245
+ @override
171
246
def _setup_signals (self , signals : List [Signals ]) -> None :
172
247
if os .name == 'nt' :
173
248
# add_signal_handler is not supported in win32
@@ -178,6 +253,7 @@ def _setup_signals(self, signals: List[Signals]) -> None:
178
253
for signum in self ._signals :
179
254
self ._loop .add_signal_handler (signum , self ._on_signal , signum )
180
255
256
+ @override
181
257
def _teardown_signals (self ) -> None :
182
258
for signum in self ._signals :
183
259
self ._loop .remove_signal_handler (signum )
0 commit comments