Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The subclass of `socket.socket` provides three key method overrides. All other `
* `sendall` - takes a Numpy array as an argument to send an Array
* `recv` - outputs a Numpy array (`len() == 0` if no data received)

### Known Issues
With a large value for `bufsize` in `recv()` there it is not guaranteed that multiple small arrays will be fully received. You should always set your `bufsize` to be small enough that it will not "reach over" to the next array.

## Installation
```
pip install numpysocket
Expand Down
11 changes: 10 additions & 1 deletion examples/Simple/simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,23 @@

import logging
import numpy as np
from numpysocket import NumpySocket

from numpysocket import NumpySocket

logger = logging.getLogger("simple client")
logger.setLevel(logging.INFO)

with NumpySocket() as s:
s.connect(("localhost", 9999))
logger.info("Connecteed to localhost:9999")

logger.info("sending numpy array:")
frame = np.arange(1000)
s.sendall(frame)

logger.info("sending reversed numpy array:")
frame = np.arange(1000, 0, -1)
s.sendall(frame)

logger.info("sending numpy array:")
frame = np.arange(1000)
Expand Down
15 changes: 9 additions & 6 deletions examples/Simple/simple_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
with NumpySocket() as s:
s.bind(("", 9999))
s.listen()
logger.info("Server listening on port 9999")
conn, addr = s.accept()
with conn:
logger.info(f"connected: {addr}")
frame = conn.recv()

logger.info("array received")
logger.info(frame)

logger.info(f"disconnected: {addr}")
while True:
frame = conn.recv()
if frame is None or (hasattr(frame, "size") and frame.size == 0):
logger.info("Client disconnected")
break
logger.info("array received")
logger.info(frame)
logger.info(f"disconnected: {addr}")
53 changes: 34 additions & 19 deletions numpysocket/numpysocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,44 @@ def sendall(self, frame: np.ndarray) -> None: # type: ignore[override]
def recv(self, bufsize: int = 1024) -> np.ndarray: # type: ignore[override]
length = None
frame_buffer = bytearray()

while True:
data = super().recv(bufsize)
receive_size = bufsize
if length is not None:
remaining = length - len(frame_buffer)
receive_size = min(bufsize, remaining)
logging.debug(
f"Receiving {receive_size} of {remaining} remaining bytes"
)

data = super().recv(receive_size)
if len(data) == 0:
return np.array([])

frame_buffer += data
if len(frame_buffer) == length:
break
while True:
if length is None:
if b":" not in frame_buffer:
break
length_str, _, frame_buffer = frame_buffer.partition(b":")
length = int(length_str)
if len(frame_buffer) < length:
break

frame_buffer = frame_buffer[length:]
length = None
break

frame = np.load(BytesIO(frame_buffer), allow_pickle=True)["frame"]
logging.debug("frame received")
return frame
if length is None:
if b":" not in frame_buffer:
continue
header, _, data = frame_buffer.partition(b":")
try:
length = int(header.decode())
frame_buffer = data
except ValueError:
logging.error("Invalid header format")
return np.array([])

if len(frame_buffer) < length:
continue

frame_data = frame_buffer[:length]

try:
frame = np.load(BytesIO(frame_data), allow_pickle=True)["frame"]
logging.debug("Frame received")
return frame
except Exception as e:
logging.error(f"Error parsing frame: {e}")
return np.array([])

def accept(self) -> tuple["NumpySocket", Union[tuple[str, int], tuple[Any, ...]]]:
fd, addr = super()._accept() # type: ignore
Expand Down
149 changes: 149 additions & 0 deletions tests/test_numpysocket_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import unittest
import os
import socket
import threading
import numpy as np

from numpysocket import NumpySocket


class TestNumpySocketIntegration(unittest.TestCase):
def setUp(self) -> None:
self.socket_path = "/tmp/numpysocket_integration_test"

try:
os.unlink(self.socket_path)
except OSError:
pass

def tearDown(self) -> None:
try:
os.unlink(self.socket_path)
except OSError:
pass

def _run_socket_test(
self, send_arrays: np.array, buffer_size: int = 1024
) -> list[np.array]:
recv_count = len(send_arrays)

received_arrays = []
server_ready = threading.Event()
server_done = threading.Event()

def server_thread() -> None:
with NumpySocket(socket.AF_UNIX, socket.SOCK_STREAM) as server:
server.bind(self.socket_path)
server.listen()
server_ready.set()

conn, _ = server.accept()
with conn:
print("Server: Client connected")
for _ in range(recv_count):
frame = conn.recv(buffer_size)
received_arrays.append(frame)

server_done.set()

server = threading.Thread(target=server_thread)
server.daemon = True
server.start()

server_ready.wait(timeout=5)

with NumpySocket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
client.connect(self.socket_path)
print("Client: Connected to server")

for _, array in enumerate(send_arrays):
client.sendall(array)

server_done.wait(timeout=5)

return received_arrays

def test_small_array_large_buffer(self) -> None:
test_array = np.array([[1, 2, 3], [4, 5, 6]])

buffer_size = 4096

received = self._run_socket_test([test_array], buffer_size=buffer_size)
self.assertEqual(len(received), 1, "Should receive exactly one array")
np.testing.assert_array_equal(received[0], test_array, "Arrays don't match")

def test_small_array_small_buffer(self) -> None:
test_array = np.array([[1, 2, 3], [4, 5, 6]])

buffer_size = 32

received = self._run_socket_test([test_array], buffer_size=buffer_size)

self.assertEqual(len(received), 1, "Should receive exactly one array")
np.testing.assert_array_equal(received[0], test_array, "Arrays don't match")

def test_large_array_small_buffer(self) -> None:
test_array = np.arange(900).reshape(30, 30)

buffer_size = 64

received = self._run_socket_test([test_array], buffer_size=buffer_size)

self.assertEqual(len(received), 1, "Should receive exactly one array")
np.testing.assert_array_equal(received[0], test_array, "Arrays don't match")

def test_two_small_arrays(self) -> None:
test_arrays = [
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[7, 8, 9], [10, 11, 12]]),
]

received = self._run_socket_test(test_arrays, buffer_size=64)

self.assertEqual(
len(received),
len(test_arrays),
f"Should receive {len(test_arrays)} arrays, got {len(received)}",
)

for i, (sent, recv) in enumerate(zip(test_arrays, received)):
np.testing.assert_array_equal(recv, sent, f"Array {i} doesn't match")

def test_two_large_arrays(self) -> None:
test_arrays = [
np.arange(900).reshape(30, 30),
np.arange(900, 1800).reshape(30, 30),
]

received = self._run_socket_test(test_arrays, buffer_size=64)

self.assertEqual(
len(received),
len(test_arrays),
f"Should receive {len(test_arrays)} arrays, got {len(received)}",
)

for i, (sent, recv) in enumerate(zip(test_arrays, received)):
np.testing.assert_array_equal(recv, sent, f"Array {i} doesn't match")

def test_three_arrays_mixed_sizes(self) -> None:
test_arrays = [
np.array([[1, 2], [3, 4]]),
np.arange(100).reshape(10, 10),
np.arange(1000, 1900).reshape(30, 30),
]

received = self._run_socket_test(test_arrays, buffer_size=64)

self.assertEqual(
len(received),
len(test_arrays),
f"Should receive {len(test_arrays)} arrays, got {len(received)}",
)

for i, (sent, recv) in enumerate(zip(test_arrays, received)):
np.testing.assert_array_equal(recv, sent, f"Array {i} doesn't match")


if __name__ == "__main__":
unittest.main()
Loading