Skip to content

Commit

Permalink
Add support for index analyzers
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Mar 11, 2024
1 parent e25b698 commit 24ab3c8
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,5 @@ ipython_config.py

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

.idea
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pytest==7.3.1
pytest-asyncio==0.23.3
ruff==0.0.282
twine==4.0.1
python-dotenv==1.0.1
47 changes: 46 additions & 1 deletion src/cassio/table/base_table.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from asyncio import InvalidStateError, Task
import logging
from typing import (
Expand Down Expand Up @@ -31,6 +32,7 @@
DELETE_CQL_TEMPLATE,
SELECT_CQL_TEMPLATE,
INSERT_ROW_CQL_TEMPLATE,
CREATE_INDEX_ANALYZER_CQL_TEMPLATE,
)
from cassio.table.utils import call_wrapped_async

Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
row_id_type: Union[str, List[str]] = ["TEXT"],
skip_provisioning: bool = False,
async_setup: bool = False,
body_index_options: List[Tuple[str, Any]] = None,
) -> None:
self.session = check_resolve_session(session)
self.keyspace = check_resolve_keyspace(keyspace)
Expand All @@ -70,6 +73,7 @@ def __init__(
self.row_id_type = normalize_type_desc(row_id_type)
self.skip_provisioning = skip_provisioning
self._prepared_statements: Dict[str, PreparedStatement] = {}
self._body_index_options = body_index_options
self.db_setup_task: Optional[Task[None]] = None
if async_setup:
self.db_setup_task = asyncio.create_task(self.adb_setup())
Expand Down Expand Up @@ -229,8 +233,16 @@ def _parse_select_core_params(
where_clause_blocks,
select_cql_vals,
) = self._extract_where_clause_blocks(n_kwargs)

if "content" in rest_kwargs:
where_clause_blocks.append(f"body_blob : '{rest_kwargs.pop('content')}'")

assert rest_kwargs == {}
where_clause = "WHERE " + " AND ".join(where_clause_blocks)

if not where_clause_blocks:
where_clause = ""
else:
where_clause = "WHERE " + " AND ".join(where_clause_blocks)
return columns_desc, where_clause, select_cql_vals

def _get_select_cql(self, **kwargs: Any) -> Tuple[str, Tuple[Any, ...]]:
Expand Down Expand Up @@ -354,13 +366,46 @@ def _get_db_setup_cql(self) -> str:
)
return create_table_cql

def _get_create_analyzer_index_cql(self) -> str:
index_name = "idx_body"
index_column = "body_blob"
body_index_options = []
for option in self._body_index_options:
key, value = option
if isinstance(value, dict):
body_index_options.append(f"'{key}': '{json.dumps(value)}'")
elif isinstance(value, str):
body_index_options.append(f"'{key}': '{value}'")
elif isinstance(value, bool):
if value:
body_index_options.append(f"'{key}': true")
else:
body_index_options.append(f"'{key}': false")
else:
raise ValueError("Unsupported body_index_option format")

create_index_cql = CREATE_INDEX_ANALYZER_CQL_TEMPLATE.format(
index_name=index_name,
index_column=index_column,
body_index_options=", ".join(body_index_options),
)
return create_index_cql

def db_setup(self) -> None:
create_table_cql = self._get_db_setup_cql()
self.execute_cql(create_table_cql, op_type=CQLOpType.SCHEMA)
if self._body_index_options:
self.execute_cql(
self._get_create_analyzer_index_cql(), op_type=CQLOpType.SCHEMA
)

async def adb_setup(self) -> None:
create_table_cql = self._get_db_setup_cql()
await self.aexecute_cql(create_table_cql, op_type=CQLOpType.SCHEMA)
if self._body_index_options:
await self.aexecute_cql(
self._get_create_analyzer_index_cql(), op_type=CQLOpType.SCHEMA
)

def _ensure_db_setup(self) -> None:
if self.db_setup_task:
Expand Down
17 changes: 14 additions & 3 deletions src/cassio/table/cql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from typing import Any, List, Tuple, Union
from enum import Enum

