diff --git a/CHANGES.md b/CHANGES.md index 22d0726a..3bb8aa9d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,8 @@ - Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying [KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions, you can use it like `FloatVector(dimensions=1536)`. +- Fixed `get_table_names()` reflection method to respect the + `schema` query argument in SQLAlchemy connection URLs. [FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector [KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match diff --git a/pyproject.toml b/pyproject.toml index 30c13633..48dac541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ release = [ "twine<6", ] test = [ + "cratedb-toolkit[testing]", "dask[dataframe]", "pandas<2.3", "pueblo>=0.0.7", diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 53fae734..43af2fc4 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -229,6 +229,15 @@ def connect(self, host=None, port=None, *args, **kwargs): def _get_default_schema_name(self, connection): return 'doc' + def _get_effective_schema_name(self, connection): + schema_name_raw = connection.engine.url.query.get("schema") + schema_name = None + if isinstance(schema_name_raw, str): + schema_name = schema_name_raw + elif isinstance(schema_name_raw, tuple): + schema_name = schema_name_raw[0] + return schema_name + def _get_server_version_info(self, connection): return tuple(connection.connection.lowest_server_version.version) @@ -258,6 +267,8 @@ def get_schema_names(self, connection, **kw): @reflection.cache def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self._get_effective_schema_name(connection) cursor = connection.exec_driver_sql( "SELECT table_name FROM information_schema.tables " "WHERE {0} = ? " diff --git a/tests/compiler_test.py b/tests/compiler_test.py index 6773b75e..a40ebb0f 100644 --- a/tests/compiler_test.py +++ b/tests/compiler_test.py @@ -18,7 +18,6 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. -import sys import warnings from textwrap import dedent from unittest import mock, skipIf, TestCase @@ -289,8 +288,7 @@ def test_for_update(self): FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) -@skipIf(SA_VERSION < SA_1_4 and (3, 9) <= sys.version_info < (3, 10), - "SQLAlchemy 1.3 has problems with these test cases on Python 3.9") +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class CompilerTestCase(TestCase): """ A base class for providing mocking infrastructure to validate the DDL compiler. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..88b10d9d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021-2023, Crate.io Inc. +# Distributed under the terms of the AGPLv3 license, see LICENSE. +import pytest +from cratedb_toolkit.testing.testcontainers.cratedb import CrateDBTestAdapter + +# Use different schemas for storing the subsystem database tables, and the +# test/example data, so that they do not accidentally touch the default `doc` +# schema. +TESTDRIVE_EXT_SCHEMA = "testdrive-ext" +TESTDRIVE_DATA_SCHEMA = "testdrive-data" + + +@pytest.fixture(scope="session") +def cratedb_service(): + """ + Provide a CrateDB service instance to the test suite. + """ + db = CrateDBTestAdapter() + db.start() + yield db + db.stop() diff --git a/tests/datetime_test.py b/tests/datetime_test.py index 07e98ede..53c30fce 100644 --- a/tests/datetime_test.py +++ b/tests/datetime_test.py @@ -20,13 +20,17 @@ # software solely pursuant to the terms of the relevant commercial agreement. from __future__ import absolute_import + from datetime import datetime, tzinfo, timedelta -from unittest import TestCase +from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock import sqlalchemy as sa from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import Session + +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -52,6 +56,7 @@ def dst(self, date_time): return timedelta(seconds=-7200) +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") @patch('crate.client.connection.Cursor', FakeCursor) class SqlAlchemyDateAndDateTimeTest(TestCase): diff --git a/tests/dict_test.py b/tests/dict_test.py index 84b6f491..5f2692c1 100644 --- a/tests/dict_test.py +++ b/tests/dict_test.py @@ -20,7 +20,8 @@ # software solely pursuant to the terms of the relevant commercial agreement. from __future__ import absolute_import -from unittest import TestCase + +from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock import sqlalchemy as sa @@ -31,7 +32,7 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import ObjectArray, ObjectType +from sqlalchemy_cratedb import ObjectArray, ObjectType, SA_VERSION, SA_1_4 from crate.client.cursor import Cursor @@ -40,6 +41,7 @@ FakeCursor.return_value = fake_cursor +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class SqlAlchemyDictTypeTest(TestCase): def setUp(self): diff --git a/tests/insert_from_select_test.py b/tests/insert_from_select_test.py index 692dfa55..a4533a55 100644 --- a/tests/insert_from_select_test.py +++ b/tests/insert_from_select_test.py @@ -18,14 +18,16 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. - from datetime import datetime -from unittest import TestCase +from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock import sqlalchemy as sa from sqlalchemy import select, insert from sqlalchemy.orm import Session + +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -40,6 +42,7 @@ FakeCursor.return_value = fake_cursor +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class SqlAlchemyInsertFromSelectTest(TestCase): def assertSQL(self, expected_str, actual_expr): diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 00000000..83eb3481 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,25 @@ +import sqlalchemy as sa + +from tests.conftest import TESTDRIVE_DATA_SCHEMA + + +def test_correct_schema(cratedb_service): + """ + Tests that the correct schema is being picked up. + """ + database = cratedb_service.database + + tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"' + inspector: sa.Inspector = sa.inspect(database.engine) + database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1") + + assert TESTDRIVE_DATA_SCHEMA in inspector.get_schema_names() + + table_names = inspector.get_table_names(schema=TESTDRIVE_DATA_SCHEMA) + assert table_names == ["foobar"] + + view_names = inspector.get_view_names(schema=TESTDRIVE_DATA_SCHEMA) + assert view_names == [] + + indexes = inspector.get_indexes(tablename) + assert indexes == [] diff --git a/tests/update_test.py b/tests/update_test.py index 5062f229..a70b56cb 100644 --- a/tests/update_test.py +++ b/tests/update_test.py @@ -18,12 +18,11 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. - from datetime import datetime -from unittest import TestCase +from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock -from sqlalchemy_cratedb import ObjectType +from sqlalchemy_cratedb import ObjectType, SA_VERSION, SA_1_4 import sqlalchemy as sa from sqlalchemy.orm import Session @@ -41,6 +40,7 @@ FakeCursor.return_value = fake_cursor +@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class SqlAlchemyUpdateTest(TestCase): def setUp(self):