From 72dcf5b831ed6fa1962cb4cae84a0cf0993d9200 Mon Sep 17 00:00:00 2001 From: Dmitry Maslennikov Date: Thu, 16 Jan 2025 23:34:06 +1100 Subject: [PATCH] some fixes for different tests --- django_iris/__init__.py | 44 +++++++++++++++- django_iris/compiler.py | 104 ++++++++++++++++++++++++++++++++++---- django_iris/features.py | 32 ++++++++++-- django_iris/operations.py | 90 +++++++++++++++++++-------------- requirements.txt | 2 +- setup.cfg | 6 ++- 6 files changed, 221 insertions(+), 57 deletions(-) diff --git a/django_iris/__init__.py b/django_iris/__init__.py index e5bfab7..7cd5153 100644 --- a/django_iris/__init__.py +++ b/django_iris/__init__.py @@ -2,7 +2,14 @@ from django.db.models.functions.datetime import Now from django.db.models.expressions import Exists, Func, Value, Col, OrderBy from django.db.models.functions.text import Chr, ConcatPair, StrIndex +from django.db.models.functions import Cast from django.db.models.fields import TextField, CharField +from django.db.models.lookups import BuiltinLookup +from django.db.models.fields.json import ( + KeyTransform, + KeyTransformExact, + compile_json_path, +) from django_iris.compiler import SQLCompiler @@ -65,7 +72,8 @@ def convert_streams(expressions): @as_intersystems(Exists) def exists_as_intersystems(self, compiler, connection, template=None, **extra_context): - template = "(SELECT COUNT(*) FROM (%(subquery)s))" + # template = "(SELECT COUNT(*) FROM (%(subquery)s))" + template = "EXISTS %(subquery)s" return self.as_sql(compiler, connection, template, **extra_context) @@ -151,3 +159,37 @@ def orderby_as_intersystems(self, compiler, connection, **extra_context): # IRIS does not support order NULL copy.nulls_first = copy.nulls_last = False return copy.as_sql(compiler, connection, **extra_context) + + +@as_intersystems(KeyTransformExact) +def json_KeyTransformExact_as_intersystems(self, compiler, connection): + return self.as_sql(compiler, connection) + + +@as_intersystems(KeyTransform) +def json_KeyTransform_as_intersystems(self, compiler, connection): + # breakpoint() + # json_path = compile_json_path(key_transforms) + return f"{self.field.name}__{self.key_name}", [] + + +@as_intersystems(BuiltinLookup) +def BuiltinLookup_as_intersystems(self, compiler, connection): + sql, params = self.as_sql(compiler, connection) + if compiler.in_get_order_by: + return "CASE WHEN %s THEN 1 ELSE 0 END" % (sql,), params + # if not compiler.in_get_select and not compiler.in_get_order_by: + return sql, params + + +@as_intersystems(Cast) +def cast_as_intersystems(self, compiler, connection, **extra_context): + if hasattr(self.source_expressions[0], "lookup_name"): + if self.source_expressions[0].lookup_name in ["gt", "gte", "lt", "lte"]: + return self.as_sql( + compiler, + connection, + template="CASE WHEN %(expressions)s THEN 1 ELSE 0 END", + **extra_context, + ) + return self.as_sql(compiler, connection, **extra_context) diff --git a/django_iris/compiler.py b/django_iris/compiler.py index 9016661..78666da 100644 --- a/django_iris/compiler.py +++ b/django_iris/compiler.py @@ -1,15 +1,81 @@ +import itertools + from django.core.exceptions import EmptyResultSet, FullResultSet -from django.db.models.expressions import Col, Value +from django.db.models.expressions import RawSQL +from django.db.models.sql.where import AND from django.db.models.sql import compiler +from django.db.models.fields.json import KeyTransform +from django.db.models.expressions import DatabaseDefault + + +class Flag: + value = False + + def __init__(self, value): + self.value = value + + def __enter__(self): + self.value = True + return True + + def __exit__(self, *args): + self.value = False + + def __bool__(self): + return self.value + class SQLCompiler(compiler.SQLCompiler): + in_get_select = Flag(False) + in_get_order_by = Flag(False) +# def get_from_clause(self): +# result, params = super().get_from_clause() +# jsoncolumns = {} +# for column in self.query.select + tuple( +# [column for column in (col.lhs for col in self.query.where.children)] +# ): +# if isinstance(column, KeyTransform): +# if column.field.name not in jsoncolumns: +# jsoncolumns[column.field.name] = {} +# jsoncolumns[column.field.name][ +# column.key_name +# ] = f"{column.field.name}__{column.key_name}" +# [] +# for field in jsoncolumns: +# self.query.where.add(RawSQL("%s is not null" % (field,), []), AND) +# cols = ", ".join( +# [ +# f"{jsoncolumns[field][col_name]} VARCHAR PATH '$.{col_name}'" +# for col_name in jsoncolumns[field] +# ] +# ) +# result.append( +# f""" +# , JSON_TABLE("{column.field.name}", '$' COLUMNS( +# {cols} +# )) +# """ +# ) + +# if "model_fields_nullablejsonmodel" in self.query.alias_map: +# breakpoint() +# return result, params + + def get_select(self, with_col_aliases=False): + with self.in_get_select: + return super().get_select(with_col_aliases) + + def get_order_by(self): + with self.in_get_order_by: + return super().get_order_by() def as_sql(self, with_limits=True, with_col_aliases=False): with_limit_offset = (with_limits or self.query.is_sliced) and ( self.query.high_mark is not None or self.query.low_mark > 0 ) if self.query.select_for_update or not with_limit_offset: - return super().as_sql(with_limits, with_col_aliases) + query, params = super().as_sql(with_limits, with_col_aliases) + return query, params try: extra_select, order_by, group_by = self.pre_sql_setup() @@ -79,7 +145,10 @@ def as_sql(self, with_limits=True, with_col_aliases=False): order_by_result = "ORDER BY %s" % first_col if offset: - out_cols.append("ROW_NUMBER() %s AS row_number" % ("OVER (%s)" % order_by_result if order_by_result else "")) + out_cols.append( + "ROW_NUMBER() %s AS row_number" + % ("OVER (%s)" % order_by_result if order_by_result else "") + ) result += [", ".join(out_cols), "FROM", *from_] params.extend(f_params) @@ -154,19 +223,32 @@ def as_sql(self, with_limits=True, with_col_aliases=False): ), tuple(sub_params + params) if offset: - query = "SELECT * FROM (%s) WHERE row_number between %d AND %d ORDER BY row_number" % ( - query, - offset, - limit, + query = ( + "SELECT * FROM (%s) WHERE row_number between %d AND %d ORDER BY row_number" + % ( + query, + offset, + limit, + ) ) return query, tuple(params) - except: - return super().as_sql(with_limits, with_col_aliases) + except Exception: + query, params = super().as_sql(with_limits, with_col_aliases) + return query, params class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): - + def as_sql(self): + + if self.query.fields: + fields = self.query.fields + self.query.fields = [ + field + for field in fields + if not isinstance(self.pre_save_val(field, self.query.objs[0]), DatabaseDefault) + ] + if self.query.fields: return super().as_sql() @@ -195,7 +277,7 @@ def as_sql(self): if self.connection._disable_constraint_checking: sql = "UPDATE %%NOCHECK" + sql[6:] return sql, params - + class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): pass diff --git a/django_iris/features.py b/django_iris/features.py index 22d3a2b..aca93e4 100644 --- a/django_iris/features.py +++ b/django_iris/features.py @@ -84,7 +84,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): only_supports_unbounded_with_preceding_and_following = False # Does the backend support JSONField? - supports_json_field = False + supports_json_field = True # Can the backend introspect a JSONField? can_introspect_json_field = False # Does the backend support primitives in JSONField? @@ -106,13 +106,15 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_collation_on_charfield = True supports_collation_on_textfield = False + supports_boolean_expr_in_select_clause = False + # Collation names for use by the Django test suite. test_collations = { # "ci": None, # Case-insensitive. "cs": "EXACT", # Case-sensitive. # "non_default": None, # Non-default. # "swedish_ci": None, # Swedish case-insensitive. - "virtual": None + "virtual": None, } @cached_property @@ -162,12 +164,23 @@ def django_test_skips(self): # "datetimes.tests.DateTimesTests.test_datetimes_ambiguous_and_invalid_times", # }, "IRIS does not have check contsraints": { + "model_fields.test_jsonfield.JSONFieldTests.test_db_check_constraints", "constraints.tests.CheckConstraintTests.test_validate", "constraints.tests.CheckConstraintTests.test_validate_boolean_expressions", "constraints.tests.UniqueConstraintTests.test_validate_expression_condition", }, + "Does not support expressions in default": { + "field_defaults.tests.DefaultTests.test_full_clean", + "basic.tests.ModelInstanceCreationTests.test_save_primary_with_db_default", + }, + "Regex not supported": { + "lookup.tests.LookupTests.test_regex", + "lookup.tests.LookupTests.test_regex_backreferencing", + "lookup.tests.LookupTests.test_regex_non_ascii", + "lookup.tests.LookupTests.test_regex_non_string", + "lookup.tests.LookupTests.test_regex_null", + }, } - ) return skips @@ -186,6 +199,19 @@ def django_test_skips(self): "schema.tests.SchemaTests.test_alter_text_field_to_time_field", "schema.tests.SchemaTests.test_alter_text_field_to_datetime_field", "schema.tests.SchemaTests.test_alter_text_field_to_date_field", + # + "lookup.tests.LookupQueryingTests.test_filter_exists_lhs", + "lookup.tests.LookupQueryingTests.test_filter_lookup_lhs", + "lookup.tests.LookupQueryingTests.test_filter_subquery_lhs", + "lookup.tests.LookupQueryingTests.test_filter_wrapped_lookup_lhs", + "lookup.tests.LookupTests.test_lookup_rhs", + + "lookup.tests.LookupTests.test_regex", + "lookup.tests.LookupTests.test_regex_backreferencing", + "lookup.tests.LookupTests.test_regex_non_ascii", + "lookup.tests.LookupTests.test_regex_non_string", + "lookup.tests.LookupTests.test_regex_null", + } # django_test_skips["IRIS Bugs"] = django_test_expected_failures diff --git a/django_iris/operations.py b/django_iris/operations.py index 5a06f5c..7e00211 100644 --- a/django_iris/operations.py +++ b/django_iris/operations.py @@ -3,30 +3,31 @@ from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta from django.utils import timezone -from itertools import chain -from datetime import date, datetime,timedelta -from django.utils.encoding import force_str +from datetime import date, datetime, timedelta from django.utils.dateparse import parse_date, parse_datetime, parse_time +from django.db.models.expressions import RawSQL, ExpressionWrapper, Exists +from django.db.models.lookups import Lookup +from django.db.models.sql.where import WhereNode + try: - from django.db.backends.base.base import timezone_constructor # Django 4.2 + from django.db.backends.base.base import timezone_constructor # Django 4.2 except ImportError: - from django.utils.timezone import timezone as timezone_constructor # Django 5+ + from django.utils.timezone import timezone as timezone_constructor # Django 5+ from .utils import BulkInsertMapper + class DatabaseOperations(BaseDatabaseOperations): compiler_module = "django_iris.compiler" - cast_data_types = { - "TextField": "VARCHAR" - } + cast_data_types = {"TextField": "VARCHAR"} def quote_name(self, name): if name.startswith('"') and name.endswith('"'): return name # Quoting once is enough. - if '.' in name: - return ".".join(['"%s"' % n for n in name.split('.')]) + if "." in name: + return ".".join(['"%s"' % n for n in name.split(".")]) return '"%s"' % name # def last_insert_id(self, cursor, table_name, pk_name): @@ -65,20 +66,20 @@ def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False ) for table_name in tables ) - + return sql def no_limit_value(self): return None def limit_offset_sql(self, low_mark, high_mark): - return '' + return "" def adapt_datetimefield_value(self, value): if value is None: - return None # Expression values are adapted by the database. + return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value if timezone.is_aware(value): @@ -93,9 +94,9 @@ def adapt_datetimefield_value(self, value): return str(value) value = int(value.timestamp() * 1000000) if value >= 0: - value += 2 ** 60 + value += 2**60 else: - value += -(2 ** 61 * 3) + value += -(2**61 * 3) return str(value) # return str(value).split("+")[0] @@ -112,25 +113,23 @@ def get_db_converters(self, expression): elif internal_type == "BooleanField": converters.append(self.convert_booleanfield_value) return converters - + def convert_booleanfield_value(self, value, expression, connection): if value in (0, 1): value = bool(value) return value - def convert_datetimefield_value(self, value, expression, connection): - original = value if value is not None: if isinstance(value, int): if value > 0: - value -= 2 ** 60 + value -= 2**60 else: - value -= -(2 ** 61 * 3) + value -= -(2**61 * 3) value = value / 1000000 value = datetime.fromtimestamp(value) elif isinstance(value, str): - if value != '': + if value != "": value = parse_datetime(value) if value is not None: if settings.USE_TZ and not timezone.is_aware(value): @@ -142,8 +141,8 @@ def convert_datefield_value(self, value, expression, connection): if not isinstance(value, type(datetime.date)): try: value = int(value) - value=date.fromordinal(672046+value) - except: + value = date.fromordinal(672046 + value) + except Exception: value = parse_date(value) return value @@ -154,15 +153,21 @@ def convert_timefield_value(self, value, expression, connection): value = int(value) value = timedelta(seconds=value) value = (datetime.min + value).time() - except: + except Exception: value = parse_time(value) return value def conditional_expression_supported_in_where_clause(self, expression): + if isinstance(expression, (Exists, Lookup, WhereNode)): + return True + if isinstance(expression, ExpressionWrapper) and expression.conditional: + return self.conditional_expression_supported_in_where_clause(expression.expression) + if isinstance(expression, RawSQL) and expression.conditional: + return True return False def adapt_datefield_value(self, value): - if value == None: + if value is None: return None return str(value) @@ -189,7 +194,7 @@ def adapt_timefield_value(self, value): # return "%s" def lookup_cast(self, lookup_type, internal_type=None): - if lookup_type in ('TEXT', 'LONG BINARY'): + if lookup_type in ("TEXT", "LONG BINARY"): return "CONVERT(VARCHAR, %s)" return "%s" @@ -199,15 +204,15 @@ def max_name_length(self): is no limit. """ return 60 - + def date_extract_sql(self, lookup_type, sql, params): if lookup_type == "week_day": - param = 'dw' + param = "dw" elif lookup_type == "iso_week_day": - param = 'dw' + param = "dw" return f"DATEPART({param}, {sql}) - 1", params elif lookup_type == "iso_year": - param = 'yyyy' + param = "yyyy" else: param = lookup_type return f"DATEPART({param}, CAST({sql} AS TIMESTAMP))", params @@ -239,7 +244,7 @@ def _convert_sql_to_tz(self, sql, params, tzname): diff, *params, ) - except: + except Exception: pass return sql, params @@ -254,16 +259,15 @@ def datetime_cast_time_sql(self, sql, params, tzname): def date_trunc_sql(self, lookup_type, sql, params, tzname=None): sql, params = self._convert_sql_to_tz(sql, params, tzname) sql = f"CAST({sql} as DATE)" - - if lookup_type == 'year': + + if lookup_type == "year": return f"CAST(TO_CHAR(DATE({sql}), 'YYYY-01-01') AS DATE)", params - if lookup_type == 'month': + if lookup_type == "month": return f"CAST(TO_CHAR(DATE({sql}), 'YYYY-MM-01') AS DATE)", params if lookup_type == "week": return ( f"CAST(DATEADD(DAY, - ((DATEPART(WEEKDAY, {sql}) + 5) # 7 ), {sql}) AS DATE)" - ), (*params, ) - + ), (*params,) return f"DATE({sql})", params @@ -288,7 +292,7 @@ def datetime_trunc_sql(self, lookup_type, sql, params, tzname=None): else: format_str = "".join(format[:i] + format_def[i:]) return f"CAST(TO_CHAR({sql}, %s) AS TIMESTAMP)", (*params, format_str) - + return sql, params def time_trunc_sql(self, lookup_type, sql, params, tzname=None): @@ -311,7 +315,10 @@ def last_executed_query(self, cursor, sql, params): last_params = cursor._params.collect() if last_params and len(last_params) > 0: for i in range(len(last_params)): - statement = statement.replace(f":%qpar({i+1})", 'NULL' if last_params[i] is None else repr(last_params[i])) + statement = statement.replace( + f":%qpar({i+1})", + "NULL" if last_params[i] is None else repr(last_params[i]), + ) statement = statement.replace(" . ", ".") statement = statement.replace(" ,", ",") statement = statement.replace(" )", ")") @@ -354,3 +361,8 @@ def combine_duration_expression(self, connector, sub_expressions): if len(sub_expressions) > 3: raise ValueError("Too many params for timedelta operations.") return "DATEADD(%s)" % ", ".join(fn_params) + + def regex_lookup(self, lookup_type): + raise NotImplementedError( + "IRIS does not support regex" + ) diff --git a/requirements.txt b/requirements.txt index 8ed432e..8eff7d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -https://github.com/intersystems-community/intersystems-irispython/releases/download/3.7.3/intersystems_iris-3.7.3-py3-none-any.whl \ No newline at end of file +https://github.com/intersystems-community/intersystems-irispython/releases/download/3.8.0/intersystems_iris-3.8.0-py3-none-any.whl \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index fe6cc5e..f8597bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = django-iris -version = 0.2.5 +version = 0.3.0 url = https://github.com/caretdev/django-iris maintainer = CaretDev maintainer_email = dmitry@caretdev.com @@ -11,7 +11,7 @@ long_description_content_type = text/markdown classifiers = Development Status :: 5 - Production/Stable Framework :: Django - Framework :: Django :: 4.0 + Framework :: Django :: 5.0 License :: OSI Approved :: MIT License Operating System :: OS Independent Programming Language :: Python @@ -19,6 +19,8 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 project_urls = Source = https://github.com/caretdev/django-iris Tracker = https://github.com/caretdev/django-iris/issues