diff --git a/.gitignore b/.gitignore index b1f2e49..d4c70ae 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,5 @@ ipython_config.py # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ + +.idea diff --git a/requirements-dev.txt b/requirements-dev.txt index df2e196..fae2f2e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,4 @@ pytest==7.3.1 pytest-asyncio==0.23.3 ruff==0.0.282 twine==4.0.1 +python-dotenv==1.0.1 \ No newline at end of file diff --git a/src/cassio/table/base_table.py b/src/cassio/table/base_table.py index e663bf1..45cc225 100644 --- a/src/cassio/table/base_table.py +++ b/src/cassio/table/base_table.py @@ -1,4 +1,5 @@ import asyncio +import json from asyncio import InvalidStateError, Task import logging from typing import ( @@ -31,6 +32,7 @@ DELETE_CQL_TEMPLATE, SELECT_CQL_TEMPLATE, INSERT_ROW_CQL_TEMPLATE, + CREATE_INDEX_ANALYZER_CQL_TEMPLATE, ) from cassio.table.utils import call_wrapped_async @@ -62,6 +64,7 @@ def __init__( row_id_type: Union[str, List[str]] = ["TEXT"], skip_provisioning: bool = False, async_setup: bool = False, + body_index_options: List[Tuple[str, Any]] = None, ) -> None: self.session = check_resolve_session(session) self.keyspace = check_resolve_keyspace(keyspace) @@ -70,6 +73,7 @@ def __init__( self.row_id_type = normalize_type_desc(row_id_type) self.skip_provisioning = skip_provisioning self._prepared_statements: Dict[str, PreparedStatement] = {} + self._body_index_options = body_index_options self.db_setup_task: Optional[Task[None]] = None if async_setup: self.db_setup_task = asyncio.create_task(self.adb_setup()) @@ -229,8 +233,16 @@ def _parse_select_core_params( where_clause_blocks, select_cql_vals, ) = self._extract_where_clause_blocks(n_kwargs) + + if "content" in rest_kwargs: + where_clause_blocks.append(f"body_blob : '{rest_kwargs.pop('content')}'") + assert rest_kwargs == {} - where_clause = "WHERE " + " AND ".join(where_clause_blocks) + + if not where_clause_blocks: + where_clause = "" + else: + where_clause = "WHERE " + " AND ".join(where_clause_blocks) return columns_desc, where_clause, select_cql_vals def _get_select_cql(self, **kwargs: Any) -> Tuple[str, Tuple[Any, ...]]: @@ -354,13 +366,46 @@ def _get_db_setup_cql(self) -> str: ) return create_table_cql + def _get_create_analyzer_index_cql(self) -> str: + index_name = "idx_body" + index_column = "body_blob" + body_index_options = [] + for option in self._body_index_options: + key, value = option + if isinstance(value, dict): + body_index_options.append(f"'{key}': '{json.dumps(value)}'") + elif isinstance(value, str): + body_index_options.append(f"'{key}': '{value}'") + elif isinstance(value, bool): + if value: + body_index_options.append(f"'{key}': true") + else: + body_index_options.append(f"'{key}': false") + else: + raise ValueError("Unsupported body_index_option format") + + create_index_cql = CREATE_INDEX_ANALYZER_CQL_TEMPLATE.format( + index_name=index_name, + index_column=index_column, + body_index_options=", ".join(body_index_options), + ) + return create_index_cql + def db_setup(self) -> None: create_table_cql = self._get_db_setup_cql() self.execute_cql(create_table_cql, op_type=CQLOpType.SCHEMA) + if self._body_index_options: + self.execute_cql( + self._get_create_analyzer_index_cql(), op_type=CQLOpType.SCHEMA + ) async def adb_setup(self) -> None: create_table_cql = self._get_db_setup_cql() await self.aexecute_cql(create_table_cql, op_type=CQLOpType.SCHEMA) + if self._body_index_options: + await self.aexecute_cql( + self._get_create_analyzer_index_cql(), op_type=CQLOpType.SCHEMA + ) def _ensure_db_setup(self) -> None: if self.db_setup_task: diff --git a/src/cassio/table/cql.py b/src/cassio/table/cql.py index c0ae52b..d7abf59 100644 --- a/src/cassio/table/cql.py +++ b/src/cassio/table/cql.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Any, List, Tuple, Union from enum import Enum @@ -22,11 +23,15 @@ class CQLOpType(Enum): INSERT_ROW_CQL_TEMPLATE = """INSERT INTO {{table_fqname}} ({columns_desc}) VALUES ({value_placeholders}) {ttl_spec} ;""" # noqa: E501 -CREATE_INDEX_CQL_TEMPLATE = """CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} ({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501 +CREATE_INDEX_CQL_PREFIX = "CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} " -CREATE_KEYS_INDEX_CQL_TEMPLATE = """CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} (KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501 +CREATE_INDEX_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501 -CREATE_ENTRIES_INDEX_CQL_TEMPLATE = """CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} (ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501 +CREATE_INDEX_ANALYZER_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' WITH OPTIONS = {{{{{body_index_options}}}}};" # noqa: E501 + +CREATE_KEYS_INDEX_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "(KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501 + +CREATE_ENTRIES_INDEX_CQL_TEMPLATE = CREATE_INDEX_CQL_PREFIX + "(ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';""" # noqa: E501 SELECT_ANN_CQL_TEMPLATE = """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} ORDER BY {vector_column} ANN OF %s {limit_clause};""" # noqa: E501 @@ -128,3 +133,9 @@ def assert_last_equal( expe_cql = self.normalizeCQLStatement(s_expe[0]) assert exe_cql == expe_cql, f"EXE#{exe_cql}# != EXPE#{expe_cql}#" return None + + +STANDARD_ANALYZER = ("index_analyzer", "STANDARD") +LOWER_CASE_ANALYZER = ("case_sensitive", False) +NORMALIZE_ANALYZER = ("normalize", True) +ASCII_ANALYZER = ("ascii", True) diff --git a/src/cassio/table/mixins.py b/src/cassio/table/mixins.py index ba4448a..0e1683b 100644 --- a/src/cassio/table/mixins.py +++ b/src/cassio/table/mixins.py @@ -632,8 +632,12 @@ def _get_ann_search_cql( where_clause_blocks, where_cql_vals, ) = self._extract_where_clause_blocks(n_kwargs) + + if "content" in rest_kwargs: + where_clause_blocks.append(f"body_blob : '{rest_kwargs.pop('content')}'") + assert rest_kwargs == {} - if where_clause_blocks == []: + if not where_clause_blocks: where_clause = "" else: where_clause = "WHERE " + " AND ".join(where_clause_blocks) diff --git a/tests/conftest.py b/tests/conftest.py index eea3ba6..cca9b7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ """ import os +import tempfile from typing import Dict, List import pytest @@ -10,24 +11,42 @@ from cassandra.cluster import Cluster # type: ignore from cassandra.auth import PlainTextAuthProvider # type: ignore +from cassio.config import download_astra_bundle_url from cassio.table.cql import MockDBSession import cassio +from dotenv import load_dotenv + +load_dotenv() + # DB session (as per settings detected in env vars) dbSession = None -def createDBSessionSingleton(): +def createDBSessionSingleton(secure_bundle_dir: str = None): global dbSession if dbSession is None: mode = os.getenv("TEST_DB_MODE", "LOCAL_CASSANDRA") # the proper DB session is created as required if mode == "ASTRA_DB": - ASTRA_DB_SECURE_BUNDLE_PATH = os.environ["ASTRA_DB_SECURE_BUNDLE_PATH"] - ASTRA_DB_CLIENT_ID = "token" ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"] + if "ASTRA_DB_SECURE_BUNDLE_PATH" in os.environ: + ASTRA_DB_SECURE_BUNDLE_PATH = os.environ["ASTRA_DB_SECURE_BUNDLE_PATH"] + elif "ASTRA_DB_DATABASE_ID" in os.environ and secure_bundle_dir: + ASTRA_DB_DATABASE_ID = os.environ["ASTRA_DB_DATABASE_ID"] + ASTRA_DB_SECURE_BUNDLE_PATH = os.path.join( + secure_bundle_dir, "secure-connect-bundle_devopsapi.zip" + ) + download_astra_bundle_url( + ASTRA_DB_DATABASE_ID, + ASTRA_DB_APPLICATION_TOKEN, + ASTRA_DB_SECURE_BUNDLE_PATH, + ) + else: + raise ValueError("Missing secure bundle path") + ASTRA_DB_CLIENT_ID = "token" cluster = Cluster( cloud={ "secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH, @@ -70,7 +89,7 @@ def createDBSessionSingleton(): def getDBKeyspace(): mode = os.getenv("TEST_DB_MODE", "LOCAL_CASSANDRA") if mode == "ASTRA_DB": - ASTRA_DB_KEYSPACE = os.environ["ASTRA_DB_KEYSPACE"] + ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE", "default_keyspace") return ASTRA_DB_KEYSPACE elif mode == "LOCAL_CASSANDRA": CASSANDRA_KEYSPACE = os.getenv("CASSANDRA_KEYSPACE", "default_keyspace") @@ -82,7 +101,9 @@ def getDBKeyspace(): @pytest.fixture(scope="session") def db_session(): - return createDBSessionSingleton() + secure_bundle_dir = tempfile.TemporaryDirectory() + yield createDBSessionSingleton(secure_bundle_dir.name) + secure_bundle_dir.cleanup() @pytest.fixture(scope="session") diff --git a/tests/integration/test_tableclasses_plaincassandratable.py b/tests/integration/test_tableclasses_plaincassandratable.py index 7674199..dece37d 100644 --- a/tests/integration/test_tableclasses_plaincassandratable.py +++ b/tests/integration/test_tableclasses_plaincassandratable.py @@ -4,6 +4,7 @@ import pytest +from cassio.table.cql import STANDARD_ANALYZER from cassio.table.tables import ( PlainCassandraTable, ) @@ -20,13 +21,18 @@ def test_crud(self, db_session, db_keyspace): keyspace=db_keyspace, table=table_name, primary_key_type="TEXT", + body_index_options=[STANDARD_ANALYZER], ) t.put(row_id="empty_row") gotten1 = t.get(row_id="empty_row") assert gotten1 == {"row_id": "empty_row", "body_blob": None} - t.put(row_id="full_row", body_blob="body_blob") + t.put(row_id="full_row", body_blob="body blob") gotten2 = t.get(row_id="full_row") - assert gotten2 == {"row_id": "full_row", "body_blob": "body_blob"} + assert gotten2 == {"row_id": "full_row", "body_blob": "body blob"} + gotten2b = t.get(content="blob") + assert gotten2b == {"row_id": "full_row", "body_blob": "body blob"} + gotten2c = t.get(content="foo") + assert gotten2c is None t.delete(row_id="full_row") gotten2n = t.get(row_id="full_row") assert gotten2n is None diff --git a/tests/integration/test_tableclasses_vectorcassandratable.py b/tests/integration/test_tableclasses_vectorcassandratable.py index d954117..e88730f 100644 --- a/tests/integration/test_tableclasses_vectorcassandratable.py +++ b/tests/integration/test_tableclasses_vectorcassandratable.py @@ -4,6 +4,7 @@ import math import pytest +from cassio.table.cql import STANDARD_ANALYZER from cassio.table.tables import ( VectorCassandraTable, ) @@ -24,13 +25,14 @@ def test_crud(self, db_session, db_keyspace): table=table_name, vector_dimension=2, primary_key_type="TEXT", + body_index_options=[STANDARD_ANALYZER], ) for n_theta in range(N): theta = n_theta * math.pi * 2 / N t.put( row_id=f"theta_{n_theta}", - body_blob=f"theta = {theta:.4f}", + body_blob=f"theta_{n_theta} = {theta:.4f}", vector=[math.cos(theta), math.sin(theta)], ) @@ -44,8 +46,15 @@ def test_crud(self, db_session, db_keyspace): query_theta = 1 * math.pi * 2 / (2 * N) ref_vector = [math.cos(query_theta), math.sin(query_theta)] ann_results = list(t.ann_search(ref_vector, n=4)) - assert {r["row_id"] for r in ann_results[:2]} == {"theta_1", "theta_0"} - assert {r["row_id"] for r in ann_results[2:4]} == {"theta_2", "theta_7"} + assert [r["row_id"] for r in ann_results] == [ + "theta_1", + "theta_0", + "theta_2", + "theta_7", + ] + + ann_results = list(t.ann_search(ref_vector, n=4, content="theta_2")) + assert [r["row_id"] for r in ann_results] == ["theta_2"] t.clear()