11import asyncio
22import logging
3+ import time
34from dataclasses import dataclass
45from typing import Optional
56
@@ -26,11 +27,16 @@ def __init__(self, host: str, port: int = 5000):
2627 self ._port = port
2728 self ._reader : Optional [asyncio .StreamReader ] = None
2829 self ._writer : Optional [asyncio .StreamWriter ] = None
30+ self ._read_buffer = bytearray ()
2931
3032 if get_capture_or_validation_active ():
3133 raise RuntimeError ("Cannot create a new TCP object while capture or validation is active" )
3234
3335 async def setup (self ):
36+ await self ._open_connection ()
37+ self ._read_buffer = bytearray ()
38+
39+ async def _open_connection (self ):
3440 self ._reader , self ._writer = await asyncio .open_connection (self ._host , self ._port )
3541
3642 async def stop (self ):
@@ -40,44 +46,122 @@ async def stop(self):
4046 self ._reader = None
4147 self ._writer = None
4248
43- async def write (self , data : bytes ):
49+ async def write (self , data : bytes , num_tries : int = 3 ):
4450 assert self ._writer is not None , "forgot to call setup?"
45- self ._writer .write (data + b"\n " )
46- await self ._writer .drain ()
47- logger .log (LOG_LEVEL_IO , "[%s:%d] write %s" , self ._host , self ._port , data )
48- capturer .record (
49- TCPCommand (
50- device_id = f"{ self ._host } :{ self ._port } " , action = "write" , data = data .decode ("unicode_escape" )
51- )
52- )
5351
54- async def read (self , num_bytes : int = - 1 ) -> bytes :
52+ last_exc : Optional [Exception ] = None
53+
54+ for attempt in range (num_tries ):
55+ try :
56+ self ._writer .write (data + b"\n " ) # TODO: this should be a part of PF400, not io.TCP
57+ await self ._writer .drain ()
58+ except (ConnectionResetError , BrokenPipeError , OSError ) as exc :
59+ last_exc = exc
60+ logger .warning (
61+ "TCP write failed with %r on attempt %d/%d; reopening connection" ,
62+ exc ,
63+ attempt + 1 ,
64+ num_tries ,
65+ )
66+ await self ._open_connection ()
67+ assert self ._writer is not None , "_open_connection() failed to set _writer"
68+ else : # success
69+ logger .log (LOG_LEVEL_IO , "[%s:%d] write %s" , self ._host , self ._port , data )
70+ capturer .record (
71+ TCPCommand (
72+ device_id = f"{ self ._host } :{ self ._port } " ,
73+ action = "write" ,
74+ data = data .hex (),
75+ )
76+ )
77+ return
78+
79+ raise ConnectionResetError (f"Max number of retries reached ({ num_tries } )" ) from last_exc
80+
81+ async def _raw_read (self , num_bytes : int , num_tries : int ) -> bytes :
82+ """Single low-level read with retries; does not use buffer."""
5583 assert self ._reader is not None , "forgot to call setup?"
56- data = await self ._reader .read (num_bytes )
57- logger .log (LOG_LEVEL_IO , "[%s:%d] read %s" , self ._host , self ._port , data )
58- capturer .record (
59- TCPCommand (
60- device_id = f"{ self ._host } :{ self ._port } " , action = "read" , data = data .decode ("unicode_escape" )
84+
85+ last_exc : Optional [Exception ] = None
86+
87+ for attempt in range (num_tries ):
88+ try :
89+ data = await self ._reader .read (num_bytes )
90+ if data == b"" :
91+ return b"" # EOF
92+ except (ConnectionResetError , OSError ) as exc :
93+ last_exc = exc
94+ logger .warning (
95+ "TCP read failed with %r on attempt %d/%d; reopening connection" ,
96+ exc ,
97+ attempt + 1 ,
98+ num_tries ,
99+ )
100+ await self ._open_connection ()
101+ assert self ._reader is not None , "_open_connection() failed to set _reader"
102+ continue
103+
104+ logger .log (LOG_LEVEL_IO , "[%s:%d] read %s" , self ._host , self ._port , data )
105+ capturer .record (
106+ TCPCommand (
107+ device_id = f"{ self ._host } :{ self ._port } " ,
108+ action = "read" ,
109+ data = data .hex (),
110+ )
61111 )
62- )
63- return data
112+ return data
113+
114+ raise ConnectionResetError (f"Max number of read retries reached ({ num_tries } )" ) from last_exc
64115
65- async def readline (self ) -> bytes :
116+ async def read (self , num_bytes : int = 128 , num_tries : int = 3 ) -> bytes :
66117 assert self ._reader is not None , "forgot to call setup?"
67118
68- data = await self ._reader .read (128 )
69- last_line = data .split (b"\r \n " )[0 ] # fix for errors with multiplate lines returned
70- last_line += b"\r \n "
119+ if num_bytes <= 0 :
120+ return b""
71121
72- logger .log (LOG_LEVEL_IO , "[%s:%d] readline %s" , self ._host , self ._port , last_line )
73- capturer .record (
74- TCPCommand (
75- device_id = f"{ self ._host } :{ self ._port } " ,
76- action = "readline" ,
77- data = last_line .decode ("unicode_escape" ),
78- )
79- )
80- return last_line
122+ # Fill buffer until we have enough bytes.
123+
124+ while len (self ._read_buffer ) < num_bytes :
125+ chunk = await self ._raw_read (num_bytes - len (self ._read_buffer ), num_tries )
126+ self ._read_buffer .extend (chunk )
127+ if len (chunk ) == 0 : # EOF
128+ break
129+
130+ if len (self ._read_buffer ) < num_bytes :
131+ raise TimeoutError (f"Timeout while waiting for { num_bytes } bytes" )
132+
133+ # Consume from buffer, or return empty if buffer is empty.
134+ if len (self ._read_buffer ) == 0 :
135+ return b""
136+ chunk = bytes (self ._read_buffer [:num_bytes ])
137+ del self ._read_buffer [:num_bytes ]
138+ return chunk
139+
140+ async def readline (
141+ self , num_tries : int = 3 , timeout : float = 60 , line_ending : bytes = b"\r \n "
142+ ) -> bytes :
143+ assert self ._reader is not None , "forgot to call setup?"
144+
145+ CHUNK = 1024
146+
147+ timeout_time = time .time () + timeout
148+
149+ while time .time () < timeout_time :
150+ idx = self ._read_buffer .find (line_ending )
151+ if idx != - 1 :
152+ end = idx + len (line_ending )
153+ line = bytes (self ._read_buffer [:end ])
154+ del self ._read_buffer [:end ]
155+ return line
156+
157+ chunk = await self ._raw_read (num_bytes = CHUNK , num_tries = num_tries )
158+ self ._read_buffer .extend (chunk )
159+ if len (chunk ) == 0 : # EOF; return what we have
160+ line = bytes (self ._read_buffer )
161+ self ._read_buffer .clear ()
162+ return line
163+
164+ raise TimeoutError (f"Timeout while waiting for line ending with '{ line_ending .decode ()} '" )
81165
82166 def serialize (self ):
83167 return {
@@ -122,14 +206,4 @@ async def read(self, num_bytes: int = 128) -> bytes:
122206 and len (next_command .data ) == num_bytes
123207 ):
124208 raise ValidationError (f"Next line is { next_command } , expected TCP read { num_bytes } " )
125- return next_command .data .encode ()
126-
127- async def readline (self ) -> bytes :
128- next_command = TCPCommand (** self .cr .next_command ())
129- if not (
130- next_command .module == "tcp"
131- and next_command .device_id == f"{ self ._host } :{ self ._port } "
132- and next_command .action == "readline"
133- ):
134- raise ValidationError (f"Next line is { next_command } , expected TCP readline" )
135- return next_command .data .encode ()
209+ return bytes .fromhex (next_command .data )
0 commit comments