From b69e366e285ebfbf53f38d87a05ce912749cac4d Mon Sep 17 00:00:00 2001 From: Dmitry Maslennikov Date: Sat, 22 Jun 2024 09:36:36 +1000 Subject: [PATCH] fix support older engine --- .github/workflows/python-publish.yml | 15 +- sqlalchemy_iris/base.py | 5 +- sqlalchemy_iris/types.py | 117 ++++++----- tests/test_suite.py | 302 ++++++++++++++------------- tox.ini | 21 ++ 5 files changed, 250 insertions(+), 210 deletions(-) create mode 100644 tox.ini diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 4b82acf..f453270 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -24,19 +24,22 @@ jobs: image: - intersystemsdc/iris-community:latest - intersystemsdc/iris-community:preview - - intersystemsdc/iris-community:2024.1-preview + engine: + - old + - new runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' - name: Install requirements run: | - pip install -r requirements-dev.txt \ - -r requirements-iris.txt \ - -e . + pip install tox - name: Run Tests run: | - pytest --container ${{ matrix.image }} - + tox -e py311${{ matrix.engine }} -- --container ${{ matrix.image }} deploy: needs: test if: github.event_name != 'pull_request' diff --git a/sqlalchemy_iris/base.py b/sqlalchemy_iris/base.py index 60e02d3..60ddfe8 100644 --- a/sqlalchemy_iris/base.py +++ b/sqlalchemy_iris/base.py @@ -91,7 +91,6 @@ def check_constraints(cls): from .types import IRISTimeStamp from .types import IRISDate from .types import IRISDateTime -from .types import IRISUniqueIdentifier from .types import IRISListBuild # noqa from .types import IRISVector # noqa @@ -819,8 +818,10 @@ def create_cursor(self): sqltypes.DateTime: IRISDateTime, sqltypes.TIMESTAMP: IRISTimeStamp, sqltypes.Time: IRISTime, - sqltypes.UUID: IRISUniqueIdentifier, } +if sqlalchemy_version.startswith("2."): + from .types import IRISUniqueIdentifier + colspecs[sqltypes.UUID] = IRISUniqueIdentifier class IRISExact(ReturnTypeFromArgs): diff --git a/sqlalchemy_iris/types.py b/sqlalchemy_iris/types.py index 2f4cd26..d77a373 100644 --- a/sqlalchemy_iris/types.py +++ b/sqlalchemy_iris/types.py @@ -2,9 +2,10 @@ from decimal import Decimal from sqlalchemy import func, text from sqlalchemy.sql import sqltypes -from sqlalchemy.types import UserDefinedType, Float +from sqlalchemy.types import UserDefinedType from uuid import UUID as _python_UUID from intersystems_iris import IRISList +from sqlalchemy import __version__ as sqlalchemy_version HOROLOG_ORDINAL = datetime.date(1840, 12, 31).toordinal() @@ -134,73 +135,79 @@ def process(value): return process -class IRISUniqueIdentifier(sqltypes.Uuid): - def literal_processor(self, dialect): - if not self.as_uuid: +if sqlalchemy_version.startswith("2."): - def process(value): - return f"""'{value.replace("'", "''")}'""" - - return process - else: - - def process(value): - return f"""'{str(value).replace("'", "''")}'""" - - return process - - def bind_processor(self, dialect): - character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid - - if character_based_uuid: - if self.as_uuid: + class IRISUniqueIdentifier(sqltypes.Uuid): + def literal_processor(self, dialect): + if not self.as_uuid: def process(value): - if value is not None: - value = str(value) - return value + return f"""'{value.replace("'", "''")}'""" return process else: def process(value): - return value + return f"""'{str(value).replace("'", "''")}'""" return process - else: - return None - def result_processor(self, dialect, coltype): - character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid + def bind_processor(self, dialect): + character_based_uuid = ( + not dialect.supports_native_uuid or not self.native_uuid + ) - if character_based_uuid: - if self.as_uuid: + if character_based_uuid: + if self.as_uuid: - def process(value): - if value and not isinstance(value, _python_UUID): - value = _python_UUID(value) - return value + def process(value): + if value is not None: + value = str(value) + return value - return process + return process + else: + + def process(value): + return value + + return process else: + return None - def process(value): - if value and isinstance(value, _python_UUID): - value = str(value) - return value + def result_processor(self, dialect, coltype): + character_based_uuid = ( + not dialect.supports_native_uuid or not self.native_uuid + ) - return process - else: - if not self.as_uuid: + if character_based_uuid: + if self.as_uuid: - def process(value): - if value and isinstance(value, _python_UUID): - value = str(value) - return value + def process(value): + if value and not isinstance(value, _python_UUID): + value = _python_UUID(value) + return value - return process + return process + else: + + def process(value): + if value and isinstance(value, _python_UUID): + value = str(value) + return value + + return process else: - return None + if not self.as_uuid: + + def process(value): + if value and isinstance(value, _python_UUID): + value = str(value) + return value + + return process + else: + return None class IRISListBuild(UserDefinedType): @@ -267,9 +274,7 @@ def __init__(self, max_items: int = None, item_type: type = float): item_type_server = ( "decimal" if self.item_type is float - else "float" - if self.item_type is Decimal - else "int" + else "float" if self.item_type is Decimal else "int" ) self.item_type_server = item_type_server @@ -304,19 +309,21 @@ class comparator_factory(UserDefinedType.Comparator): # return self.func('vector_l2', other) def max_inner_product(self, other): - return self.func('vector_dot_product', other) + return self.func("vector_dot_product", other) def cosine_distance(self, other): - return self.func('vector_cosine', other) + return self.func("vector_cosine", other) def cosine(self, other): - return (1 - self.func('vector_cosine', other)) + return 1 - self.func("vector_cosine", other) def func(self, funcname: str, other): if not isinstance(other, list) and not isinstance(other, tuple): raise ValueError("expected list or tuple, got '%s'" % type(other)) othervalue = f"[{','.join([str(v) for v in other])}]" - return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server))) + return getattr(func, funcname)( + self, func.to_vector(othervalue, text(self.type.item_type_server)) + ) class BIT(sqltypes.TypeEngine): diff --git a/tests/test_suite.py b/tests/test_suite.py index 2f948af..afaa961 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -4,9 +4,6 @@ from sqlalchemy.testing.suite import CompoundSelectTest as _CompoundSelectTest from sqlalchemy.testing.suite import CTETest as _CTETest from sqlalchemy.testing.suite import DifficultParametersTest as _DifficultParametersTest -from sqlalchemy.testing.suite import ( - BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, -) from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing import config @@ -26,6 +23,8 @@ from sqlalchemy.testing.suite import * # noqa +from sqlalchemy import __version__ as sqlalchemy_version + class CompoundSelectTest(_CompoundSelectTest): @pytest.mark.skip() @@ -270,166 +269,175 @@ def test_expect_bytes(self): ) -class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): - @testing.combinations( - ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" - ) - @testing.variation("use_composite", [True, False]) - @testing.combinations( - ("plain",), - # ("(2)",), not in IRIS - ("per % cent",), - ("[brackets]",), - argnames="tablename", +if sqlalchemy_version.startswith("2."): + from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, ) - def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): - super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) - -class IRISListBuildTest(fixtures.TablesTest): - __backend__ = True - - @classmethod - def define_tables(cls, metadata): - Table( - "data", - metadata, - Column("val", IRISListBuild(10, float)), + class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + # ("(2)",), not in IRIS + ("per % cent",), + ("[brackets]",), + argnames="tablename", ) + def test_fk_ref( + self, connection, metadata, use_composite, tablename, columnname + ): + super().test_fk_ref( + connection, metadata, use_composite, tablename, columnname + ) - @classmethod - def fixtures(cls): - return dict( - data=( - ("val",), - ([1.0] * 50,), - ([1.23] * 50,), - ([i for i in range(0, 50)],), - (None,), + class IRISListBuildTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "data", + metadata, + Column("val", IRISListBuild(10, float)), ) - ) - def _assert_result(self, select, result): - with config.db.connect() as conn: - eq_(conn.execute(select).fetchall(), result) + @classmethod + def fixtures(cls): + return dict( + data=( + ("val",), + ([1.0] * 50,), + ([1.23] * 50,), + ([i for i in range(0, 50)],), + (None,), + ) + ) - def test_listbuild(self): - self._assert_result( - select(self.tables.data), - [ - ([1.0] * 50,), - ([1.23] * 50,), - ([i for i in range(0, 50)],), - (None,), - ], - ) - self._assert_result( - select(self.tables.data).where(self.tables.data.c.val == [1.0] * 50), - [ - ([1.0] * 50,), - ], - ) + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) - self._assert_result( - select( - self.tables.data, - self.tables.data.c.val.func("$listsame", [1.0] * 50).label("same"), - ).limit(1), - [ - ([1.0] * 50, 1), - ], - ) + def test_listbuild(self): + self._assert_result( + select(self.tables.data), + [ + ([1.0] * 50,), + ([1.23] * 50,), + ([i for i in range(0, 50)],), + (None,), + ], + ) + self._assert_result( + select(self.tables.data).where(self.tables.data.c.val == [1.0] * 50), + [ + ([1.0] * 50,), + ], + ) + self._assert_result( + select( + self.tables.data, + self.tables.data.c.val.func("$listsame", [1.0] * 50).label("same"), + ).limit(1), + [ + ([1.0] * 50, 1), + ], + ) -class IRISVectorTest(fixtures.TablesTest): - __backend__ = True + class IRISVectorTest(fixtures.TablesTest): + __backend__ = True - __requires__ = ("iris_vector",) + __requires__ = ("iris_vector",) - @classmethod - def define_tables(cls, metadata): - Table( - "data", - metadata, - Column("id", INTEGER), - Column("emb", IRISVector(3, float)), - ) + @classmethod + def define_tables(cls, metadata): + Table( + "data", + metadata, + Column("id", INTEGER), + Column("emb", IRISVector(3, float)), + ) - @classmethod - def fixtures(cls): - return dict( - data=( - ( - "id", - "emb", - ), - ( - 1, - [1, 1, 1], - ), - ( - 2, - [2, 2, 2], - ), - ( - 3, - [1, 1, 2], - ), + @classmethod + def fixtures(cls): + return dict( + data=( + ( + "id", + "emb", + ), + ( + 1, + [1, 1, 1], + ), + ( + 2, + [2, 2, 2], + ), + ( + 3, + [1, 1, 2], + ), + ) ) - ) - def _assert_result(self, select, result): - with config.db.connect() as conn: - eq_(conn.execute(select).fetchall(), result) + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) - def test_vector(self): - self._assert_result( - select(self.tables.data.c.emb), - [ - ([1, 1, 1],), - ([2, 2, 2],), - ([1, 1, 2],), - ], - ) - self._assert_result( - select(self.tables.data.c.id).where(self.tables.data.c.emb == [2, 2, 2]), - [ - (2,), - ], - ) + def test_vector(self): + self._assert_result( + select(self.tables.data.c.emb), + [ + ([1, 1, 1],), + ([2, 2, 2],), + ([1, 1, 2],), + ], + ) + self._assert_result( + select(self.tables.data.c.id).where( + self.tables.data.c.emb == [2, 2, 2] + ), + [ + (2,), + ], + ) - def test_cosine(self): - self._assert_result( - select( - self.tables.data.c.id, - ).order_by(self.tables.data.c.emb.cosine([1, 1, 1])), - [ - (1,), - (2,), - (3,), - ], - ) + def test_cosine(self): + self._assert_result( + select( + self.tables.data.c.id, + ).order_by(self.tables.data.c.emb.cosine([1, 1, 1])), + [ + (1,), + (2,), + (3,), + ], + ) - def test_cosine_distance(self): - self._assert_result( - select( - self.tables.data.c.id, - ).order_by(1 - self.tables.data.c.emb.cosine_distance([1, 1, 1])), - [ - (1,), - (2,), - (3,), - ], - ) + def test_cosine_distance(self): + self._assert_result( + select( + self.tables.data.c.id, + ).order_by(1 - self.tables.data.c.emb.cosine_distance([1, 1, 1])), + [ + (1,), + (2,), + (3,), + ], + ) - def test_max_inner_product(self): - self._assert_result( - select( - self.tables.data.c.id, - ).order_by(self.tables.data.c.emb.max_inner_product([1, 1, 1])), - [ - (1,), - (3,), - (2,), - ], - ) + def test_max_inner_product(self): + self._assert_result( + select( + self.tables.data.c.id, + ).order_by(self.tables.data.c.emb.max_inner_product([1, 1, 1])), + [ + (1,), + (3,), + (2,), + ], + ) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..fbbc048 --- /dev/null +++ b/tox.ini @@ -0,0 +1,21 @@ +[tox] +requires = + tox>=4 +env_list = py{310,311,312}{old,new} + + +[testenv:py{38,39,310,311,312}old] +deps = + sqlalchemy<2 + -r requirements-dev.txt + -r requirements-iris.txt + -e. +commands = {envpython} -m pytest {posargs} + +[testenv:py{38,39,310,311,312}new] +deps = + sqlalchemy>=2 + -r requirements-dev.txt + -r requirements-iris.txt + -e. +commands = {envpython} -m pytest {posargs}