From 73c83e9728187dc7d1c851e91a386cb1c9a372e3 Mon Sep 17 00:00:00 2001 From: Dmitry Maslennikov Date: Fri, 28 Oct 2022 22:45:40 +0400 Subject: [PATCH] various updates --- setup.cfg | 5 ++- setup.py | 2 +- sqlalchemy_iris/__init__.py | 8 ++-- sqlalchemy_iris/base.py | 55 +++++++++++++++++++++++--- sqlalchemy_iris/iris.py | 4 +- sqlalchemy_iris/requirements.py | 9 ++--- test/conftest.py | 6 +-- test/test_suite.py | 68 ++++++++++++++++++++++++++++++++- 8 files changed, 131 insertions(+), 26 deletions(-) diff --git a/setup.cfg b/setup.cfg index 943b097..9f8976c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,9 +36,12 @@ addopts= --tb native -v -r fxX --maxfail=25 -p no:warnings python_files=test/*test_*.py [db] -default=iris+iris://_SYSTEM:SYS@localhost:1972/USER +default=iris://_SYSTEM:SYS@localhost:1972/USER sqlite=sqlite:///:memory: [sqla_testing] requirement_cls=sqlalchemy_iris.requirements:Requirements profile_file=test/profiles.txt + +[flake8] +max-line-length=120 diff --git a/setup.py b/setup.py index 02bf717..254aa1f 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ dependency_links=dependency_links, entry_points={ "sqlalchemy.dialects": [ - "iris = sqlalchemy_iris:IRISDialect", + "iris = sqlalchemy_iris.iris:IRISDialect_iris", ] }, ) diff --git a/sqlalchemy_iris/__init__.py b/sqlalchemy_iris/__init__.py index 75bae7b..70e9fde 100644 --- a/sqlalchemy_iris/__init__.py +++ b/sqlalchemy_iris/__init__.py @@ -1,12 +1,12 @@ +from sqlalchemy.dialects import registry as _registry + from . import base from . import iris -from .base import IRISDialect -from .iris import IRISDialect_iris base.dialect = dialect = iris.dialect +_registry.register("iris.iris", "sqlalchemy_iris.iris", "IRISDialect_iris") + __all__ = [ - IRISDialect, - IRISDialect_iris, dialect, ] diff --git a/sqlalchemy_iris/base.py b/sqlalchemy_iris/base.py index cb7aa08..4c49c45 100644 --- a/sqlalchemy_iris/base.py +++ b/sqlalchemy_iris/base.py @@ -9,7 +9,7 @@ from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import between from sqlalchemy.sql import func -from sqlalchemy import sql +from sqlalchemy import sql, text from sqlalchemy import util from sqlalchemy import types as sqltypes @@ -393,6 +393,9 @@ def translate_select_structure(self, select_stmt, **kwargs): for elem in select._order_by_clause.clauses ] + if not _order_by_clauses: + _order_by_clauses = [text('%id')] + limit_clause = self._get_limit_or_fetch(select) offset_clause = select._offset_clause @@ -440,6 +443,44 @@ def visit_drop_schema(self, drop, **kw): def visit_check_constraint(self, constraint, **kw): raise exc.CompileError("Check CONSTRAINT not supported") + def get_column_specification(self, column, **kwargs): + + colspec = [ + self.preparer.format_column(column), + ] + + if ( + column.primary_key + and column is column.table._autoincrement_column + ): + colspec.append("SERIAL") + else: + colspec.append( + self.dialect.type_compiler.process( + column.type, + type_expression=column, + identifier_preparer=self.preparer, + ) + ) + + if column.computed is not None: + colspec.append(self.process(column.computed)) + default = self.get_column_default_string(column) + if default is not None: + colspec.append("DEFAULT " + default) + + if not column.nullable: + colspec.append("NOT NULL") + + comment = column.comment + if comment is not None: + literal = self.sql_compiler.render_literal_value( + comment, sqltypes.String() + ) + colspec.append("%%DESCRIPTION " + literal) + + return " ".join(colspec) + class IRISTypeCompiler(compiler.GenericTypeCompiler): def visit_boolean(self, type_, **kw): @@ -536,11 +577,13 @@ def process(value): class IRISDialect(default.DefaultDialect): - driver = 'iris' + + name = 'iris' default_schema_name = "SQLUser" default_paramstyle = "format" + supports_statement_cache = True supports_native_decimal = True supports_sane_rowcount = True @@ -551,7 +594,6 @@ class IRISDialect(default.DefaultDialect): supports_sequences = False - supports_statement_cache = False postfetch_lastrowid = False non_native_boolean_check_constraint = False supports_simple_order_by_label = False @@ -610,6 +652,7 @@ def _fix_for_params(self, query, params, many=False): def do_execute(self, cursor, query, params, context=None): query, params = self._fix_for_params(query, params) + # print('do_execute', query, params) cursor.execute(query, params) def do_executemany(self, cursor, query, params, context=None): @@ -903,7 +946,7 @@ def get_columns(self, connection, table_name, schema=None, **kw): ): if charlen == -1: charlen = None - kwargs["length"] = charlen + kwargs["length"] = int(charlen) if collation: kwargs["collation"] = collation if coltype is None: @@ -914,10 +957,10 @@ def get_columns(self, connection, table_name, schema=None, **kw): coltype = sqltypes.NULLTYPE else: if issubclass(coltype, sqltypes.Numeric): - kwargs["precision"] = numericprec + kwargs["precision"] = int(numericprec) if not issubclass(coltype, sqltypes.Float): - kwargs["scale"] = numericscale + kwargs["scale"] = int(numericscale) coltype = coltype(**kwargs) diff --git a/sqlalchemy_iris/iris.py b/sqlalchemy_iris/iris.py index a473c15..7bc563a 100644 --- a/sqlalchemy_iris/iris.py +++ b/sqlalchemy_iris/iris.py @@ -4,9 +4,7 @@ class IRISDialect_iris(IRISDialect): driver = "iris" - def create_connect_args(self, url): - opts = dict(url.query) - return ([], opts) + supports_statement_cache = True dialect = IRISDialect_iris diff --git a/sqlalchemy_iris/requirements.py b/sqlalchemy_iris/requirements.py index 3bee82f..9150a5e 100644 --- a/sqlalchemy_iris/requirements.py +++ b/sqlalchemy_iris/requirements.py @@ -211,7 +211,7 @@ def autoincrement_insert(self): """target platform generates new surrogate integer primary key values when insert() is executed, excluding the pk column.""" - return exclusions.closed() + return exclusions.open() @property def fetch_rows_post_commit(self): @@ -367,8 +367,7 @@ def reflects_pk_names(self): @property def table_reflection(self): """target database has general support for table reflection""" - # return exclusions.open() - return exclusions.closed() + return exclusions.open() @property def reflect_tables_no_columns(self): @@ -1124,5 +1123,5 @@ def autoincrement_without_sequence(self): """If autoincrement=True on a column does not require an explicit sequence. This should be false only for oracle. """ - # return exclusions.open() - return exclusions.closed() + return exclusions.open() + # return exclusions.closed() diff --git a/test/conftest.py b/test/conftest.py index 968009b..9f8b22e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,8 @@ from sqlalchemy.dialects import registry import pytest -registry.register( - "iris.iris", "sqlalchemy_iris", "IRISDialect" -) +registry.register("iris.iris", "sqlalchemy_iris.iris", "IRISDialect_iris") pytest.register_assert_rewrite("sqlalchemy.testing.assertions") -from sqlalchemy.testing.plugin.pytestplugin import * \ No newline at end of file +from sqlalchemy.testing.plugin.pytestplugin import * # noqa diff --git a/test/test_suite.py b/test/test_suite.py index 8a7adec..1bb4ece 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1,8 +1,12 @@ -from sqlalchemy.testing.suite.test_reflection import QuotedNameArgumentTest as _QuotedNameArgumentTest +from sqlalchemy.testing.suite import QuotedNameArgumentTest as _QuotedNameArgumentTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest from sqlalchemy.testing.suite import CompoundSelectTest as _CompoundSelectTest +from sqlalchemy.testing import fixtures, AssertsExecutionResults, AssertsCompiledSQL +from sqlalchemy import testing +from sqlalchemy import Table, Column, Integer, String, select import pytest -from sqlalchemy.testing.suite import * # noqa +from sqlalchemy.testing.suite import * # noqa class CompoundSelectTest(_CompoundSelectTest): @@ -14,3 +18,63 @@ def test_limit_offset_aliased_selectable_in_unions(self): @pytest.mark.skip() class QuotedNameArgumentTest(_QuotedNameArgumentTest): pass + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + + def test_simple_offset_no_order(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + self._assert_result( + connection, + select(table).offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.combinations( + ([(2, 0), (2, 1), (3, 2)]), + ([(2, 1), (2, 0), (3, 2)]), + ([(3, 1), (2, 1), (3, 1)]), + argnames="cases", + ) + def test_simple_limit_offset_no_order(self, connection, cases): + table = self.tables.some_table + connection = connection.execution_options(compiled_cache={}) + + assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)] + + for limit, offset in cases: + expected = assert_data[offset: offset + limit] + self._assert_result( + connection, + select(table).limit(limit).offset(offset), + expected, + ) + + +class MiscTest(AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest): + + __backend__ = True + + __only_on__ = "iris" + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + # def test_compile(self): + # table = self.tables.some_table + + # stmt = select(table.c.id, table.c.x).offset(20).limit(10) +