11import asyncio
22import logging
3+ import time
34from dataclasses import dataclass
45from typing import Optional
56
@@ -26,11 +27,16 @@ def __init__(self, host: str, port: int = 5000):
2627 self ._port = port
2728 self ._reader : Optional [asyncio .StreamReader ] = None
2829 self ._writer : Optional [asyncio .StreamWriter ] = None
30+ self ._read_buffer = bytearray ()
2931
3032 if get_capture_or_validation_active ():
3133 raise RuntimeError ("Cannot create a new TCP object while capture or validation is active" )
3234
3335 async def setup (self ):
36+ await self ._open_connection ()
37+ self ._read_buffer = bytearray ()
38+
39+ async def _open_connection (self ):
3440 self ._reader , self ._writer = await asyncio .open_connection (self ._host , self ._port )
3541
3642 async def stop (self ):
@@ -40,44 +46,123 @@ async def stop(self):
4046 self ._reader = None
4147 self ._writer = None
4248
43- async def write (self , data : bytes ):
49+ async def write (self , data : bytes , num_tries : int = 3 ):
4450 assert self ._writer is not None , "forgot to call setup?"
45- self ._writer .write (data + b"\n " )
46- await self ._writer .drain ()
47- logger .log (LOG_LEVEL_IO , "[%s:%d] write %s" , self ._host , self ._port , data )
48- capturer .record (
49- TCPCommand (
50- device_id = f"{ self ._host } :{ self ._port } " , action = "write" , data = data .decode ("unicode_escape" )
51- )
52- )
5351
54- async def read (self , num_bytes : int = - 1 ) -> bytes :
52+ last_exc : Optional [Exception ] = None
53+
54+ for attempt in range (num_tries ):
55+ try :
56+ self ._writer .write (data + b"\n " ) # TODO: this should be a part of PF400, not io.TCP
57+ await self ._writer .drain ()
58+ except (ConnectionResetError , BrokenPipeError , OSError ) as exc :
59+ last_exc = exc
60+ logger .warning (
61+ "TCP write failed with %r on attempt %d/%d; reopening connection" ,
62+ exc ,
63+ attempt + 1 ,
64+ num_tries ,
65+ )
66+ await self ._open_connection ()
67+ assert self ._writer is not None , "_open_connection() failed to set _writer"
68+ else : # success
69+ logger .log (LOG_LEVEL_IO , "[%s:%d] write %s" , self ._host , self ._port , data )
70+ capturer .record (
71+ TCPCommand (
72+ device_id = f"{ self ._host } :{ self ._port } " ,
73+ action = "write" ,
74+ data = data .hex (),
75+ )
76+ )
77+ return
78+
79+ raise ConnectionResetError (f"Max number of retries reached ({ num_tries } )" ) from last_exc
80+
81+ async def _raw_read (self , num_bytes : int , num_tries : int ) -> bytes :
82+ """Single low-level read with retries; does not use buffer."""
5583 assert self ._reader is not None , "forgot to call setup?"
56- data = await self ._reader .read (num_bytes )
57- logger .log (LOG_LEVEL_IO , "[%s:%d] read %s" , self ._host , self ._port , data )
58- capturer .record (
59- TCPCommand (
60- device_id = f"{ self ._host } :{ self ._port } " , action = "read" , data = data .decode ("unicode_escape" )
84+
85+ last_exc : Optional [Exception ] = None
86+
87+ for attempt in range (num_tries ):
88+ try :
89+ data = await self ._reader .read (num_bytes )
90+ if data == b"" :
91+ # EOF: peer closed the connection
92+ raise ConnectionResetError ("Peer closed connection" )
93+ except (ConnectionResetError , OSError ) as exc :
94+ last_exc = exc
95+ logger .warning (
96+ "TCP read failed with %r on attempt %d/%d; reopening connection" ,
97+ exc ,
98+ attempt + 1 ,
99+ num_tries ,
100+ )
101+ await self ._open_connection ()
102+ assert self ._reader is not None , "_open_connection() failed to set _reader"
103+ continue
104+
105+ logger .log (LOG_LEVEL_IO , "[%s:%d] read %s" , self ._host , self ._port , data )
106+ capturer .record (
107+ TCPCommand (
108+ device_id = f"{ self ._host } :{ self ._port } " ,
109+ action = "read" ,
110+ data = data .hex (),
111+ )
61112 )
62- )
63- return data
113+ return data
114+
115+ raise ConnectionResetError (f"Max number of read retries reached ({ num_tries } )" ) from last_exc
64116
65- async def readline (self ) -> bytes :
117+ async def read (self , num_bytes : int = 128 , num_tries : int = 3 ) -> bytes :
66118 assert self ._reader is not None , "forgot to call setup?"
67119
68- data = await self ._reader .read (128 )
69- last_line = data .split (b"\r \n " )[0 ] # fix for errors with multiplate lines returned
70- last_line += b"\r \n "
120+ if num_bytes <= 0 :
121+ return b""
71122
72- logger .log (LOG_LEVEL_IO , "[%s:%d] readline %s" , self ._host , self ._port , last_line )
73- capturer .record (
74- TCPCommand (
75- device_id = f"{ self ._host } :{ self ._port } " ,
76- action = "readline" ,
77- data = last_line .decode ("unicode_escape" ),
78- )
79- )
80- return last_line
123+ # Fill buffer until we have enough bytes.
124+
125+ while len (self ._read_buffer ) < num_bytes :
126+ chunk = await self ._raw_read (num_bytes - len (self ._read_buffer ), num_tries )
127+ self ._read_buffer .extend (chunk )
128+ if len (chunk ) == 0 : # EOF
129+ break
130+
131+ if len (self ._read_buffer ) < num_bytes :
132+ raise TimeoutError (f"Timeout while waiting for { num_bytes } bytes" )
133+
134+ # Consume from buffer, or return empty if buffer is empty.
135+ if len (self ._read_buffer ) == 0 :
136+ return b""
137+ chunk = bytes (self ._read_buffer [:num_bytes ])
138+ del self ._read_buffer [:num_bytes ]
139+ return chunk
140+
141+ async def readline (
142+ self , num_tries : int = 3 , timeout : float = 60 , line_ending : bytes = b"\r \n "
143+ ) -> bytes :
144+ assert self ._reader is not None , "forgot to call setup?"
145+
146+ CHUNK = 1024
147+
148+ timeout_time = time .time () + timeout
149+
150+ while time .time () < timeout_time :
151+ idx = self ._read_buffer .find (line_ending )
152+ if idx != - 1 :
153+ end = idx + len (line_ending )
154+ line = bytes (self ._read_buffer [:end ])
155+ del self ._read_buffer [:end ]
156+ return line
157+
158+ chunk = await self ._raw_read (num_bytes = CHUNK , num_tries = num_tries )
159+ self ._read_buffer .extend (chunk )
160+ if len (chunk ) == 0 : # EOF; return what we have
161+ line = bytes (self ._read_buffer )
162+ self ._read_buffer .clear ()
163+ return line
164+
165+ raise TimeoutError (f"Timeout while waiting for line ending with '{ line_ending .decode ()} '" )
81166
82167 def serialize (self ):
83168 return {
@@ -122,14 +207,4 @@ async def read(self, num_bytes: int = 128) -> bytes:
122207 and len (next_command .data ) == num_bytes
123208 ):
124209 raise ValidationError (f"Next line is { next_command } , expected TCP read { num_bytes } " )
125- return next_command .data .encode ()
126-
127- async def readline (self ) -> bytes :
128- next_command = TCPCommand (** self .cr .next_command ())
129- if not (
130- next_command .module == "tcp"
131- and next_command .device_id == f"{ self ._host } :{ self ._port } "
132- and next_command .action == "readline"
133- ):
134- raise ValidationError (f"Next line is { next_command } , expected TCP readline" )
135- return next_command .data .encode ()
210+ return bytes .fromhex (next_command .data )
0 commit comments