Skip to content

Commit 3f5f532

Browse files
committed
refactor: separate asyncio Protocol from AsyncioEventLoop
- Separate Protocol from AsyncioEventLoop (which were too complex). This makes it much clearer to discern which methods are asyncio- specific specializations of the abstract base event loop class (used for managing the lifecycle of event loop itself) and which serves as callback functions for IPC communication. - Document the design and the lifecycle of the BaseEventLoop class. Although asyncio is the only existing implementation, the current behavior or abstraction is documented (until further refactorings) to avoid confusions and clarify how the subclass should be implemented. - Use `typing.override` in the AsyncioEventLoop subclass (requires typing-extensions >= 4.5.0).
1 parent 2059684 commit 3f5f532

File tree

5 files changed

+184
-94
lines changed

5 files changed

+184
-94
lines changed

pynvim/msgpack_rpc/event_loop/asyncio.py

Lines changed: 119 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
import sys
77
from collections import deque
88
from signal import Signals
9-
from typing import Any, Callable, Deque, List, Optional
9+
from typing import Any, Callable, Deque, List, Optional, cast
1010

11-
from pynvim.msgpack_rpc.event_loop.base import BaseEventLoop
11+
if sys.version_info >= (3, 12):
12+
from typing import Final, override
13+
else:
14+
from typing_extensions import Final, override
15+
16+
from pynvim.msgpack_rpc.event_loop.base import BaseEventLoop, TTransportType
1217

1318
logger = logging.getLogger(__name__)
1419
debug, info, warn = (logger.debug, logger.info, logger.warning,)
@@ -27,88 +32,136 @@
2732

2833
# pylint: disable=logging-fstring-interpolation
2934

30-
class AsyncioEventLoop(BaseEventLoop, asyncio.Protocol,
31-
asyncio.SubprocessProtocol):
32-
"""`BaseEventLoop` subclass that uses `asyncio` as a backend."""
35+
class Protocol(asyncio.Protocol, asyncio.SubprocessProtocol):
36+
"""The protocol class used for asyncio-based RPC communication."""
3337

34-
_queued_data: Deque[bytes]
35-
if os.name != 'nt':
36-
_child_watcher: Optional['asyncio.AbstractChildWatcher']
38+
def __init__(self, on_data, on_error):
39+
"""Initialize the Protocol object."""
40+
assert on_data is not None
41+
assert on_error is not None
42+
self._on_data = on_data
43+
self._on_error = on_error
3744

45+
@override
3846
def connection_made(self, transport):
3947
"""Used to signal `asyncio.Protocol` of a successful connection."""
40-
self._transport = transport
41-
self._raw_transport = transport
42-
if isinstance(transport, asyncio.SubprocessTransport):
43-
self._transport = transport.get_pipe_transport(0)
48+
del transport # no-op
4449

45-
def connection_lost(self, exc):
50+
@override
51+
def connection_lost(self, exc: Optional[Exception]) -> None:
4652
"""Used to signal `asyncio.Protocol` of a lost connection."""
4753
debug(f"connection_lost: exc = {exc}")
48-
self._on_error(exc.args[0] if exc else 'EOF')
54+
self._on_error(exc if exc else EOFError())
4955

56+
@override
5057
def data_received(self, data: bytes) -> None:
5158
"""Used to signal `asyncio.Protocol` of incoming data."""
52-
if self._on_data:
53-
self._on_data(data)
54-
return
55-
self._queued_data.append(data)
59+
self._on_data(data)
5660

57-
def pipe_connection_lost(self, fd, exc):
61+
@override
62+
def pipe_connection_lost(self, fd: int, exc: Optional[Exception]) -> None:
5863
"""Used to signal `asyncio.SubprocessProtocol` of a lost connection."""
5964
debug("pipe_connection_lost: fd = %s, exc = %s", fd, exc)
6065
if os.name == 'nt' and fd == 2: # stderr
6166
# On windows, ignore piped stderr being closed immediately (#505)
6267
return
63-
self._on_error(exc.args[0] if exc else 'EOF')
68+
self._on_error(exc if exc else EOFError())
6469

