Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User-Defined AQL functions #37

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion arangoasync/aql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["AQL", "AQLQueryCache"]


from typing import Optional
from typing import Optional, cast

from arangoasync.cursor import Cursor
from arangoasync.errno import HTTP_NOT_FOUND
Expand All @@ -10,6 +10,9 @@
AQLCacheConfigureError,
AQLCacheEntriesError,
AQLCachePropertiesError,
AQLFunctionCreateError,
AQLFunctionDeleteError,
AQLFunctionListError,
AQLQueryClearError,
AQLQueryExecuteError,
AQLQueryExplainError,
Expand Down Expand Up @@ -634,3 +637,117 @@ def response_handler(resp: Response) -> Jsons:
return self.deserializer.loads_many(resp.raw_body)

return await self._executor.execute(request, response_handler)

async def functions(self, namespace: Optional[str] = None) -> Result[Jsons]:
"""List the registered used-defined AQL functions.

Args:
namespace (str | None): Returns all registered AQL user functions from
the specified namespace.

Returns:
list: List of the AQL functions defined in the database.

Raises:
AQLFunctionListError: If retrieval fails.

References:
- `list-the-registered-user-defined-aql-functions <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#list-the-registered-user-defined-aql-functions>`__
""" # noqa: E501
params: Json = dict()
if namespace is not None:
params["namespace"] = namespace
request = Request(
method=Method.GET,
endpoint="/_api/aqlfunction",
params=params,
)

def response_handler(resp: Response) -> Jsons:
if not resp.is_success:
raise AQLFunctionListError(resp, request)
result = cast(Jsons, self.deserializer.loads(resp.raw_body).get("result"))
if result is None:
raise AQLFunctionListError(resp, request)
return result

return await self._executor.execute(request, response_handler)

async def create_function(
self,
name: str,
code: str,
is_deterministic: Optional[bool] = None,
) -> Result[Json]:
"""Registers a user-defined AQL function (UDF) written in JavaScript.

Args:
name (str): Name of the function.
code (str): JavaScript code of the function.
is_deterministic (bool | None): If set to `True`, the function is
deterministic.

Returns:
dict: Information about the registered function.

Raises:
AQLFunctionCreateError: If registration fails.

References:
- `create-a-user-defined-aql-function <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#create-a-user-defined-aql-function>`__
""" # noqa: E501
request = Request(
method=Method.POST,
endpoint="/_api/aqlfunction",
data=self.serializer.dumps(
dict(name=name, code=code, isDeterministic=is_deterministic)
),
)

def response_handler(resp: Response) -> Json:
if not resp.is_success:
raise AQLFunctionCreateError(resp, request)
return self.deserializer.loads(resp.raw_body)

return await self._executor.execute(request, response_handler)

async def delete_function(
self,
name: str,
group: Optional[bool] = None,
ignore_missing: bool = False,
) -> Result[Json]:
"""Remove a user-defined AQL function.

Args:
name (str): Name of the function.
group (bool | None): If set to `True`, the function name is treated
as a namespace prefix.
ignore_missing (bool): If set to `True`, will not raise an exception
if the function is not found.

Returns:
dict: Information about the removed functions (their count).

Raises:
AQLFunctionDeleteError: If removal fails.

References:
- `remove-a-user-defined-aql-function <https://docs.arangodb.com/stable/develop/http-api/queries/user-defined-aql-functions/#remove-a-user-defined-aql-function>`__
""" # noqa: E501
params: Json = dict()
if group is not None:
params["group"] = group
request = Request(
method=Method.DELETE,
endpoint=f"/_api/aqlfunction/{name}",
params=params,
)

def response_handler(resp: Response) -> Json:
if not resp.is_success:
if not (resp.status_code == HTTP_NOT_FOUND and ignore_missing):
raise AQLFunctionDeleteError(resp, request)
return self.deserializer.loads(resp.raw_body)

return await self._executor.execute(request, response_handler)
12 changes: 12 additions & 0 deletions arangoasync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ class AQLCachePropertiesError(ArangoServerError):
"""Failed to retrieve query cache properties."""


class AQLFunctionCreateError(ArangoServerError):
"""Failed to create AQL user function."""


class AQLFunctionDeleteError(ArangoServerError):
"""Failed to delete AQL user function."""


class AQLFunctionListError(ArangoServerError):
"""Failed to retrieve AQL user functions."""


