-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path_protocol.py
132 lines (97 loc) · 4.74 KB
/
_protocol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations
import asyncio
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast
from typing_extensions import Buffer
from ._transport import (
_BaseTransportProxy,
_DatagramTransportProxy,
_make_transport_proxy,
)
if TYPE_CHECKING:
from ._loop import LoopProxy
class _BaseProtocolProxy(asyncio.BaseProtocol):
def __init__(self, protocol: asyncio.BaseProtocol, loop: LoopProxy) -> None:
self._loop = loop
self.protocol = protocol
self.transport: _BaseTransportProxy | None = None
self.wait_closed = self._loop.create_future()
def __repr__(self) -> str:
return repr(self.protocol)
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = _make_transport_proxy(transport, self._loop)
self._loop._wrap_cb(self.protocol.connection_made, self.transport)
def connection_lost(self, exc: Exception | None) -> None:
self._loop._wrap_cb(self.protocol.connection_lost, exc)
self.wait_closed.set_result(None)
def pause_writing(self) -> None:
self._loop._wrap_cb(self.protocol.pause_writing)
def resume_writing(self) -> None:
self._loop._wrap_cb(self.protocol.resume_writing)
class _ProtocolProxy(_BaseProtocolProxy, asyncio.Protocol):
def data_received(self, data: bytes) -> None:
protocol = cast(asyncio.Protocol, self.protocol)
self._loop._wrap_cb(protocol.data_received, data)
def eof_received(self) -> None:
protocol = cast(asyncio.Protocol, self.protocol)
self._loop._wrap_cb(protocol.eof_received)
class _BufferedProtocolProxy(_BaseProtocolProxy, asyncio.BufferedProtocol):
def get_buffer(self, sizehint: int) -> Buffer:
protocol = cast(asyncio.BufferedProtocol, self.protocol)
return self._loop._wrap_cb(protocol.get_buffer, sizehint)
def buffer_updated(self, nbytes: int) -> None:
protocol = cast(asyncio.BufferedProtocol, self.protocol)
self._loop._wrap_cb(protocol.buffer_updated, nbytes)
def eof_received(self) -> None:
protocol = cast(asyncio.BufferedProtocol, self.protocol)
self._loop._wrap_cb(protocol.eof_received)
class _UniversalProtocolProxy(_BufferedProtocolProxy, _ProtocolProxy):
pass
class _DatagramProtocolProxy(_BaseProtocolProxy, asyncio.DatagramProtocol):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
# asyncio has wrong DatagramTransport inheritance, auto-selection by
# original type doesn't work.
# See https://bugs.python.org/issue46194
self.transport = _DatagramTransportProxy(transport, self._loop)
self._loop._wrap_cb(self.protocol.connection_made, self.transport)
def datagram_received(self, data: bytes, addr: Any) -> None:
protocol = cast(asyncio.DatagramProtocol, self.protocol)
self._loop._wrap_cb(protocol.datagram_received, data, addr)
def error_received(self, exc: Exception) -> None:
protocol = cast(asyncio.DatagramProtocol, self.protocol)
self._loop._wrap_cb(protocol.error_received, exc)
class _SubprocessProtocolProxy(_BaseProtocolProxy, asyncio.SubprocessProtocol):
def pipe_data_received(self, fd: int, data: bytes) -> None:
protocol = cast(asyncio.SubprocessProtocol, self.protocol)
self._loop._wrap_cb(protocol.pipe_data_received, fd, data)
def pipe_connection_lost(self, fd: int, exc: Exception | None) -> None:
protocol = cast(asyncio.SubprocessProtocol, self.protocol)
self._loop._wrap_cb(protocol.pipe_connection_lost, fd, exc)
def process_exited(self) -> None:
protocol = cast(asyncio.SubprocessProtocol, self.protocol)
self._loop._wrap_cb(protocol.process_exited)
_MAP = (
(asyncio.SubprocessProtocol, _SubprocessProtocolProxy),
(asyncio.DatagramProtocol, _DatagramProtocolProxy),
(asyncio.BufferedProtocol, _BufferedProtocolProxy),
(asyncio.Protocol, _ProtocolProxy),
(asyncio.BaseProtocol, _BaseProtocolProxy),
)
def _proto_proxy(original: asyncio.BaseProtocol, loop: LoopProxy) -> _BaseProtocolProxy:
if isinstance(original, asyncio.BufferedProtocol) and isinstance(
original, asyncio.Protocol
):
return _UniversalProtocolProxy(original, loop)
for orig_type, proxy_type in _MAP:
if isinstance(original, orig_type):
return proxy_type(original, loop)
else:
raise RuntimeError(f"Cannot find protocol proxy for {original!r}")
def _proto_proxy_factory(
original_factory: Callable[[], asyncio.BaseProtocol],
loop: LoopProxy,
) -> Callable[[], _BaseProtocolProxy]:
def factory() -> _BaseProtocolProxy:
original = original_factory()
return _proto_proxy(original, loop)
return factory