70+
@override
6571
def pipe_data_received(self, fd, data):
6672
"""Used to signal `asyncio.SubprocessProtocol` of incoming data."""
6773
if fd == 2: # stderr fd number
6874
# Ignore stderr message, log only for debugging
6975
debug("stderr: %s", str(data))
70-
elif self._on_data:
71-
self._on_data(data)
72-
else:
73-
self._queued_data.append(data)
76+
elif fd == 1: # stdout
77+
self.data_received(data)
7478

79+
@override
7580
def process_exited(self) -> None:
7681
"""Used to signal `asyncio.SubprocessProtocol` when the child exits."""
7782
debug("process_exited")
78-
self._on_error('EOF')
83+
self._on_error(EOFError())
84+
85+
86+
class AsyncioEventLoop(BaseEventLoop):
87+
"""`BaseEventLoop` subclass that uses core `asyncio` as a backend."""
88+
89+
_protocol: Optional[Protocol]
90+
_transport: Optional[asyncio.WriteTransport]
91+
_signals: List[Signals]
92+
_data_buffer: Deque[bytes]
93+
if os.name != 'nt':
94+
_child_watcher: Optional['asyncio.AbstractChildWatcher']
7995

80-
def _init(self) -> None:
81-
self._loop = loop_cls()
82-
self._queued_data = deque()
83-
self._fact = lambda: self
96+
def __init__(self,
97+
transport_type: TTransportType,
98+
*args: Any, **kwargs: Any):
99+
"""asyncio-specific initialization. see BaseEventLoop.__init__."""
100+
101+
# The underlying asyncio event loop.
102+
self._loop: Final[asyncio.AbstractEventLoop] = loop_cls()
103+
104+
# Handle messages from nvim that may arrive before run() starts.
105+
self._data_buffer = deque()
106+
107+
def _on_data(data: bytes) -> None:
108+
if self._on_data is None:
109+
self._data_buffer.append(data)
110+
return
111+
self._on_data(data)
112+
113+
# pylint: disable-next=unnecessary-lambda
114+
self._protocol_factory = lambda: Protocol(
115+
on_data=_on_data,
116+
on_error=self._on_error,
117+
)
118+
self._protocol = None
119+
120+
# The communication channel (endpoint) created by _connect_*() method.
121+
self._transport = None
84122
self._raw_transport = None
85123
self._child_watcher = None
86124

125+
super().__init__(transport_type, *args, **kwargs)
126+
127+
@override
87128
def _connect_tcp(self, address: str, port: int) -> None:
88129
async def connect_tcp():
89-
await self._loop.create_connection(self._fact, address, port)
130+
transport, protocol = await self._loop.create_connection(
131+
self._protocol_factory, address, port)
90132
debug(f"tcp connection successful: {address}:{port}")
133+
self._transport = transport
134+
self._protocol = protocol
91135

92136
self._loop.run_until_complete(connect_tcp())
93137

138+
@override
94139
def _connect_socket(self, path: str) -> None:
95140
async def connect_socket():
96141
if os.name == 'nt':
97-
transport, _ = await self._loop.create_pipe_connection(self._fact, path)
142+
_create_connection = self._loop.create_pipe_connection
98143
else:
99-
transport, _ = await self._loop.create_unix_connection(self._fact, path)
100-
debug("socket connection successful: %s", transport)
144+
_create_connection = self._loop.create_unix_connection
145+
146+
transport, protocol = await _create_connection(
147+
self._protocol_factory, path)
148+
debug("socket connection successful: %s", self._transport)
149+
self._transport = transport
150+
self._protocol = protocol
101151

102152
self._loop.run_until_complete(connect_socket())
103153

