Skip to content

Commit

Permalink
User-Defined AQL functions (#37)
Browse files Browse the repository at this point in the history
* Adding user-defined AQL functions

* Deterministic test
  • Loading branch information
apetenchea authored Feb 13, 2025
1 parent fd840a1 commit efdad0e
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 2 deletions.
119 changes: 118 additions & 1 deletion arangoasync/aql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["AQL", "AQLQueryCache"]


from typing import Optional
from typing import Optional, cast

from arangoasync.cursor import Cursor
from arangoasync.errno import HTTP_NOT_FOUND
Expand All @@ -10,6 +10,9 @@
AQLCacheConfigureError,
AQLCacheEntriesError,
AQLCachePropertiesError,
AQLFunctionCreateError,
AQLFunctionDeleteError,
AQLFunctionListError,
AQLQueryClearError,
AQLQueryExecuteError,
AQLQueryExplainError,
Expand Down Expand Up @@ -634,3 +637,117 @@ def response_handler(resp: Response) -> Jsons:
return self.deserializer.loads_many(resp.raw_body)

return await self._executor.execute(request, response_handler)

async def functions(self, namespace: Optional[str] = None) -> Result[Jsons]:
"""List the registered used-defined AQL functions.
Args:
namespace (str | None): Returns all registered AQL user functions from
the specified namespace.
Returns:
list: List of the AQL functions defined in the database.
Raises:
AQLFunctionListError: If retrieval fails.
References:
- `list-the-registered-user-defined-aql-functions <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#list-the-registered-user-defined-aql-functions>`__
""" # noqa: E501
params: Json = dict()
if namespace is not None:
params["namespace"] = namespace
request = Request(
method=Method.GET,
endpoint="/_api/aqlfunction",
params=params,
)

def response_handler(resp: Response) -> Jsons:
if not resp.is_success:
raise AQLFunctionListError(resp, request)
result = cast(Jsons, self.deserializer.loads(resp.raw_body).get("result"))
if result is None:
raise AQLFunctionListError(resp, request)
return result

return await self._executor.execute(request, response_handler)

async def create_function(
self,
name: str,
code: str,
is_deterministic: Optional[bool] = None,
) -> Result[Json]:
"""Registers a user-defined AQL function (UDF) written in JavaScript.
Args:
name (str): Name of the function.
code (str): JavaScript code of the function.
is_deterministic (bool | None): If set to `True`, the function is
deterministic.
Returns:
dict: Information about the registered function.
Raises:
AQLFunctionCreateError: If registration fails.
References:
- `create-a-user-defined-aql-function <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#create-a-user-defined-aql-function>`__
""" # noqa: E501
request = Request(
method=Method.POST,
endpoint="/_api/aqlfunction",
data=self.serializer.dumps(
dict(name=name, code=code, isDeterministic=is_deterministic)
),
)

def response_handler(resp: Response) -> Json:
if not resp.is_success:
raise AQLFunctionCreateError(resp, request)
return self.deserializer.loads(resp.raw_body)

return await self._executor.execute(request, response_handler)

async def delete_function(
self,
name: str,
group: Optional[bool] = None,
ignore_missing: bool = False,
) -> Result[Json]:
"""Remove a user-defined AQL function.
Args:
name (str): Name of the function.
group (bool | None): If set to `True`, the function name is treated
as a namespace prefix.
ignore_missing (bool): If set to `True`, will not raise an exception
if the function is not found.
Returns:
dict: Information about the removed functions (their count).
Raises:
AQLFunctionDeleteError: If removal fails.
References:
- `remove-a-user-defined-aql-function <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#remove-a-user-defined-aql-function>`__
""" # noqa: E501
params: Json = dict()
if group is not None:
params["group"] = group
request = Request(
method=Method.DELETE,
endpoint=f"/_api/aqlfunction/{name}",
params=params,
)

def response_handler(resp: Response) -> Json:
if not resp.is_success:
if not (resp.status_code == HTTP_NOT_FOUND and ignore_missing):
raise AQLFunctionDeleteError(resp, request)
return self.deserializer.loads(resp.raw_body)

return await self._executor.execute(request, response_handler)
12 changes: 12 additions & 0 deletions arangoasync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ class AQLCachePropertiesError(ArangoServerError):
"""Failed to retrieve query cache properties."""


class AQLFunctionCreateError(ArangoServerError):
"""Failed to create AQL user function."""


class AQLFunctionDeleteError(ArangoServerError):
"""Failed to delete AQL user function."""


class AQLFunctionListError(ArangoServerError):
"""Failed to retrieve AQL user functions."""


class AQLQueryClearError(ArangoServerError):
"""Failed to clear slow AQL queries."""

Expand Down
89 changes: 88 additions & 1 deletion tests/test_aql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
import pytest
from packaging import version

from arangoasync.errno import FORBIDDEN, QUERY_PARSE
from arangoasync.errno import (
FORBIDDEN,
QUERY_FUNCTION_INVALID_CODE,
QUERY_FUNCTION_NOT_FOUND,
QUERY_PARSE,
)
from arangoasync.exceptions import (
AQLCacheClearError,
AQLCacheConfigureError,
AQLCacheEntriesError,
AQLCachePropertiesError,
AQLFunctionCreateError,
AQLFunctionDeleteError,
AQLFunctionListError,
AQLQueryClearError,
AQLQueryExecuteError,
AQLQueryExplainError,
Expand Down Expand Up @@ -276,3 +284,82 @@ async def test_cache_plan_management(db, bad_db, doc_col, docs, db_version):
with pytest.raises(AQLCacheClearError) as err:
await bad_db.aql.cache.clear_plan()
assert err.value.error_code == FORBIDDEN


@pytest.mark.asyncio
async def test_aql_function_management(db, bad_db):
fn_group = "functions::temperature"
fn_name_1 = "functions::temperature::celsius_to_fahrenheit"
fn_body_1 = "function (celsius) { return celsius * 1.8 + 32; }"
fn_name_2 = "functions::temperature::fahrenheit_to_celsius"
fn_body_2 = "function (fahrenheit) { return (fahrenheit - 32) / 1.8; }"
bad_fn_name = "functions::temperature::should_not_exist"
bad_fn_body = "function (celsius) { invalid syntax }"

aql = db.aql
# List AQL functions
assert await aql.functions() == []

# List AQL functions with bad database
with pytest.raises(AQLFunctionListError) as err:
await bad_db.aql.functions()
assert err.value.error_code == FORBIDDEN

# Create invalid AQL function
with pytest.raises(AQLFunctionCreateError) as err:
await aql.create_function(bad_fn_name, bad_fn_body)
assert err.value.error_code == QUERY_FUNCTION_INVALID_CODE

# Create first AQL function
result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True)
assert result["isNewlyCreated"] is True
functions = await aql.functions()
assert len(functions) == 1
assert functions[0]["name"] == fn_name_1
assert functions[0]["code"] == fn_body_1
assert functions[0]["isDeterministic"] is True

# Create same AQL function again
result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True)
assert result["isNewlyCreated"] is False
functions = await aql.functions()
assert len(functions) == 1
assert functions[0]["name"] == fn_name_1
assert functions[0]["code"] == fn_body_1
assert functions[0]["isDeterministic"] is True

