Skip to content

Commit eabfeb5

Browse files
committed
refactor io.TCP
1 parent 1464fdd commit eabfeb5

File tree

1 file changed

+116
-41
lines changed

1 file changed

+116
-41
lines changed

pylabrobot/io/tcp.py

Lines changed: 116 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
import time
34
from dataclasses import dataclass
45
from typing import Optional
56

@@ -26,11 +27,16 @@ def __init__(self, host: str, port: int = 5000):
2627
self._port = port
2728
self._reader: Optional[asyncio.StreamReader] = None
2829
self._writer: Optional[asyncio.StreamWriter] = None
30+
self._read_buffer = bytearray()
2931

3032
if get_capture_or_validation_active():
3133
raise RuntimeError("Cannot create a new TCP object while capture or validation is active")
3234

3335
async def setup(self):
36+
await self._open_connection()
37+
self._read_buffer = bytearray()
38+
39+
async def _open_connection(self):
3440
self._reader, self._writer = await asyncio.open_connection(self._host, self._port)
3541

3642
async def stop(self):
@@ -40,44 +46,123 @@ async def stop(self):
4046
self._reader = None
4147
self._writer = None
4248

43-
async def write(self, data: bytes):
49+
async def write(self, data: bytes, num_tries: int = 3):
4450
assert self._writer is not None, "forgot to call setup?"
45-
self._writer.write(data + b"\n")
46-
await self._writer.drain()
47-
logger.log(LOG_LEVEL_IO, "[%s:%d] write %s", self._host, self._port, data)
48-
capturer.record(
49-
TCPCommand(
50-
device_id=f"{self._host}:{self._port}", action="write", data=data.decode("unicode_escape")
51-
)
52-
)
5351