154+
@override
104155
def _connect_stdio(self) -> None:
105156
async def connect_stdin():
106157
if os.name == 'nt':
107158
pipe = PipeHandle(msvcrt.get_osfhandle(sys.stdin.fileno()))
108159
else:
109160
pipe = sys.stdin
110-
await self._loop.connect_read_pipe(self._fact, pipe)
161+
transport, protocol = await self._loop.connect_read_pipe(
162+
self._protocol_factory, pipe)
111163
debug("native stdin connection successful")
164+
del transport, protocol
112165
self._loop.run_until_complete(connect_stdin())
113166

114167
# Make sure subprocesses don't clobber stdout,
@@ -122,52 +175,74 @@ async def connect_stdout():
122175
else:
123176
pipe = os.fdopen(rename_stdout, 'wb')
124177

125-
await self._loop.connect_write_pipe(self._fact, pipe)
178+
transport, protocol = await self._loop.connect_write_pipe(
179+
self._protocol_factory, pipe)
126180
debug("native stdout connection successful")
127-
181+
self._transport = transport
182+
self._protocol = protocol
128183
self._loop.run_until_complete(connect_stdout())
129184

185+
@override
130186
def _connect_child(self, argv: List[str]) -> None:
131187
if os.name != 'nt':
132188
# see #238, #241
133-
_child_watcher = asyncio.get_child_watcher()
134-
_child_watcher.attach_loop(self._loop)
189+
self._child_watcher = asyncio.get_child_watcher()
190+
self._child_watcher.attach_loop(self._loop)
135191

136192
async def create_subprocess():
137-
transport: asyncio.SubprocessTransport
138-
transport, protocol = await self._loop.subprocess_exec(self._fact, *argv)
193+
transport: asyncio.SubprocessTransport # type: ignore
194+
transport, protocol = await self._loop.subprocess_exec(
195+
self._protocol_factory, *argv)
139196
pid = transport.get_pid()
140197
debug("child subprocess_exec successful, PID = %s", pid)
141198

199+
self._transport = cast(asyncio.WriteTransport,
200+
transport.get_pipe_transport(0)) # stdin
201+
self._protocol = protocol
202+
203+
# await until child process have been launched and the transport has
204+
# been established
142205
self._loop.run_until_complete(create_subprocess())
143206

207+
@override
144208
def _start_reading(self) -> None:
145209
pass
146210

211+
@override
147212
def _send(self, data: bytes) -> None:
213+
assert self._transport, "connection has not been established."
148214
self._transport.write(data)
149215

216+
@override
150217
def _run(self) -> None:
151-
while self._queued_data:
152-
data = self._queued_data.popleft()
218+
# process the early messages that arrived as soon as the transport
219+
# channels are open and on_data is fully ready to receive messages.
220+
while self._data_buffer:
221+
data: bytes = self._data_buffer.popleft()
153222
if self._on_data is not None:
154223
self._on_data(data)
224+
155225
self._loop.run_forever()
156226

227+
@override
157228
def _stop(self) -> None:
158229
self._loop.stop()
159230

231+
@override
160232
def _close(self) -> None:
233+
# TODO close all the transports
161234
if self._raw_transport is not None:
162-
self._raw_transport.close()
235+
self._raw_transport.close() # type: ignore[unreachable]
163236
self._loop.close()
164237
if self._child_watcher is not None:
165238
self._child_watcher.close()
166239
self._child_watcher = None
167240

241+
@override
168242
def _threadsafe_call(self, fn: Callable[[], Any]) -> None:
169243
self._loop.call_soon_threadsafe(fn)
170244

245+
@override
171246
def _setup_signals(self, signals: List[Signals]) -> None:
172247
if os.name == 'nt':
173248
# add_signal_handler is not supported in win32
@@ -178,6 +253,7 @@ def _setup_signals(self, signals: List[Signals]) -> None:
178253
for signum in self._signals:
179254
self._loop.add_signal_handler(signum, self._on_signal, signum)
180255

256+
@override
181257
def _teardown_signals(self) -> None:
182258
for signum in self._signals:
183259
self._loop.remove_signal_handler(signum)

0 commit comments

Comments
 (0)