# Create second AQL function
result = await aql.create_function(fn_name_2, fn_body_2, is_deterministic=False)
assert result["isNewlyCreated"] is True
functions = await aql.functions()
assert len(functions) == 2
assert functions[0]["name"] == fn_name_1
assert functions[0]["code"] == fn_body_1
assert functions[0]["isDeterministic"] is True
assert functions[1]["name"] == fn_name_2
assert functions[1]["code"] == fn_body_2
assert functions[1]["isDeterministic"] is False

# Delete first function
result = await aql.delete_function(fn_name_1)
assert result["deletedCount"] == 1
functions = await aql.functions()
assert len(functions) == 1

# Delete missing function
with pytest.raises(AQLFunctionDeleteError) as err:
await aql.delete_function(fn_name_1)
assert err.value.error_code == QUERY_FUNCTION_NOT_FOUND
result = await aql.delete_function(fn_name_1, ignore_missing=True)
assert "deletedCount" not in result

# Delete function from bad db
with pytest.raises(AQLFunctionDeleteError) as err:
await bad_db.aql.delete_function(fn_name_2)
assert err.value.error_code == FORBIDDEN

# Delete function group
result = await aql.delete_function(fn_group, group=True)
assert result["deletedCount"] == 1
functions = await aql.functions()
assert len(functions) == 0
1 change: 1 addition & 0 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def test_cursor_write_query(db, doc_col, docs):
cursor = await aql.execute(
"""
FOR d IN {col} FILTER d.val == @first OR d.val == @second
SORT d.val
UPDATE {{_key: d._key, _val: @val }} IN {col}
RETURN NEW
""".format(
Expand Down

0 comments on commit efdad0e

Please sign in to comment.