Expand All @@ -22,11 +23,15 @@ class CQLOpType(Enum):

INSERT_ROW_CQL_TEMPLATE = """INSERT INTO {{table_fqname}} ({columns_desc}) VALUES ({value_placeholders}) {ttl_spec} ;""" # noqa: E501

CREATE_INDEX_CQL_TEMPLATE = """CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} ({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501
CREATE_INDEX_CQL_PREFIX = "CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} "

CREATE_KEYS_INDEX_CQL_TEMPLATE = """CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} (KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501
CREATE_INDEX_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501

CREATE_ENTRIES_INDEX_CQL_TEMPLATE = """CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} (ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501
CREATE_INDEX_ANALYZER_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' WITH OPTIONS = {{{{{body_index_options}}}}};" # noqa: E501

CREATE_KEYS_INDEX_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "(KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501

CREATE_ENTRIES_INDEX_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "(ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501

SELECT_ANN_CQL_TEMPLATE = """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} ORDER BY {vector_column} ANN OF %s {limit_clause};""" # noqa: E501

Expand Down Expand Up @@ -128,3 +133,9 @@ def assert_last_equal(
expe_cql = self.normalizeCQLStatement(s_expe[0])
assert exe_cql == expe_cql, f"EXE#{exe_cql}# != EXPE#{expe_cql}#"
return None


STANDARD_ANALYZER = ("index_analyzer", "STANDARD")
LOWER_CASE_ANALYZER = ("case_sensitive", False)
NORMALIZE_ANALYZER = ("normalize", True)
ASCII_ANALYZER = ("ascii", True)
6 changes: 5 additions & 1 deletion src/cassio/table/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,12 @@ def _get_ann_search_cql(
where_clause_blocks,
where_cql_vals,
) = self._extract_where_clause_blocks(n_kwargs)

if "content" in rest_kwargs:
where_clause_blocks.append(f"body_blob : '{rest_kwargs.pop('content')}'")

assert rest_kwargs == {}
if where_clause_blocks == []:
if not where_clause_blocks:
where_clause = ""
else:
where_clause = "WHERE " + " AND ".join(where_clause_blocks)
Expand Down
31 changes: 26 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,50 @@
"""

import os
import tempfile
from typing import Dict, List

import pytest

from cassandra.cluster import Cluster # type: ignore
from cassandra.auth import PlainTextAuthProvider # type: ignore

from cassio.config import download_astra_bundle_url
from cassio.table.cql import MockDBSession

import cassio

from dotenv import load_dotenv

load_dotenv()


# DB session (as per settings detected in env vars)
dbSession = None


