diff --git a/arangoasync/aql.py b/arangoasync/aql.py index 1fdcc6e..072fdbe 100644 --- a/arangoasync/aql.py +++ b/arangoasync/aql.py @@ -53,7 +53,7 @@ async def execute( memory_limit: Optional[int] = None, ttl: Optional[int] = None, allow_dirty_read: Optional[bool] = None, - options: Optional[QueryProperties] = None, + options: Optional[QueryProperties | Json] = None, ) -> Result[Cursor]: """Execute the query and return the result cursor. @@ -73,7 +73,7 @@ async def execute( will be removed on the server automatically after the specified amount of time. allow_dirty_read (bool | None): Allow reads from followers in a cluster. - options (QueryProperties | None): Extra options for the query. + options (QueryProperties | dict | None): Extra options for the query. References: - `create-a-cursor `__ @@ -92,7 +92,9 @@ async def execute( if ttl is not None: data["ttl"] = ttl if options is not None: - data["options"] = options.to_dict() + if isinstance(options, QueryProperties): + options = options.to_dict() + data["options"] = options headers = dict() if allow_dirty_read is not None: diff --git a/arangoasync/cursor.py b/arangoasync/cursor.py index 64375cb..55ba40a 100644 --- a/arangoasync/cursor.py +++ b/arangoasync/cursor.py @@ -50,7 +50,7 @@ def __init__(self, executor: ApiExecutor, data: Json) -> None: self._batch: Deque[Any] = deque() self._update(data) - async def __aiter__(self) -> "Cursor": + def __aiter__(self) -> "Cursor": return self async def __anext__(self) -> Any: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 64ffac4..8995206 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -3,8 +3,14 @@ import pytest from arangoasync.aql import AQL -from arangoasync.errno import CURSOR_NOT_FOUND -from arangoasync.exceptions import CursorCloseError +from arangoasync.errno import CURSOR_NOT_FOUND, HTTP_BAD_PARAMETER +from arangoasync.exceptions import ( + CursorCloseError, + CursorCountError, + CursorEmptyError, + CursorNextError, + CursorStateError, +) from arangoasync.typings import QueryExecutionStats, QueryProperties @@ -180,3 +186,208 @@ async def test_cursor_write_query(db, doc_col, docs): await cursor.close(ignore_missing=False) assert err.value.error_code == CURSOR_NOT_FOUND assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_invalid_id(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=2, + ttl=1000, + options={"optimizer": {"rules": ["+all"]}, "profile": 1}, + ) + + # Set the cursor ID to "invalid" and assert errors + setattr(cursor, "_id", "invalid") + + # Cursor should not be found + with pytest.raises(CursorNextError) as err: + async for _ in cursor: + pass + assert err.value.error_code == CURSOR_NOT_FOUND + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + # Set the cursor ID to None and assert errors + setattr(cursor, "_id", None) + with pytest.raises(CursorStateError): + print(await cursor.next()) + with pytest.raises(CursorStateError): + await cursor.fetch() + assert await cursor.close() is False + + +@pytest.mark.asyncio +async def test_cursor_premature_close(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=2, + ttl=1000, + ) + assert len(cursor.batch) == 2 + assert await cursor.close() is True + + # Cursor should be already closed + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_context_manager(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=2, + ttl=1000, + ) + async with cursor as ctx: + assert (await ctx.next())["val"] == docs[0]["val"] + + # Cursor should be already closed + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_manual_fetch_and_pop(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=1, + ttl=1000, + options={"allowRetry": True}, + ) + + # Fetch documents manually + for idx in range(2, len(docs)): + result = await cursor.fetch() + assert len(result) == 1 + assert cursor.count == len(docs) + assert cursor.has_more + assert len(cursor.batch) == idx + assert result[0]["val"] == docs[idx - 1]["val"] + result = await cursor.fetch() + assert result[0]["val"] == docs[len(docs) - 1]["val"] + assert len(cursor.batch) == len(docs) + assert not cursor.has_more + + # Pop documents manually + idx = 0 + while not cursor.empty(): + doc = cursor.pop() + assert doc["val"] == docs[idx]["val"] + idx += 1 + assert len(cursor.batch) == 0 + + # Cursor should be empty + with pytest.raises(CursorEmptyError): + await cursor.pop() + + +@pytest.mark.asyncio +async def test_cursor_retry(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + # Do not allow retries + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=1, + ttl=1000, + options={"allowRetry": False}, + ) + + # Increase the batch id by doing a fetch + await cursor.fetch() + while not cursor.empty(): + cursor.pop() + next_batch_id = cursor.next_batch_id + + # Fetch the next batch + await cursor.fetch() + # Retry is not allowed + with pytest.raises(CursorNextError) as err: + await cursor.fetch(batch_id=next_batch_id) + assert err.value.error_code == HTTP_BAD_PARAMETER + + await cursor.close() + + # Now let's allow retries + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=1, + ttl=1000, + options={"allowRetry": True}, + ) + + # Increase the batch id by doing a fetch + await cursor.fetch() + while not cursor.empty(): + cursor.pop() + next_batch_id = cursor.next_batch_id + + # Fetch the next batch + prev_batch = await cursor.fetch() + next_next_batch_id = cursor.next_batch_id + # Should fetch the same batch again + next_batch = await cursor.fetch(batch_id=next_batch_id) + assert next_batch == prev_batch + # Next batch id should be the same + assert cursor.next_batch_id == next_next_batch_id + + # Fetch the next batch + next_next_batch = await cursor.fetch() + assert next_next_batch != next_batch + + assert await cursor.close() + + +@pytest.mark.asyncio +async def test_cursor_no_count(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=False, + batch_size=2, + ttl=1000, + ) + + # Cursor count is not enabled + with pytest.raises(CursorCountError): + _ = len(cursor) + with pytest.raises(CursorCountError): + _ = bool(cursor) + + while cursor.has_more: + assert cursor.count is None + assert await cursor.fetch()