class AQLQueryClearError(ArangoServerError):
"""Failed to clear slow AQL queries."""

Expand Down
89 changes: 88 additions & 1 deletion tests/test_aql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
import pytest
from packaging import version

from arangoasync.errno import FORBIDDEN, QUERY_PARSE
from arangoasync.errno import (
FORBIDDEN,
QUERY_FUNCTION_INVALID_CODE,
QUERY_FUNCTION_NOT_FOUND,
QUERY_PARSE,
)
from arangoasync.exceptions import (
AQLCacheClearError,
AQLCacheConfigureError,
AQLCacheEntriesError,
AQLCachePropertiesError,
AQLFunctionCreateError,
AQLFunctionDeleteError,
AQLFunctionListError,
AQLQueryClearError,
AQLQueryExecuteError,
AQLQueryExplainError,
Expand Down Expand Up @@ -276,3 +284,82 @@ async def test_cache_plan_management(db, bad_db, doc_col, docs, db_version):
with pytest.raises(AQLCacheClearError) as err:
await bad_db.aql.cache.clear_plan()
assert err.value.error_code == FORBIDDEN


@pytest.mark.asyncio
async def test_aql_function_management(db, bad_db):
fn_group = "functions::temperature"
fn_name_1 = "functions::temperature::celsius_to_fahrenheit"
fn_body_1 = "function (celsius) { return celsius * 1.8 + 32; }"
fn_name_2 = "functions::temperature::fahrenheit_to_celsius"
fn_body_2 = "function (fahrenheit) { return (fahrenheit - 32) / 1.8; }"
bad_fn_name = "functions::temperature::should_not_exist"
bad_fn_body = "function (celsius) { invalid syntax }"

aql = db.aql
# List AQL functions
assert await aql.functions() == []

# List AQL functions with bad database
with pytest.raises(AQLFunctionListError) as err:
await bad_db.aql.functions()
assert err.value.error_code == FORBIDDEN

# Create invalid AQL function
with pytest.raises(AQLFunctionCreateError) as err:
await aql.create_function(bad_fn_name, bad_fn_body)
assert err.value.error_code == QUERY_FUNCTION_INVALID_CODE

# Create first AQL function
result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True)
assert result["isNewlyCreated"] is True
functions = await aql.functions()
assert len(functions) == 1
assert functions[0]["name"] == fn_name_1
assert functions[0]["code"] == fn_body_1
assert functions[0]["isDeterministic"] is True

# Create same AQL function again
result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True)
assert result["isNewlyCreated"] is False
functions = await aql.functions()
assert len(functions) == 1
assert functions[0]["name"] == fn_name_1
assert functions[0]["code"] == fn_body_1
assert functions[0]["isDeterministic"] is True

# Create second AQL function
result = await aql.create_function(fn_name_2, fn_body_2, is_deterministic=False)
assert result["isNewlyCreated"] is True
functions = await aql.functions()
assert len(functions) == 2
assert functions[0]["name"] == fn_name_1
assert functions[0]["code"] == fn_body_1
assert functions[0]["isDeterministic"] is True
assert functions[1]["name"] == fn_name_2
assert functions[1]["code"] == fn_body_2
assert functions[1]["isDeterministic"] is False

# Delete first function
result = await aql.delete_function(fn_name_1)
assert result["deletedCount"] == 1
functions = await aql.functions()
assert len(functions) == 1

# Delete missing function
with pytest.raises(AQLFunctionDeleteError) as err:
await aql.delete_function(fn_name_1)
assert err.value.error_code == QUERY_FUNCTION_NOT_FOUND
result = await aql.delete_function(fn_name_1, ignore_missing=True)
assert "deletedCount" not in result

# Delete function from bad db
with pytest.raises(AQLFunctionDeleteError) as err:
await bad_db.aql.delete_function(fn_name_2)
assert err.value.error_code == FORBIDDEN

# Delete function group
result = await aql.delete_function(fn_group, group=True)
assert result["deletedCount"] == 1
functions = await aql.functions()
assert len(functions) == 0
1 change: 1 addition & 0 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def test_cursor_write_query(db, doc_col, docs):
cursor = await aql.execute(
"""
FOR d IN {col} FILTER d.val == @first OR d.val == @second
SORT d.val
UPDATE {{_key: d._key, _val: @val }} IN {col}
RETURN NEW
""".format(
Expand Down
Loading