def createDBSessionSingleton():
def createDBSessionSingleton(secure_bundle_dir: str = None):
global dbSession
if dbSession is None:
mode = os.getenv("TEST_DB_MODE", "LOCAL_CASSANDRA")
# the proper DB session is created as required
if mode == "ASTRA_DB":
ASTRA_DB_SECURE_BUNDLE_PATH = os.environ["ASTRA_DB_SECURE_BUNDLE_PATH"]
ASTRA_DB_CLIENT_ID = "token"
ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
if "ASTRA_DB_SECURE_BUNDLE_PATH" in os.environ:
ASTRA_DB_SECURE_BUNDLE_PATH = os.environ["ASTRA_DB_SECURE_BUNDLE_PATH"]
elif "ASTRA_DB_DATABASE_ID" in os.environ and secure_bundle_dir:
ASTRA_DB_DATABASE_ID = os.environ["ASTRA_DB_DATABASE_ID"]
ASTRA_DB_SECURE_BUNDLE_PATH = os.path.join(
secure_bundle_dir, "secure-connect-bundle_devopsapi.zip"
)
download_astra_bundle_url(
ASTRA_DB_DATABASE_ID,
ASTRA_DB_APPLICATION_TOKEN,
ASTRA_DB_SECURE_BUNDLE_PATH,
)
else:
raise ValueError("Missing secure bundle path")
ASTRA_DB_CLIENT_ID = "token"
cluster = Cluster(
cloud={
"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH,
Expand Down Expand Up @@ -70,7 +89,7 @@ def createDBSessionSingleton():
def getDBKeyspace():
mode = os.getenv("TEST_DB_MODE", "LOCAL_CASSANDRA")
if mode == "ASTRA_DB":
ASTRA_DB_KEYSPACE = os.environ["ASTRA_DB_KEYSPACE"]
ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE", "default_keyspace")
return ASTRA_DB_KEYSPACE
elif mode == "LOCAL_CASSANDRA":
CASSANDRA_KEYSPACE = os.getenv("CASSANDRA_KEYSPACE", "default_keyspace")
Expand All @@ -82,7 +101,9 @@ def getDBKeyspace():

@pytest.fixture(scope="session")
def db_session():
return createDBSessionSingleton()
secure_bundle_dir = tempfile.TemporaryDirectory()
yield createDBSessionSingleton(secure_bundle_dir.name)
secure_bundle_dir.cleanup()


@pytest.fixture(scope="session")
Expand Down
10 changes: 8 additions & 2 deletions tests/integration/test_tableclasses_plaincassandratable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from cassio.table.cql import STANDARD_ANALYZER
from cassio.table.tables import (
PlainCassandraTable,
)
Expand All @@ -20,13 +21,18 @@ def test_crud(self, db_session, db_keyspace):
keyspace=db_keyspace,
table=table_name,
primary_key_type="TEXT",
body_index_options=[STANDARD_ANALYZER],
)
t.put(row_id="empty_row")
gotten1 = t.get(row_id="empty_row")
assert gotten1 == {"row_id": "empty_row", "body_blob": None}
t.put(row_id="full_row", body_blob="body_blob")
t.put(row_id="full_row", body_blob="body blob")
gotten2 = t.get(row_id="full_row")
assert gotten2 == {"row_id": "full_row", "body_blob": "body_blob"}
assert gotten2 == {"row_id": "full_row", "body_blob": "body blob"}
gotten2b = t.get(content="blob")
assert gotten2b == {"row_id": "full_row", "body_blob": "body blob"}
gotten2c = t.get(content="foo")
assert gotten2c is None
t.delete(row_id="full_row")
gotten2n = t.get(row_id="full_row")
assert gotten2n is None
Expand Down
15 changes: 12 additions & 3 deletions tests/integration/test_tableclasses_vectorcassandratable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import pytest

from cassio.table.cql import STANDARD_ANALYZER
from cassio.table.tables import (
VectorCassandraTable,
)
Expand All @@ -24,13 +25,14 @@ def test_crud(self, db_session, db_keyspace):
table=table_name,
vector_dimension=2,
primary_key_type="TEXT",
body_index_options=[STANDARD_ANALYZER],
)

for n_theta in range(N):
theta = n_theta * math.pi * 2 / N
t.put(
row_id=f"theta_{n_theta}",
body_blob=f"theta = {theta:.4f}",
body_blob=f"theta_{n_theta} = {theta:.4f}",
vector=[math.cos(theta), math.sin(theta)],
)

Expand All @@ -44,8 +46,15 @@ def test_crud(self, db_session, db_keyspace):
query_theta = 1 * math.pi * 2 / (2 * N)
ref_vector = [math.cos(query_theta), math.sin(query_theta)]
ann_results = list(t.ann_search(ref_vector, n=4))
assert {r["row_id"] for r in ann_results[:2]} == {"theta_1", "theta_0"}
assert {r["row_id"] for r in ann_results[2:4]} == {"theta_2", "theta_7"}
assert [r["row_id"] for r in ann_results] == [
"theta_1",
"theta_0",
"theta_2",
"theta_7",
]

ann_results = list(t.ann_search(ref_vector, n=4, content="theta_2"))
assert [r["row_id"] for r in ann_results] == ["theta_2"]

t.clear()

Expand Down

0 comments on commit 24ab3c8

Please sign in to comment.