54-
async def read(self, num_bytes: int = -1) -> bytes:
52+
last_exc: Optional[Exception] = None
53+
54+
for attempt in range(num_tries):
55+
try:
56+
self._writer.write(data + b"\n") # TODO: this should be a part of PF400, not io.TCP
57+
await self._writer.drain()
58+
except (ConnectionResetError, BrokenPipeError, OSError) as exc:
59+
last_exc = exc
60+
logger.warning(
61+
"TCP write failed with %r on attempt %d/%d; reopening connection",
62+
exc,
63+
attempt + 1,
64+
num_tries,
65+
)
66+
await self._open_connection()
67+
assert self._writer is not None, "_open_connection() failed to set _writer"
68+
else: # success
69+
logger.log(LOG_LEVEL_IO, "[%s:%d] write %s", self._host, self._port, data)
70+
capturer.record(
71+
TCPCommand(
72+
device_id=f"{self._host}:{self._port}",
73+
action="write",
74+
data=data.hex(),
75+
)
76+
)
77+
return
78+
79+
raise ConnectionResetError(f"Max number of retries reached ({num_tries})") from last_exc
80+
81+
async def _raw_read(self, num_bytes: int, num_tries: int) -> bytes:
82+
"""Single low-level read with retries; does not use buffer."""
5583
assert self._reader is not None, "forgot to call setup?"
56-
data = await self._reader.read(num_bytes)
57-
logger.log(LOG_LEVEL_IO, "[%s:%d] read %s", self._host, self._port, data)
58-
capturer.record(
59-
TCPCommand(
60-
device_id=f"{self._host}:{self._port}", action="read", data=data.decode("unicode_escape")
84+
85+
last_exc: Optional[Exception] = None
86+
87+
for attempt in range(num_tries):
88+
try:
89+
data = await self._reader.read(num_bytes)
90+
if data == b"":
91+
# EOF: peer closed the connection
92+
raise ConnectionResetError("Peer closed connection")
93+
except (ConnectionResetError, OSError) as exc:
94+
last_exc = exc
95+
logger.warning(
96+
"TCP read failed with %r on attempt %d/%d; reopening connection",
97+
exc,
98+
attempt + 1,
99+
num_tries,
100+
)
101+
await self._open_connection()
102+
assert self._reader is not None, "_open_connection() failed to set _reader"
103+
continue
104+
105+
logger.log(LOG_LEVEL_IO, "[%s:%d] read %s", self._host, self._port, data)
106+
capturer.record(
107+
TCPCommand(
108+
device_id=f"{self._host}:{self._port}",
109+
action="read",
110+
data=data.hex(),
111+
)
61112
)
62-
)
63-
return data
113+
return data
114+
115+
raise ConnectionResetError(f"Max number of read retries reached ({num_tries})") from last_exc
64116

65-
async def readline(self) -> bytes:
117+
async def read(self, num_bytes: int = 128, num_tries: int = 3) -> bytes:
66118
assert self._reader is not None, "forgot to call setup?"
67119

68-
data = await self._reader.read(128)
69-
last_line = data.split(b"\r\n")[0] # fix for errors with multiplate lines returned
70-
last_line += b"\r\n"
120+
if num_bytes <= 0:
121+
return b""
71122

72-
logger.log(LOG_LEVEL_IO, "[%s:%d] readline %s", self._host, self._port, last_line)
73-
capturer.record(
74-
TCPCommand(
75-
device_id=f"{self._host}:{self._port}",
76-
action="readline",
77-
data=last_line.decode("unicode_escape"),
78-
)
79-
)
80-
return last_line
123+
# Fill buffer until we have enough bytes.
124+
125+
while len(self._read_buffer) < num_bytes:
126+
chunk = await self._raw_read(num_bytes - len(self._read_buffer), num_tries)
127+
self._read_buffer.extend(chunk)
128+
if len(chunk) == 0: # EOF
129+
break
130+
131+
if len(self._read_buffer) < num_bytes:
132+
raise TimeoutError(f"Timeout while waiting for {num_bytes} bytes")
133+
134+
# Consume from buffer, or return empty if buffer is empty.
135+
if len(self._read_buffer) == 0:
136+
return b""
137+
chunk = bytes(self._read_buffer[:num_bytes])
138+
del self._read_buffer[:num_bytes]
139+
return chunk
140+
141+
async def readline(
142+
self, num_tries: int = 3, timeout: float = 60, line_ending: bytes = b"\r\n"
143+
) -> bytes:
144+
assert self._reader is not None, "forgot to call setup?"
145+
146+
CHUNK = 1024
147+
148+
timeout_time = time.time() + timeout
149+
150+
while time.time() < timeout_time:
151+
idx = self._read_buffer.find(line_ending)
152+
if idx != -1:
153+
end = idx + len(line_ending)
154+
line = bytes(self._read_buffer[:end])
155+
del self._read_buffer[:end]
156+
return line
157+
158+
chunk = await self._raw_read(num_bytes=CHUNK, num_tries=num_tries)
159+
self._read_buffer.extend(chunk)
160+
if len(chunk) == 0: # EOF; return what we have
161+
line = bytes(self._read_buffer)
162+
self._read_buffer.clear()
163+
return line
164+
165+
raise TimeoutError(f"Timeout while waiting for line ending with '{line_ending.decode()}'")
81166

82167
def serialize(self):
83168
return {
@@ -122,14 +207,4 @@ async def read(self, num_bytes: int = 128) -> bytes:
122207
and len(next_command.data) == num_bytes
123208
):
124209
raise ValidationError(f"Next line is {next_command}, expected TCP read {num_bytes}")
125-
return next_command.data.encode()
126-
127-
async def readline(self) -> bytes:
128-
next_command = TCPCommand(**self.cr.next_command())
129-
if not (
130-
next_command.module == "tcp"
131-
and next_command.device_id == f"{self._host}:{self._port}"
132-
and next_command.action == "readline"
133-
):
134-
raise ValidationError(f"Next line is {next_command}, expected TCP readline")
135-
return next_command.data.encode()
210+
return bytes.fromhex(next_command.data)

0 commit comments

Comments
 (0)