Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions tests/unit/test_tables/test_knowledge_table_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def _make_store():
class TestGetGraphEmbeddings:

@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_entity_embeddings_with_singular_vector(
self, mock_async_execute
self, mock_async_execute_paged
):
"""
Cassandra rows return entities as a list of [entity_tuple, vector]
Expand All @@ -57,7 +57,7 @@ async def test_row_converts_to_entity_embeddings_with_singular_vector(
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
mock_async_execute.return_value = [fake_row]
mock_async_execute_paged.return_value = [[fake_row]]

received = []

Expand All @@ -66,7 +66,7 @@ async def receiver(msg):

await store.get_graph_embeddings("alice", "doc-1", receiver)

mock_async_execute.assert_called_once_with(
mock_async_execute_paged.assert_called_once_with(
store.cassandra,
store.get_graph_embeddings_stmt,
("alice", "doc-1"),
Expand Down Expand Up @@ -96,16 +96,16 @@ async def receiver(msg):
assert ge.entities[2].entity.value == "a literal entity"

@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
"""row[3] being None / empty must produce a GraphEmbeddings with
no entities, not raise."""
fake_row = (None, None, None, None)

store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
mock_async_execute.return_value = [fake_row]
mock_async_execute_paged.return_value = [[fake_row]]

received = []

Expand All @@ -118,8 +118,8 @@ async def receiver(msg):
assert received[0].entities == []

@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
fake_rows = [
(None, None, None, [
(("http://example.org/a", True), [1.0]),
Expand All @@ -132,7 +132,7 @@ async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
mock_async_execute.return_value = fake_rows
mock_async_execute_paged.return_value = [fake_rows]

received = []

Expand All @@ -153,8 +153,8 @@ class TestGetTriples:
the same Metadata construction. Cover it for parity."""

@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
async def test_row_converts_to_triples(self, mock_async_execute):
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_triples(self, mock_async_execute_paged):
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
fake_row = (
None, None, None,
Expand All @@ -170,7 +170,7 @@ async def test_row_converts_to_triples(self, mock_async_execute):
store = _make_store()
store.cassandra = Mock()
store.get_triples_stmt = Mock()
mock_async_execute.return_value = [fake_row]
mock_async_execute_paged.return_value = [[fake_row]]

received = []

Expand Down
4 changes: 2 additions & 2 deletions trustgraph-base/trustgraph/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def triple_generator():
from . bulk_client import BulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._bulk_client

def metrics(self):
Expand Down Expand Up @@ -462,7 +462,7 @@ async def triple_gen():
from . async_bulk_client import AsyncBulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._async_bulk_client

def async_metrics(self):
Expand Down
51 changes: 23 additions & 28 deletions trustgraph-base/trustgraph/api/async_bulk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
class AsyncBulkClient:
"""Asynchronous bulk operations client"""

def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace

def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
Expand All @@ -25,11 +26,21 @@ def _convert_to_ws_url(self, url: str) -> str:
else:
return f"ws://{url}"

def _build_ws_url(self, path: str) -> str:
"""Build a WebSocket URL with token and workspace query params."""
ws_url = f"{self.url}{path}"
params = []
if self.token:
params.append(f"token={self.token}")
if self.workspace:
params.append(f"workspace={self.workspace}")
if params:
ws_url = f"{ws_url}?{'&'.join(params)}"
return ws_url

async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for triple in triples:
Expand All @@ -42,9 +53,7 @@ async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwar

async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
"""Bulk export triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
Expand All @@ -57,69 +66,55 @@ async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple

async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
await websocket.send(json.dumps(embedding))

async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)

async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
await websocket.send(json.dumps(embedding))

async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)

async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for context in contexts:
await websocket.send(json.dumps(context))

async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)

async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for row in rows:
Expand Down
52 changes: 24 additions & 28 deletions trustgraph-base/trustgraph/api/bulk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@ class BulkClient:
Note: For true async support, use AsyncBulkClient instead.
"""

def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
"""
Initialize synchronous bulk client.

Args:
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
timeout: WebSocket timeout in seconds
token: Optional bearer token for authentication
workspace: Workspace for data isolation
"""
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace

def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
Expand All @@ -58,6 +60,18 @@ def _convert_to_ws_url(self, url: str) -> str:
else:
return f"ws://{url}"

def _build_ws_url(self, path: str) -> str:
"""Build a WebSocket URL with token and workspace query params."""
ws_url = f"{self.url}{path}"
params = []
if self.token:
params.append(f"token={self.token}")
if self.workspace:
params.append(f"workspace={self.workspace}")
if params:
ws_url = f"{ws_url}?{'&'.join(params)}"
return ws_url

def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
"""Run async coroutine synchronously"""
try:
Expand Down Expand Up @@ -116,9 +130,7 @@ async def _import_triples_async(
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of triple import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")

if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
Expand Down Expand Up @@ -194,9 +206,7 @@ def export_triples(self, flow: str, **kwargs: Any) -> Iterator[Triple]:

async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
"""Async implementation of triple export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
Expand Down Expand Up @@ -238,9 +248,7 @@ def embedding_generator():

async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of graph embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
Expand Down Expand Up @@ -296,9 +304,7 @@ def export_graph_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str

async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of graph embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
Expand Down Expand Up @@ -336,9 +342,7 @@ def doc_embedding_generator():

async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of document embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
Expand Down Expand Up @@ -394,9 +398,7 @@ def export_document_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[

async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of document embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
Expand Down Expand Up @@ -446,9 +448,7 @@ async def _import_entity_contexts_async(
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of entity contexts import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")

if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
Expand Down Expand Up @@ -522,9 +522,7 @@ def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str,

async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of entity contexts export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
Expand Down Expand Up @@ -562,9 +560,7 @@ def row_generator():

async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")

async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for row in rows:
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-base/trustgraph/api/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ async def _ensure_connected(self):
)

if resp.get("type") == "auth-ok":
self.workspace = resp.get("workspace", self.workspace)
if self.workspace == "default":
self.workspace = resp.get("workspace", self.workspace)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(
Expand Down
4 changes: 3 additions & 1 deletion trustgraph-flow/trustgraph/gateway/endpoint/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ async def handle(self, request):

running = Running()

params = dict(request.query)
params.update(request.match_info)
dispatcher = await self.dispatcher(
ws, running, request.match_info
ws, running, params
)

worker_task = tg.create_task(
Expand Down
Loading