Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/unit/test_python_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import Mock, patch, MagicMock, call
import json

from trustgraph.api.socket_client import SocketClient
from trustgraph.api import (
Api,
Triple,
Expand Down Expand Up @@ -222,6 +223,82 @@ def test_socket_flow_has_methods(self):
for method in expected_methods:
assert hasattr(flow_instance, method), f"Missing method: {method}"

def test_socket_client_close_does_not_swallow_base_exceptions(self):
"""Test close cleanup does not suppress process-level interrupts."""

class InterruptingLoop:
def is_closed(self):
return False

def run_until_complete(self, awaitable):
if hasattr(awaitable, "close"):
awaitable.close()
raise SystemExit("stop")

socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._loop = InterruptingLoop()

with pytest.raises(SystemExit):
socket.close()

@pytest.mark.parametrize(
("generator_method", "async_method"),
[
("_streaming_generator", "_send_request_async_streaming"),
("_streaming_generator_raw", "_send_request_async_streaming_raw"),
],
)
def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions(
self, generator_method, async_method
):
"""Test streaming cleanup does not suppress process-level interrupts."""

class FakeAsyncGenerator:
def __anext__(self):
return "next"

def aclose(self):
return "close"

class InterruptingLoop:
def run_until_complete(self, awaitable):
if awaitable == "next":
raise StopAsyncIteration
if awaitable == "close":
raise SystemExit("stop")
raise AssertionError(f"unexpected awaitable: {awaitable!r}")

socket = SocketClient(url="http://test/", timeout=60, token=None)
setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator())
generator = getattr(socket, generator_method)(
"agent", "default", {}, InterruptingLoop()
)

with pytest.raises(SystemExit):
next(generator)

@pytest.mark.asyncio
async def test_socket_client_reader_does_not_swallow_base_exceptions(self):
"""Test reader error fanout does not suppress process-level interrupts."""

class FailingSocket:
def __aiter__(self):
return self

async def __anext__(self):
raise ValueError("reader failed")

class InterruptingQueue:
async def put(self, message):
raise SystemExit("stop")

socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._socket = FailingSocket()
socket._pending = {"req-1": InterruptingQueue()}

with pytest.raises(SystemExit):
await socket._reader()


class TestBulkClient:
"""Test bulk operations client"""
Expand Down
11 changes: 6 additions & 5 deletions trustgraph-base/trustgraph/api/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import json
import asyncio
import websockets
from websockets.exceptions import ConnectionClosed
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock

Expand Down Expand Up @@ -191,13 +192,13 @@ async def _reader(self):
if request_id and request_id in self._pending:
await self._pending[request_id].put(response)

except websockets.exceptions.ConnectionClosed:
except ConnectionClosed:
pass
except Exception as e:
for queue in self._pending.values():
try:
await queue.put({"error": str(e)})
except:
except Exception:
pass
finally:
self._connected = False
Expand Down Expand Up @@ -250,7 +251,7 @@ def _streaming_generator(
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
except Exception:
pass

def _streaming_generator_raw(
Expand All @@ -273,7 +274,7 @@ def _streaming_generator_raw(
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
except Exception:
pass

async def _send_request_async_streaming_raw(
Expand Down Expand Up @@ -542,7 +543,7 @@ def close(self) -> None:
if self._loop and not self._loop.is_closed():
try:
self._loop.run_until_complete(self._close_async())
except:
except Exception:
pass

async def _close_async(self):
Expand Down
Loading