diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 98bdc51..3bb1c77 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -73,7 +73,7 @@ jobs: run: | uv sync --locked --no-install-package=django uv pip install "${{ matrix.django }}" - - name: Run tests + - name: Run tests on PostgreSQL env: DB_SETTINGS: >- { @@ -86,5 +86,8 @@ jobs: } run: .venv/bin/pytest -v continue-on-error: ${{ matrix.python == env.allowed_python_failure }} + - name: Run tests on SQLite + run: .venv/bin/pytest -v + continue-on-error: ${{ matrix.python == env.allowed_python_failure }} - name: Check style run: .venv/bin/ruff check diff --git a/django_cte/__init__.py b/django_cte/__init__.py index aa34e97..da5a7be 100644 --- a/django_cte/__init__.py +++ b/django_cte/__init__.py @@ -1,3 +1,4 @@ -from .cte import CTEManager, CTEQuerySet, With # noqa +from .cte import CTE, with_cte, CTEManager, CTEQuerySet, With # noqa -__version__ = "1.3.3" +__version__ = "2.0.0" +__all__ = ["CTE", "with_cte"] diff --git a/django_cte/_deprecated.py b/django_cte/_deprecated.py new file mode 100644 index 0000000..7f5ca33 --- /dev/null +++ b/django_cte/_deprecated.py @@ -0,0 +1,138 @@ +try: + from warnings import deprecated +except ImportError: + from warnings import warn + + # Copied from Python 3.13, lightly modified for Python 3.9 compatibility. + # Can be removed when the oldest supported Python version is 3.13. + class deprecated: + """Indicate that a class, function or overload is deprecated. + + When this decorator is applied to an object, the type checker + will generate a diagnostic on usage of the deprecated object. + + Usage: + + @deprecated("Use B instead") + class A: + pass + + @deprecated("Use g instead") + def f(): + pass + + @overload + @deprecated("int support is deprecated") + def g(x: int) -> int: ... + @overload + def g(x: str) -> int: ... + + The warning specified by *category* will be emitted at runtime + on use of deprecated objects. For functions, that happens on calls; + for classes, on instantiation and on creation of subclasses. + If the *category* is ``None``, no warning is emitted at runtime. + The *stacklevel* determines where the + warning is emitted. If it is ``1`` (the default), the warning + is emitted at the direct caller of the deprecated object; if it + is higher, it is emitted further up the stack. + Static type checker behavior is not affected by the *category* + and *stacklevel* arguments. + + The deprecation message passed to the decorator is saved in the + ``__deprecated__`` attribute on the decorated object. + If applied to an overload, the decorator + must be after the ``@overload`` decorator for the attribute to + exist on the overload as returned by ``get_overloads()``. + + See PEP 702 for details. + + """ + def __init__( + self, + message: str, + /, + *, + category=DeprecationWarning, + stacklevel=1, + ): + if not isinstance(message, str): + raise TypeError( + f"Expected an object of type str for 'message', not {type(message).__name__!r}" + ) + self.message = message + self.category = category + self.stacklevel = stacklevel + + def __call__(self, arg, /): + # Make sure the inner functions created below don't + # retain a reference to self. + msg = self.message + category = self.category + stacklevel = self.stacklevel + if category is None: + arg.__deprecated__ = msg + return arg + elif isinstance(arg, type): + import functools + from types import MethodType + + original_new = arg.__new__ + + @functools.wraps(original_new) + def __new__(cls, /, *args, **kwargs): + if cls is arg: + warn(msg, category=category, stacklevel=stacklevel + 1) + if original_new is not object.__new__: + return original_new(cls, *args, **kwargs) + # Mirrors a similar check in object.__new__. + elif cls.__init__ is object.__init__ and (args or kwargs): + raise TypeError(f"{cls.__name__}() takes no arguments") + else: + return original_new(cls) + + arg.__new__ = staticmethod(__new__) + + original_init_subclass = arg.__init_subclass__ + # We need slightly different behavior if __init_subclass__ + # is a bound method (likely if it was implemented in Python) + if isinstance(original_init_subclass, MethodType): + original_init_subclass = original_init_subclass.__func__ + + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + + arg.__init_subclass__ = classmethod(__init_subclass__) + # Or otherwise, which likely means it's a builtin such as + # object's implementation of __init_subclass__. + else: + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + + arg.__init_subclass__ = __init_subclass__ + + arg.__deprecated__ = __new__.__deprecated__ = msg + __init_subclass__.__deprecated__ = msg + return arg + elif callable(arg): + import functools + import inspect + + @functools.wraps(arg) + def wrapper(*args, **kwargs): + warn(msg, category=category, stacklevel=stacklevel + 1) + return arg(*args, **kwargs) + + if inspect.iscoroutinefunction(arg): + wrapper = inspect.markcoroutinefunction(wrapper) + + arg.__deprecated__ = wrapper.__deprecated__ = msg + return wrapper + else: + raise TypeError( + "@deprecated decorator with non-None category must be applied to " + f"a class or callable, not {arg!r}" + ) diff --git a/django_cte/cte.py b/django_cte/cte.py index 5e57fb9..2b0b1db 100644 --- a/django_cte/cte.py +++ b/django_cte/cte.py @@ -1,17 +1,38 @@ -from django.db.models import Manager +from copy import copy + +from django.db.models import Manager, sql from django.db.models.expressions import Ref from django.db.models.query import Q, QuerySet, ValuesIterable from django.db.models.sql.datastructures import BaseTable +from .jitmixin import jit_mixin from .join import QJoin, INNER from .meta import CTEColumnRef, CTEColumns from .query import CTEQuery +from ._deprecated import deprecated + +__all__ = ["CTE", "with_cte"] + + +def with_cte(*ctes, select): + """Add Common Table Expression(s) (CTEs) to a model or queryset -__all__ = ["With", "CTEManager", "CTEQuerySet"] + :param *ctes: One or more CTE objects. + :param select: A model class, queryset, or CTE to use as the base + query to which CTEs are attached. + :returns: A queryset with the given CTE added to it. + """ + if isinstance(select, CTE): + select = select.queryset() + elif not isinstance(select, QuerySet): + select = select._default_manager.all() + jit_mixin(select.query, CTEQuery) + select.query._with_ctes += ctes + return select -class With(object): - """Common Table Expression query object: `WITH ...` +class CTE: + """Common Table Expression :param queryset: A queryset to use as the body of the CTE. :param name: Optional name parameter for the CTE (default: "cte"). @@ -41,7 +62,7 @@ def __repr__(self): @classmethod def recursive(cls, make_cte_queryset, name="cte", materialized=False): - """Recursive Common Table Expression: `WITH RECURSIVE ...` + """Recursive Common Table Expression :param make_cte_queryset: Function taking a single argument (a not-yet-fully-constructed cte object) and returning a `QuerySet` @@ -58,10 +79,11 @@ def recursive(cls, make_cte_queryset, name="cte", materialized=False): def join(self, model_or_queryset, *filter_q, **filter_kw): """Join this CTE to the given model or queryset - This CTE will be refernced by the returned queryset, but the + This CTE will be referenced by the returned queryset, but the + corresponding `WITH ...` statement will not be prepended to the - queryset's SQL output; use `.with_cte(cte)` to - achieve that outcome. + queryset's SQL output; use `with_cte(cte, select=cte.join(...))` + to achieve that outcome. :param model_or_queryset: Model class or queryset to which the CTE should be joined. @@ -96,15 +118,15 @@ def queryset(self): This CTE will be referenced by the returned queryset, but the corresponding `WITH ...` statement will not be prepended to the - queryset's SQL output; use `.with_cte(cte)` to - achieve that outcome. + queryset's SQL output; use `with_cte(cte, select=cte)` to do + that. :returns: A queryset. """ cte_query = self.query qs = cte_query.model._default_manager.get_queryset() - query = CTEQuery(cte_query.model) + query = jit_mixin(sql.Query(cte_query.model), CTEQuery) query.join(BaseTable(self.name, None)) query.default_cols = cte_query.default_cols query.deferred_loading = cte_query.deferred_loading @@ -130,26 +152,38 @@ def _resolve_ref(self, name): return Ref(name, self.query.resolve_ref(name)) return self.query.resolve_ref(name) + def resolve_expression(self, *args, **kw): + if self.query is None: + raise ValueError("Cannot resolve recursive CTE without a query.") + clone = copy(self) + clone.query = clone.query.resolve_expression(*args, **kw) + return clone + + +@deprecated("Use `django_cte.CTE` instead.") +class With(CTE): + + @staticmethod + @deprecated("Use `django_cte.CTE.recursive` instead.") + def recursive(*args, **kw): + return CTE.recursive(*args, **kw) + +@deprecated("CTEQuerySet is deprecated. " + "CTEs can now be applied to any queryset using `with_cte()`") class CTEQuerySet(QuerySet): """QuerySet with support for Common Table Expressions""" def __init__(self, model=None, query=None, using=None, hints=None): # Only create an instance of a Query if this is the first invocation in # a query chain. - if query is None: - query = CTEQuery(model) super(CTEQuerySet, self).__init__(model, query, using, hints) + jit_mixin(self.query, CTEQuery) + @deprecated("Use `django_cte.with_cte(cte, select=...)` instead.") def with_cte(self, cte): - """Add a Common Table Expression to this queryset - - The CTE `WITH ...` clause will be added to the queryset's SQL - output (after other CTEs that have already been added) so it - can be referenced in annotations, filters, etc. - """ qs = self._clone() - qs.query._with_ctes.append(cte) + qs.query._with_ctes += cte, return qs def as_manager(cls): @@ -161,36 +195,9 @@ def as_manager(cls): as_manager.queryset_only = True as_manager = classmethod(as_manager) - def _combinator_query(self, *args, **kw): - clone = super()._combinator_query(*args, **kw) - if clone.query.combinator: - ctes = clone.query._with_ctes = [] - seen = {} - for query in clone.query.combined_queries: - for cte in getattr(query, "_with_ctes", []): - if seen.get(cte.name) is cte: - continue - if cte.name in seen: - raise ValueError( - f"Found two or more CTEs named '{cte.name}'. " - "Hint: assign a unique name to each CTE." - ) - ctes.append(cte) - seen[cte.name] = cte - if ctes: - def without_ctes(query): - if getattr(query, "_with_ctes", None): - query = query.clone() - query._with_ctes = [] - return query - - clone.query.combined_queries = [ - without_ctes(query) - for query in clone.query.combined_queries - ] - return clone - +@deprecated("CTEMAnager is deprecated. " + "CTEs can now be applied to any queryset using `with_cte()`") class CTEManager(Manager.from_queryset(CTEQuerySet)): """Manager for models that perform CTE queries""" diff --git a/django_cte/expressions.py b/django_cte/expressions.py deleted file mode 100644 index 850ccd7..0000000 --- a/django_cte/expressions.py +++ /dev/null @@ -1,45 +0,0 @@ -from django.db.models import Subquery - - -class CTESubqueryResolver(object): - - def __init__(self, annotation): - self.annotation = annotation - - def resolve_expression(self, *args, **kw): - # source: django.db.models.expressions.Subquery.resolve_expression - # --- begin copied code (lightly adapted) --- # - - # Need to recursively resolve these. - def resolve_all(child): - if hasattr(child, 'children'): - [resolve_all(_child) for _child in child.children] - if hasattr(child, 'rhs'): - child.rhs = resolve(child.rhs) - - def resolve(child): - if hasattr(child, 'resolve_expression'): - resolved = child.resolve_expression(*args, **kw) - # Add table alias to the parent query's aliases to prevent - # quoting. - if hasattr(resolved, 'alias') and \ - resolved.alias != resolved.target.model._meta.db_table: - get_query(clone).external_aliases.add(resolved.alias) - return resolved - return child - - # --- end copied code --- # - - def get_query(clone): - return clone.query - - # NOTE this uses the old (pre-Django 3) way of resolving. - # Should a different technique should be used on Django 3+? - clone = self.annotation.resolve_expression(*args, **kw) - if isinstance(self.annotation, Subquery): - for cte in getattr(get_query(clone), '_with_ctes', []): - resolve_all(cte.query.where) - for key, value in cte.query.annotations.items(): - if isinstance(value, Subquery): - cte.query.annotations[key] = resolve(value) - return clone diff --git a/django_cte/jitmixin.py b/django_cte/jitmixin.py new file mode 100644 index 0000000..99f7353 --- /dev/null +++ b/django_cte/jitmixin.py @@ -0,0 +1,28 @@ +def jit_mixin(obj, mixin): + """Apply mixin to object and return the object""" + if not isinstance(obj, mixin): + obj.__class__ = jit_mixin_type(obj.__class__, mixin) + return obj + + +def jit_mixin_type(base, *mixins): + assert not issubclass(base, mixins), (base, mixins) + mixed = _mixin_cache.get((base, mixins)) + if mixed is None: + prefix = "".join(m._jit_mixin_prefix for m in mixins) + name = f"{prefix}{base.__name__}" + mixed = _mixin_cache[(base, mixins)] = type(name, (*mixins, base), { + "_jit_mixin_base": getattr(base, "_jit_mixin_base", base), + "_jit_mixins": mixins + getattr(base, "_jit_mixins", ()), + }) + return mixed + + +_mixin_cache = {} + + +class JITMixin: + + def __reduce__(self): + # make JITMixin classes pickleable + return (jit_mixin_type, (self._jit_mixin_base, *self._jit_mixins)) diff --git a/django_cte/join.py b/django_cte/join.py index bc18d26..00664cb 100644 --- a/django_cte/join.py +++ b/django_cte/join.py @@ -1,7 +1,7 @@ from django.db.models.sql.constants import INNER -class QJoin(object): +class QJoin: """Join clause with join condition from Q object clause :param parent_alias: Alias of parent table. diff --git a/django_cte/meta.py b/django_cte/meta.py index 0cb6ed8..b1865be 100644 --- a/django_cte/meta.py +++ b/django_cte/meta.py @@ -3,7 +3,7 @@ from django.db.models.expressions import Col, Expression -class CTEColumns(object): +class CTEColumns: def __init__(self, cte): self._cte = weakref.ref(cte) @@ -87,7 +87,7 @@ def resolve_expression(self, query=None, allow_joins=True, reuse=None, clone._alias = self._alias or query.table_map.get( self.cte_name, [self.cte_name])[0] return clone - return super(CTEColumnRef, self).resolve_expression( + return super().resolve_expression( query, allow_joins, reuse, summarize, for_save) def relabeled_clone(self, change_map): @@ -95,7 +95,7 @@ def relabeled_clone(self, change_map): self.cte_name not in change_map and self._alias not in change_map ): - return super(CTEColumnRef, self).relabeled_clone(change_map) + return super().relabeled_clone(change_map) clone = self.copy() if self.cte_name in change_map: diff --git a/django_cte/query.py b/django_cte/query.py index cbb1801..cbba107 100644 --- a/django_cte/query.py +++ b/django_cte/query.py @@ -1,194 +1,168 @@ import django from django.core.exceptions import EmptyResultSet -from django.db import connections -from django.db.models.sql import DeleteQuery, Query, UpdateQuery -from django.db.models.sql.compiler import ( - SQLCompiler, - SQLDeleteCompiler, - SQLUpdateCompiler, -) from django.db.models.sql.constants import LOUTER -from .expressions import CTESubqueryResolver +from .jitmixin import JITMixin, jit_mixin from .join import QJoin +# NOTE: it is currently not possible to execute delete queries that +# reference CTEs without patching `QuerySet.delete` (Django method) +# to call `self.query.chain(sql.DeleteQuery)` instead of +# `sql.DeleteQuery(self.model)` -class CTEQuery(Query): - """A Query which processes SQL compilation through the CTE compiler""" - - def __init__(self, *args, **kwargs): - super(CTEQuery, self).__init__(*args, **kwargs) - self._with_ctes = [] - - def combine(self, other, connector): - if other._with_ctes: - if self._with_ctes: - raise TypeError("cannot merge queries with CTEs on both sides") - self._with_ctes = other._with_ctes[:] - return super(CTEQuery, self).combine(other, connector) - - def get_compiler(self, using=None, connection=None, *args, **kwargs): - """ Overrides the Query method get_compiler in order to return - a CTECompiler. - """ - # Copy the body of this method from Django except the final - # return statement. We will ignore code coverage for this. - if using is None and connection is None: # pragma: no cover - raise ValueError("Need either using or connection") - if using: - connection = connections[using] - # Check that the compiler will be able to execute the query - for alias, aggregate in self.annotation_select.items(): - connection.ops.check_expression_support(aggregate) - # Instantiate the custom compiler. - klass = COMPILER_TYPES.get(self.__class__, CTEQueryCompiler) - return klass(self, connection, using, *args, **kwargs) - - def add_annotation(self, annotation, *args, **kw): - annotation = CTESubqueryResolver(annotation) - super(CTEQuery, self).add_annotation(annotation, *args, **kw) - - def __chain(self, _name, klass=None, *args, **kwargs): - klass = QUERY_TYPES.get(klass, self.__class__) - clone = getattr(super(CTEQuery, self), _name)(klass, *args, **kwargs) - clone._with_ctes = self._with_ctes[:] - return clone - - def chain(self, klass=None): - return self.__chain("chain", klass) +class CTEQuery(JITMixin): + """A Query mixin that processes SQL compilation through a CTE compiler""" + _jit_mixin_prefix = "CTE" + _with_ctes = () -class CTECompiler(object): - - @classmethod - def generate_sql(cls, connection, query, as_sql): - if not query._with_ctes: - return as_sql() + @property + def combined_queries(self): + return self.__dict__.get("combined_queries", ()) + @combined_queries.setter + def combined_queries(self, queries): ctes = [] - params = [] - for cte in query._with_ctes: - if django.VERSION > (4, 2): - _ignore_with_col_aliases(cte.query) - - alias = query.alias_map.get(cte.name) - should_elide_empty = ( - not isinstance(alias, QJoin) or alias.join_type != LOUTER - ) - - compiler = cte.query.get_compiler( - connection=connection, elide_empty=should_elide_empty - ) - - qn = compiler.quote_name_unless_alias - try: - cte_sql, cte_params = compiler.as_sql() - except EmptyResultSet: - # If the CTE raises an EmptyResultSet the SqlCompiler still - # needs to know the information about this base compiler - # like, col_count and klass_info. - as_sql() - raise - template = cls.get_cte_query_template(cte) - ctes.append(template.format(name=qn(cte.name), query=cte_sql)) - params.extend(cte_params) - - explain_attribute = "explain_info" - explain_info = getattr(query, explain_attribute, None) - explain_format = getattr(explain_info, "format", None) - explain_options = getattr(explain_info, "options", {}) - - explain_query_or_info = getattr(query, explain_attribute, None) - sql = [] - if explain_query_or_info: - sql.append( - connection.ops.explain_query_prefix( - explain_format, - **explain_options - ) - ) - # this needs to get set to None so that the base as_sql() doesn't - # insert the EXPLAIN statement where it would end up between the - # WITH ... clause and the final SELECT - setattr(query, explain_attribute, None) - - if ctes: - # Always use WITH RECURSIVE - # https://www.postgresql.org/message-id/13122.1339829536%40sss.pgh.pa.us - sql.extend(["WITH RECURSIVE", ", ".join(ctes)]) - base_sql, base_params = as_sql() + seen = {cte.name: cte for cte in self._with_ctes} + for query in queries: + for cte in getattr(query, "_with_ctes", ()): + if seen.get(cte.name) is cte: + continue + if cte.name in seen: + raise ValueError( + f"Found two or more CTEs named '{cte.name}'. " + "Hint: assign a unique name to each CTE." + ) + ctes.append(cte) + seen[cte.name] = cte + + if seen: + def without_ctes(query): + if getattr(query, "_with_ctes", None): + query = query.clone() + del query._with_ctes + return query + + self._with_ctes += tuple(ctes) + queries = tuple(without_ctes(q) for q in queries) + self.__dict__["combined_queries"] = queries + + def resolve_expression(self, *args, **kwargs): + clone = super().resolve_expression(*args, **kwargs) + clone._with_ctes = tuple( + cte.resolve_expression(*args, **kwargs) + for cte in clone._with_ctes + ) + return clone - if explain_query_or_info: - setattr(query, explain_attribute, explain_query_or_info) + def get_compiler(self, *args, **kwargs): + return jit_mixin(super().get_compiler(*args, **kwargs), CTECompiler) - sql.append(base_sql) - params.extend(base_params) - return " ".join(sql), tuple(params) + def chain(self, klass=None): + clone = jit_mixin(super().chain(klass), CTEQuery) + clone._with_ctes = self._with_ctes + return clone - @classmethod - def get_cte_query_template(cls, cte): - if cte.materialized: - return "{name} AS MATERIALIZED ({query})" - return "{name} AS ({query})" +def generate_cte_sql(connection, query, as_sql): + if not query._with_ctes: + return as_sql() + + ctes = [] + params = [] + for cte in query._with_ctes: + if django.VERSION > (4, 2): + _ignore_with_col_aliases(cte.query) + + alias = query.alias_map.get(cte.name) + should_elide_empty = ( + not isinstance(alias, QJoin) or alias.join_type != LOUTER + ) + + compiler = cte.query.get_compiler( + connection=connection, elide_empty=should_elide_empty + ) + + qn = compiler.quote_name_unless_alias + try: + cte_sql, cte_params = compiler.as_sql() + except EmptyResultSet: + # If the CTE raises an EmptyResultSet the SqlCompiler still + # needs to know the information about this base compiler + # like, col_count and klass_info. + as_sql() + raise + template = get_cte_query_template(cte) + ctes.append(template.format(name=qn(cte.name), query=cte_sql)) + params.extend(cte_params) + + explain_attribute = "explain_info" + explain_info = getattr(query, explain_attribute, None) + explain_format = getattr(explain_info, "format", None) + explain_options = getattr(explain_info, "options", {}) + + explain_query_or_info = getattr(query, explain_attribute, None) + sql = [] + if explain_query_or_info: + sql.append( + connection.ops.explain_query_prefix( + explain_format, + **explain_options + ) + ) + # this needs to get set to None so that the base as_sql() doesn't + # insert the EXPLAIN statement where it would end up between the + # WITH ... clause and the final SELECT + setattr(query, explain_attribute, None) -class CTEUpdateQuery(UpdateQuery, CTEQuery): - pass + if ctes: + # Always use WITH RECURSIVE + # https://www.postgresql.org/message-id/13122.1339829536%40sss.pgh.pa.us + sql.extend(["WITH RECURSIVE", ", ".join(ctes)]) + base_sql, base_params = as_sql() + if explain_query_or_info: + setattr(query, explain_attribute, explain_query_or_info) -class CTEDeleteQuery(DeleteQuery, CTEQuery): - pass + sql.append(base_sql) + params.extend(base_params) + return " ".join(sql), tuple(params) -QUERY_TYPES = { - Query: CTEQuery, - UpdateQuery: CTEUpdateQuery, - DeleteQuery: CTEDeleteQuery, -} +def get_cte_query_template(cte): + if cte.materialized: + return "{name} AS MATERIALIZED ({query})" + return "{name} AS ({query})" def _ignore_with_col_aliases(cte_query): if getattr(cte_query, "combined_queries", None): - for query in cte_query.combined_queries: - query.ignore_with_col_aliases = True - - -class CTEQueryCompiler(SQLCompiler): - - def as_sql(self, *args, **kwargs): - def _as_sql(): - return super(CTEQueryCompiler, self).as_sql(*args, **kwargs) - return CTECompiler.generate_sql(self.connection, self.query, _as_sql) - - def get_select(self, **kw): - if kw.get("with_col_aliases") \ - and getattr(self.query, "ignore_with_col_aliases", False): - kw.pop("with_col_aliases") - return super().get_select(**kw) + cte_query.combined_queries = tuple( + jit_mixin(q, NoAliasQuery) for q in cte_query.combined_queries + ) -class CTEUpdateQueryCompiler(SQLUpdateCompiler): +class CTECompiler(JITMixin): + """Mixin for django.db.models.sql.compiler.SQLCompiler""" + _jit_mixin_prefix = "CTE" def as_sql(self, *args, **kwargs): def _as_sql(): - return super(CTEUpdateQueryCompiler, self).as_sql(*args, **kwargs) - return CTECompiler.generate_sql(self.connection, self.query, _as_sql) + return super(CTECompiler, self).as_sql(*args, **kwargs) + return generate_cte_sql(self.connection, self.query, _as_sql) -class CTEDeleteQueryCompiler(SQLDeleteCompiler): +class NoAliasQuery(JITMixin): + """Mixin for django.db.models.sql.compiler.Query""" + _jit_mixin_prefix = "NoAlias" - # NOTE: it is currently not possible to execute delete queries that - # reference CTEs without patching `QuerySet.delete` (Django method) - # to call `self.query.chain(sql.DeleteQuery)` instead of - # `sql.DeleteQuery(self.model)` + def get_compiler(self, *args, **kwargs): + return jit_mixin(super().get_compiler(*args, **kwargs), NoAliasCompiler) - def as_sql(self, *args, **kwargs): - def _as_sql(): - return super(CTEDeleteQueryCompiler, self).as_sql(*args, **kwargs) - return CTECompiler.generate_sql(self.connection, self.query, _as_sql) +class NoAliasCompiler(JITMixin): + """Mixin for django.db.models.sql.compiler.SQLCompiler""" + _jit_mixin_prefix = "NoAlias" -COMPILER_TYPES = { - CTEUpdateQuery: CTEUpdateQueryCompiler, - CTEDeleteQuery: CTEDeleteQueryCompiler, -} + def get_select(self, *, with_col_aliases=False, **kw): + return super().get_select(**kw) diff --git a/django_cte/raw.py b/django_cte/raw.py index dcd1919..c1bbe2a 100644 --- a/django_cte/raw.py +++ b/django_cte/raw.py @@ -7,14 +7,14 @@ def raw_cte_sql(sql, params, refs): :returns: Object that can be passed to `With`. """ - class raw_cte_ref(object): + class raw_cte_ref: def __init__(self, output_field): self.output_field = output_field def get_source_expressions(self): return [] - class raw_cte_compiler(object): + class raw_cte_compiler: def __init__(self, connection): self.connection = connection @@ -25,8 +25,8 @@ def as_sql(self): def quote_name_unless_alias(self, name): return self.connection.ops.quote_name(name) - class raw_cte_queryset(object): - class query(object): + class raw_cte_queryset: + class query: @staticmethod def get_compiler(connection, *, elide_empty=None): return raw_cte_compiler(connection) diff --git a/docs/index.md b/docs/index.md index 7986ef9..cc56e2b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,48 +8,32 @@ for the duration of the query it is attached to. django-cte allows common table expressions to be attached to normal Django ORM queries. -## Prerequisite: A Model with a "CTEManager" - -The custom manager class, `CTEManager`, constructs `CTEQuerySet`s, which have -all of the same features as normal `QuerySet`s and also support CTE queries. - -```py -from django_cte import CTEManager - -class Order(Model): - objects = CTEManager() - id = AutoField(primary_key=True) - region = ForeignKey("Region", on_delete=CASCADE) - amount = IntegerField(default=0) - - class Meta: - db_table = "orders" -``` - - ## Simple Common Table Expressions -Simple CTEs are constructed using `With(...)`. A CTE can be joined to a model or -other `CTEQuerySet` using its `join(...)` method, which creates a new queryset -with a `JOIN` and `ON` condition. Finally, the CTE is added to the resulting -queryset using `with_cte(cte)`, which adds the `WITH` expression before the -main `SELECT` query. +See [Appendix A](#appendix-a-model-definitions-used-in-sample-code) for model +definitions used in sample code. + +Simple CTEs are constructed using `CTE(...)`. A CTE is added to a queryset using +`with_cte(cte, select=queryset)`, which adds the `WITH` expression before the +main `SELECT` query. A CTE can be joined to a model or other `QuerySet` using +its `.join(...)` method, which creates a new queryset with a `JOIN` and +`ON` condition. ```py -from django_cte import With +from django_cte import CTE, with_cte -cte = With( +cte = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")) ) -orders = ( - # FROM orders INNER JOIN cte ON orders.region_id = cte.region_id - cte.join(Order, region=cte.col.region_id) +orders = with_cte( + # WITH cte ... + cte, - # Add `WITH ...` before `SELECT ... FROM orders ...` - .with_cte(cte) + # SELECT ... FROM orders INNER JOIN cte ON orders.region_id = cte.region_id + select=cte.join(Order, region=cte.col.region_id) # Annotate each Order with a "region_total" .annotate(region_total=cte.col.total) @@ -78,7 +62,7 @@ FROM "orders" INNER JOIN "cte" ON "orders"."region_id" = "cte"."region_id" ``` -The `orders` query is a query set containing annotated `Order` objects, just as +The `orders` query is a queryset containing annotated `Order` objects, just as you would get from a query like `Order.objects.annotate(region_total=...)`. Each `Order` object will be annotated with a `region_total` attribute, which is populated with the value of the corresponding total from the joined CTE query. @@ -93,19 +77,9 @@ recursive CTEs to be included in the WITH block. ## Recursive Common Table Expressions Recursive CTE queries allow fundamentally new types of queries that are -not otherwise possible. First, a model for the example. - -```py -class Region(Model): - objects = CTEManager() - name = TextField(primary_key=True) - parent = ForeignKey("self", null=True, on_delete=CASCADE) - - class Meta: - db_table = "region" -``` +not otherwise possible. -Recursive CTEs are constructed using `With.recursive()`, which takes as its +Recursive CTEs are constructed using `CTE.recursive()`, which takes as its first argument a function that constructs and returns a recursive query. Recursive queries have two elements: first a non-recursive query element, and second a recursive query element. The second is typically attached to the first @@ -133,11 +107,11 @@ def make_regions_cte(cte): all=True, ) -cte = With.recursive(make_regions_cte) +cte = CTE.recursive(make_regions_cte) -regions = ( - cte.join(Region, name=cte.col.name) - .with_cte(cte) +regions = with_cte( + cte, + select=cte.join(Region, name=cte.col.name) .annotate( path=cte.col.path, depth=cte.col.depth, @@ -184,9 +158,9 @@ ORDER BY "path" ASC ## Named Common Table Expressions It is possible to add more than one CTE to a query. To do this, each CTE must -have a unique name. `With(queryset)` returns a CTE with the name `'cte'` by -default, but that can be overridden: `With(queryset, name='custom')` or -`With.recursive(make_queryset, name='custom')`. This allows each CTE to be +have a unique name. `CTE(queryset)` returns a CTE with the name `'cte'` by +default, but that can be overridden: `CTE(queryset, name='custom')` or +`CTE.recursive(make_queryset, name='custom')`. This allows each CTE to be referenced uniquely within a single query. Also note that a CTE may reference other CTEs in the same query. @@ -208,9 +182,9 @@ def make_root_mapping(rootmap): ), all=True, ) -rootmap = With.recursive(make_root_mapping, name="rootmap") +rootmap = CTE.recursive(make_root_mapping, name="rootmap") -totals = With( +totals = CTE( rootmap.join(Order, region_id=rootmap.col.name) .values( root=rootmap.col.root, @@ -221,11 +195,12 @@ totals = With( name="totals", ) -root_regions = ( - totals.join(Region, name=totals.col.root) - # Important: add both CTEs to the final query - .with_cte(rootmap) - .with_cte(totals) +root_regions = with_cte( + # Important: add both CTEs to the query + rootmap, + totals, + + select=totals.join(Region, name=totals.col.root) .annotate( # count of orders in this region and all subregions orders_count=totals.col.orders_count, @@ -276,16 +251,16 @@ INNER JOIN "totals" ON "region"."name" = "totals"."root" Sometimes it is useful to construct queries where the final `FROM` clause contains only common table expression(s). This is possible with -`With(...).queryset()`. +`CTE(...).queryset()`. Each returned row may be a model object: ```py -cte = With( +cte = CTE( Order.objects .annotate(region_parent=F("region__parent_id")), ) -orders = cte.queryset().with_cte(cte) +orders = with_cte(cte, select=cte.queryset()) ``` And the resulting SQL: @@ -311,7 +286,7 @@ FROM "cte" It is also possible to do the same with `values(...)` queries: ```py -cte = With( +cte = CTE( Order.objects .values( "region_id", @@ -319,7 +294,7 @@ cte = With( ) .distinct() ) -values = cte.queryset().with_cte(cte).filter(region_parent__isnull=False) +values = with_cte(cte, select=cte).filter(region_parent__isnull=False) ``` Which produces this SQL: @@ -339,55 +314,30 @@ FROM "cte" WHERE "cte"."region_parent" IS NOT NULL ``` - -## Custom QuerySets and Managers - -Custom `QuerySet`s that will be used in CTE queries should be derived from -`CTEQuerySet`. - -```py -class LargeOrdersQuerySet(CTEQuerySet): - def big_amounts(self): - return self.filter(amount__gt=100) - - -class Order(Model): - amount = models.IntegerField() - large = LargeOrdersQuerySet.as_manager() -``` - -Custom `CTEQuerySet`s can also be used with custom `CTEManager`s. - -```py -class CustomManager(CTEManager): - ... - - -class Order(Model): - large = CustomManager.from_queryset(LargeOrdersQuerySet)() - objects = CustomManager() -``` +You may have noticed that when a CTE is passed to the `select=...` argument as +in `with_cte(cte, select=cte)`, the `.queryset()` call is optional and may be +omitted. ## Experimental: Left Outer Join Django does not provide precise control over joins, but there is an experimental way to perform a `LEFT OUTER JOIN` with a CTE query using the `_join_type` -keyword argument of `With.join(...)`. +keyword argument of `CTE.join(...)`. ```py from django.db.models.sql.constants import LOUTER -totals = With( +totals = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")) .filter(total__gt=100) ) -orders = ( - totals +orders = with_cte( + totals, + select=totals .join(Order, region=totals.col.region_id, _join_type=LOUTER) - .with_cte(totals) .annotate(region_total=totals.col.total) ) ``` @@ -420,12 +370,13 @@ produce the desired SQL. ## Materialized CTE -Both PostgreSQL 12+ and sqlite 3.35+ supports `MATERIALIZED` keyword for CTE queries. -To enforce using of this keyword add `materialized` as a parameter of `With(..., materialized=True)`. +Both PostgreSQL 12+ and sqlite 3.35+ supports `MATERIALIZED` keyword for CTE +queries. To enforce usage of this keyword add `materialized` as a parameter of +`CTE(..., materialized=True)`. ```py -cte = With( +cte = CTE( Order.objects.values('id'), materialized=True ) @@ -457,7 +408,7 @@ A short example: from django.db.models import IntegerField, TextField from django_cte.raw import raw_cte_sql -cte = With(raw_cte_sql( +cte = CTE(raw_cte_sql( """ SELECT region_id, AVG(amount) AS avg_order FROM orders @@ -470,11 +421,11 @@ cte = With(raw_cte_sql( "avg_order": IntegerField(), }, )) -moon_avg = ( - cte +moon_avg = with_cte( + cte, + select=cte .join(Region, name=cte.col.region_id) .annotate(avg_order=cte.col.avg_order) - .with_cte(cte) ) ``` @@ -507,3 +458,33 @@ in the tests: - [`test_cte.py`](https://github.com/dimagi/django-cte/blob/main/tests/test_cte.py) - [`test_recursive.py`](https://github.com/dimagi/django-cte/blob/main/tests/test_recursive.py) - [`test_raw.py`](https://github.com/dimagi/django-cte/blob/main/tests/test_raw.py) + + +## Appendix A: Model definitions used in sample code + +```py +class Order(Model): + id = AutoField(primary_key=True) + region = ForeignKey("Region", on_delete=CASCADE) + amount = IntegerField(default=0) + + class Meta: + db_table = "orders" + + +class Region(Model): + name = TextField(primary_key=True) + parent = ForeignKey("self", null=True, on_delete=CASCADE) + + class Meta: + db_table = "region" +``` + + +## Appendix B: django-cte v1 documentation (DEPRECATED) + +The syntax for constructing CTE queries changed slightly in django-cte 2.0. The +most important change is that a custom model manager is no longer required on +models used to construct CTE queries. The documentation has been updated to use +v2 syntax, but the [documentation for v1](https://github.com/dimagi/django-cte/blob/v1.3.3/docs/index.md) +can be found on Github if needed. diff --git a/tests/__init__.py b/tests/__init__.py index b541028..ce38c96 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,6 @@ import os +import warnings +from contextlib import contextmanager import django from unmagic import fixture @@ -10,8 +12,23 @@ from .django_setup import init_db, destroy_db # noqa -@fixture(autouse=__name__, scope="package") +@fixture(autouse=__file__, scope="package") def test_db(): - init_db() + with ignore_v1_warnings(): + init_db() yield destroy_db() + + +@contextmanager +def ignore_v1_warnings(): + msg = ( + r"CTE(Manager|QuerySet) is deprecated.*" + r"|" + r"Use `django_cte\.with_cte\(.*\)` instead\." + r"|" + r"Use `django_cte\.CTE(\.recursive)?` instead\." + ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=msg, category=DeprecationWarning) + yield diff --git a/tests/models.py b/tests/models.py index ff2bd0d..f3c9e0c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2,6 +2,7 @@ CASCADE, Manager, Model, + QuerySet, AutoField, CharField, ForeignKey, @@ -9,33 +10,20 @@ TextField, ) -from django_cte import CTEManager, CTEQuerySet - -class LT40QuerySet(CTEQuerySet): +class LT40QuerySet(QuerySet): def lt40(self): return self.filter(amount__lt=40) -class LT30QuerySet(CTEQuerySet): - - def lt30(self): - return self.filter(amount__lt=30) - - -class LT25QuerySet(CTEQuerySet): +class LT25QuerySet(QuerySet): def lt25(self): return self.filter(amount__lt=25) -class LTManager(CTEManager): - pass - - class Region(Model): - objects = CTEManager() name = TextField(primary_key=True) parent = ForeignKey("self", null=True, on_delete=CASCADE) @@ -52,7 +40,6 @@ class Meta: class Order(Model): - objects = CTEManager() id = AutoField(primary_key=True) region = ForeignKey(Region, on_delete=CASCADE) amount = IntegerField(default=0) @@ -65,35 +52,16 @@ class Meta: class OrderFromLT40(Order): class Meta: proxy = True - objects = CTEManager.from_queryset(LT40QuerySet)() - - -class OrderLT40AsManager(Order): - class Meta: - proxy = True - objects = LT40QuerySet.as_manager() + objects = Manager.from_queryset(LT40QuerySet)() class OrderCustomManagerNQuery(Order): class Meta: proxy = True - objects = LTManager.from_queryset(LT25QuerySet)() - - -class OrderCustomManager(Order): - class Meta: - proxy = True - objects = LTManager() - - -class OrderPlainManager(Order): - class Meta: - proxy = True - objects = Manager() + objects = Manager.from_queryset(LT25QuerySet)() class KeyPair(Model): - objects = CTEManager() key = CharField(max_length=32) value = IntegerField(default=0) parent = ForeignKey("self", null=True, on_delete=CASCADE) diff --git a/tests/test_combinators.py b/tests/test_combinators.py index c493d06..78b0e60 100644 --- a/tests/test_combinators.py +++ b/tests/test_combinators.py @@ -3,42 +3,42 @@ from django.db.models.aggregates import Sum from django.test import TestCase -from django_cte import With +from django_cte import CTE, with_cte -from .models import Order, OrderPlainManager +from .models import Order class TestCTECombinators(TestCase): def test_cte_union_query(self): - one = With( + one = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")), name="one" ) - two = With( + two = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount") * 2), name="two" ) - earths = ( - one.join( + earths = with_cte( + one, + select=one.join( Order.objects.filter(region_id="earth"), region=one.col.region_id ) - .with_cte(one) .annotate(region_total=one.col.total) .values_list("amount", "region_id", "region_total") ) - mars = ( - two.join( + mars = with_cte( + two, + select=two.join( Order.objects.filter(region_id="mars"), region=two.col.region_id ) - .with_cte(two) .annotate(region_total=two.col.total) .values_list("amount", "region_id", "region_total") ) @@ -71,22 +71,21 @@ def test_cte_union_query(self): ]) def test_cte_union_with_non_cte_query(self): - one = With( + one = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")), ) - earths = ( - one.join( + earths = with_cte( + one, + select=one.join( Order.objects.filter(region_id="earth"), region=one.col.region_id - ) - .with_cte(one) - .annotate(region_total=one.col.total) + ).annotate(region_total=one.col.total) ) plain_mars = ( - OrderPlainManager.objects.filter(region_id="mars") + Order.objects.filter(region_id="mars") .annotate(region_total=Value(0)) ) # Note: this does not work in the opposite order. A CTE query @@ -106,27 +105,27 @@ def test_cte_union_with_non_cte_query(self): ]) def test_cte_union_with_duplicate_names(self): - cte_sun = With( + cte_sun = CTE( Order.objects .filter(region__parent="sun") .values("region_id") .annotate(total=Sum("amount")), ) - cte_proxima = With( + cte_proxima = CTE( Order.objects .filter(region__parent="proxima centauri") .values("region_id") .annotate(total=2 * Sum("amount")), ) - orders_sun = ( - cte_sun.join(Order, region=cte_sun.col.region_id) - .with_cte(cte_sun) + orders_sun = with_cte( + cte_sun, + select=cte_sun.join(Order, region=cte_sun.col.region_id) .annotate(region_total=cte_sun.col.total) ) - orders_proxima = ( - cte_proxima.join(Order, region=cte_proxima.col.region_id) - .with_cte(cte_proxima) + orders_proxima = with_cte( + cte_proxima, + select=cte_proxima.join(Order, region=cte_proxima.col.region_id) .annotate(region_total=cte_proxima.col.total) ) @@ -135,21 +134,21 @@ def test_cte_union_with_duplicate_names(self): orders_sun.union(orders_proxima) def test_cte_union_of_same_cte(self): - cte = With( + cte = CTE( Order.objects .filter(region__parent="sun") .values("region_id") .annotate(total=Sum("amount")), ) - orders_big = ( - cte.join(Order, region=cte.col.region_id) - .with_cte(cte) + orders_big = with_cte( + cte, + select=cte.join(Order, region=cte.col.region_id) .annotate(region_total=3 * cte.col.total) ) - orders_small = ( - cte.join(Order, region=cte.col.region_id) - .with_cte(cte) + orders_small = with_cte( + cte, + select=cte.join(Order, region=cte.col.region_id) .annotate(region_total=cte.col.total) ) @@ -189,27 +188,27 @@ def test_cte_union_of_same_cte(self): ]) def test_cte_intersection(self): - cte_big = With( + cte_big = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")), name='big' ) - cte_small = With( + cte_small = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")), name='small' ) - orders_big = ( - cte_big.join(Order, region=cte_big.col.region_id) - .with_cte(cte_big) + orders_big = with_cte( + cte_big, + select=cte_big.join(Order, region=cte_big.col.region_id) .annotate(region_total=cte_big.col.total) .filter(region_total__gte=86) ) - orders_small = ( - cte_small.join(Order, region=cte_small.col.region_id) - .with_cte(cte_small) + orders_small = with_cte( + cte_small, + select=cte_small.join(Order, region=cte_small.col.region_id) .annotate(region_total=cte_small.col.total) .filter(region_total__lte=123) ) @@ -229,27 +228,27 @@ def test_cte_intersection(self): ]) def test_cte_difference(self): - cte_big = With( + cte_big = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")), name='big' ) - cte_small = With( + cte_small = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")), name='small' ) - orders_big = ( - cte_big.join(Order, region=cte_big.col.region_id) - .with_cte(cte_big) + orders_big = with_cte( + cte_big, + select=cte_big.join(Order, region=cte_big.col.region_id) .annotate(region_total=cte_big.col.total) .filter(region_total__gte=86) ) - orders_small = ( - cte_small.join(Order, region=cte_small.col.region_id) - .with_cte(cte_small) + orders_small = with_cte( + cte_small, + select=cte_small.join(Order, region=cte_small.col.region_id) .annotate(region_total=cte_small.col.total) .filter(region_total__lte=123) ) diff --git a/tests/test_cte.py b/tests/test_cte.py index 920b93e..5f51959 100644 --- a/tests/test_cte.py +++ b/tests/test_cte.py @@ -1,15 +1,14 @@ -from unittest import SkipTest - +import pytest from django.db.models import IntegerField, TextField from django.db.models.aggregates import Count, Max, Min, Sum from django.db.models.expressions import ( Exists, ExpressionWrapper, F, OuterRef, Subquery, ) from django.db.models.sql.constants import LOUTER +from django.db.utils import OperationalError, ProgrammingError from django.test import TestCase -from django_cte import With -from django_cte import CTEManager +from django_cte import CTE, with_cte from .models import Order, Region, User @@ -20,22 +19,20 @@ class TestCTE(TestCase): def test_simple_cte_query(self): - cte = With( + cte = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")) ) - orders = ( - # FROM orders INNER JOIN cte ON orders.region_id = cte.region_id - cte.join(Order, region=cte.col.region_id) - - # Add `WITH ...` before `SELECT ... FROM orders ...` - .with_cte(cte) + orders = with_cte( + # WITH cte ... + cte, - # Annotate each Order with a "region_total" - .annotate(region_total=cte.col.total) - ) + # SELECT ... FROM orders + # INNER JOIN cte ON orders.region_id = cte.region_id + select=cte.join(Order, region=cte.col.region_id), + ).annotate(region_total=cte.col.total) print(orders.query) data = sorted((o.amount, o.region_id, o.region_total) for o in orders) @@ -65,17 +62,16 @@ def test_simple_cte_query(self): ]) def test_cte_name_escape(self): - totals = With( + totals = CTE( Order.objects .filter(region__parent="sun") .values("region_id") .annotate(total=Sum("amount")), name="mixedCaseCTEName" ) - orders = ( - totals - .join(Order, region=totals.col.region_id) - .with_cte(totals) + orders = with_cte( + totals, + select=totals.join(Order, region=totals.col.region_id) .annotate(region_total=totals.col.total) .order_by("amount") ) @@ -83,15 +79,14 @@ def test_cte_name_escape(self): str(orders.query).startswith('WITH RECURSIVE "mixedCaseCTEName"')) def test_cte_queryset(self): - sub_totals = With( + sub_totals = CTE( Order.objects .values(region_parent=F("region__parent_id")) .annotate(total=Sum("amount")), ) - regions = ( - Region.objects.all() - .with_cte(sub_totals) - .annotate( + regions = with_cte( + sub_totals, + select=Region.objects.annotate( child_regions_total=Subquery( sub_totals.queryset() .filter(region_parent=OuterRef("name")) @@ -118,11 +113,14 @@ def test_cte_queryset(self): ]) def test_cte_queryset_with_model_result(self): - cte = With( + cte = CTE( Order.objects .annotate(region_parent=F("region__parent_id")), ) - orders = cte.queryset().with_cte(cte) + orders = with_cte( + cte, # WITH cte AS (...) + select=cte, # SELECT ... FROM cte + ) print(orders.query) data = sorted( @@ -140,13 +138,13 @@ def test_cte_queryset_with_model_result(self): ) def test_cte_queryset_with_join(self): - cte = With( + cte = CTE( Order.objects .annotate(region_parent=F("region__parent_id")), ) - orders = ( - cte.queryset() - .with_cte(cte) + orders = with_cte( + cte, + select=cte.queryset() .annotate(parent=F("region__parent_id")) .order_by("region_id", "amount") ) @@ -162,7 +160,7 @@ def test_cte_queryset_with_join(self): ]) def test_cte_queryset_with_values_result(self): - cte = With( + cte = CTE( Order.objects .values( "region_id", @@ -170,11 +168,7 @@ def test_cte_queryset_with_values_result(self): ) .distinct() ) - values = ( - cte.queryset() - .with_cte(cte) - .filter(region_parent__isnull=False) - ) + values = with_cte(cte, select=cte).filter(region_parent__isnull=False) print(values.query) def key(item): @@ -193,27 +187,27 @@ def key(item): ]) def test_named_simple_ctes(self): - totals = With( + totals = CTE( Order.objects .filter(region__parent="sun") .values("region_id") .annotate(total=Sum("amount")), name="totals", ) - region_count = With( + region_count = CTE( Region.objects .filter(parent="sun") .values("parent_id") .annotate(num=Count("name")), name="region_count", ) - orders = ( - region_count.join( + orders = with_cte( + totals, + region_count, + select=region_count.join( totals.join(Order, region=totals.col.region_id), region__parent=region_count.col.parent_id ) - .with_cte(totals) - .with_cte(region_count) .annotate(region_total=totals.col.total) .annotate(region_count=region_count.col.num) .order_by("amount") @@ -257,9 +251,9 @@ def make_root_mapping(rootmap): ), all=True, ) - rootmap = With.recursive(make_root_mapping, name="rootmap") + rootmap = CTE.recursive(make_root_mapping, name="rootmap") - totals = With( + totals = CTE( rootmap.join(Order, region_id=rootmap.col.name) .values( root=rootmap.col.root, @@ -270,11 +264,10 @@ def make_root_mapping(rootmap): name="totals", ) - root_regions = ( - totals.join(Region, name=totals.col.root) - .with_cte(rootmap) - .with_cte(totals) - .annotate( + root_regions = with_cte( + rootmap, + totals, + select=totals.join(Region, name=totals.col.root).annotate( # count of orders in this region and all subregions orders_count=totals.col.orders_count, # sum of order amounts in this region and all subregions @@ -292,17 +285,16 @@ def make_root_mapping(rootmap): ]) def test_materialized_option(self): - totals = With( + totals = CTE( Order.objects .filter(region__parent="sun") .values("region_id") .annotate(total=Sum("amount")), materialized=True ) - orders = ( - totals - .join(Order, region=totals.col.region_id) - .with_cte(totals) + orders = with_cte( + totals, + select=totals.join(Order, region=totals.col.region_id) .annotate(region_total=totals.col.total) .order_by("amount") ) @@ -313,14 +305,14 @@ def test_materialized_option(self): ) def test_update_cte_query(self): - cte = With( + cte = CTE( Order.objects .values(region_parent=F("region__parent_id")) .annotate(total=Sum("amount")) .filter(total__isnull=False) ) # not the most efficient query, but it exercises CTEUpdateQuery - Order.objects.all().with_cte(cte).filter(region_id__in=Subquery( + with_cte(cte, select=Order).filter(region_id__in=Subquery( cte.queryset() .filter(region_parent=OuterRef("region_id")) .values("region_parent") @@ -344,7 +336,7 @@ def test_update_cte_query(self): def test_update_with_subquery(self): # Test for issue: https://github.com/dimagi/django-cte/issues/9 - # Issue is not reproduces on sqlite3 use postgres to run. + # Issue is not reproduced on sqlite3, use postgres to run. # To reproduce the problem it's required to have some join # in the select-query so the compiler will turn it into a subquery. # To add a join use a filter over field of related model @@ -358,19 +350,21 @@ def test_update_with_subquery(self): ('mars', 0), }) + @pytest.mark.xfail( + reason="this test will not work until `QuerySet.delete` " + "(Django method) calls `self.query.chain(sql.DeleteQuery)` " + "instead of `sql.DeleteQuery(self.model)`", + raises=(OperationalError, ProgrammingError), + strict=True, + ) def test_delete_cte_query(self): - raise SkipTest( - "this test will not work until `QuerySet.delete` (Django method) " - "calls `self.query.chain(sql.DeleteQuery)` instead of " - "`sql.DeleteQuery(self.model)`" - ) - cte = With( + cte = CTE( Order.objects .values(region_parent=F("region__parent_id")) .annotate(total=Sum("amount")) .filter(total__isnull=False) ) - Order.objects.all().with_cte(cte).annotate( + with_cte(cte, select=Order).annotate( cte_has_order=Exists( cte.queryset() .values("total") @@ -391,7 +385,7 @@ def test_delete_cte_query(self): def test_outerref_in_cte_query(self): # This query is meant to return the difference between min and max # order of each region, through a subquery - min_and_max = With( + min_and_max = CTE( Order.objects .filter(region=OuterRef("pk")) .values('region') # This is to force group by region_id @@ -405,7 +399,8 @@ def test_outerref_in_cte_query(self): Region.objects .annotate( difference=Subquery( - min_and_max.queryset().with_cte(min_and_max).annotate( + with_cte(min_and_max, select=min_and_max) + .annotate( difference=ExpressionWrapper( F('amount_max') - F('amount_min'), output_field=int_field, @@ -434,16 +429,16 @@ def test_outerref_in_cte_query(self): ]) def test_experimental_left_outer_join(self): - totals = With( + totals = CTE( Order.objects .values("region_id") .annotate(total=Sum("amount")) .filter(total__gt=100) ) - orders = ( - totals + orders = with_cte( + totals, + select=totals .join(Order, region=totals.col.region_id, _join_type=LOUTER) - .with_cte(totals) .annotate(region_total=totals.col.total) ) print(orders.query) @@ -482,9 +477,7 @@ def test_non_cte_subquery(self): subquery model doesn't use the CTE manager, and the query results match expected behavior """ - self.assertNotIsInstance(User.objects, CTEManager) - - sub_totals = With( + sub_totals = CTE( Order.objects .values(region_parent=F("region__parent_id")) .annotate( @@ -496,10 +489,9 @@ def test_non_cte_subquery(self): ), ), ) - regions = ( - Region.objects.all() - .with_cte(sub_totals) - .annotate( + regions = with_cte( + sub_totals, + select=Region.objects.annotate( child_regions_total=Subquery( sub_totals.queryset() .filter(region_parent=OuterRef("name")) @@ -531,27 +523,27 @@ def test_explain(self): correct position """ - totals = With( + totals = CTE( Order.objects .filter(region__parent="sun") .values("region_id") .annotate(total=Sum("amount")), name="totals", ) - region_count = With( + region_count = CTE( Region.objects .filter(parent="sun") .values("parent_id") .annotate(num=Count("name")), name="region_count", ) - orders = ( - region_count.join( + orders = with_cte( + totals, + region_count, + select=region_count.join( totals.join(Order, region=totals.col.region_id), region__parent=region_count.col.parent_id ) - .with_cte(totals) - .with_cte(region_count) .annotate(region_total=totals.col.total) .annotate(region_count=region_count.col.num) .order_by("amount") @@ -565,16 +557,16 @@ def test_empty_result_set_cte(self): Verifies that the CTEQueryCompiler can handle empty result sets in the related CTEs """ - totals = With( + totals = CTE( Order.objects .filter(id__in=[]) .values("region_id") .annotate(total=Sum("amount")), name="totals", ) - orders = ( - totals.join(Order, region=totals.col.region_id) - .with_cte(totals) + orders = with_cte( + totals, + select=totals.join(Order, region=totals.col.region_id) .annotate(region_total=totals.col.total) .order_by("amount") ) @@ -582,16 +574,17 @@ def test_empty_result_set_cte(self): self.assertEqual(len(orders), 0) def test_left_outer_join_on_empty_result_set_cte(self): - totals = With( + totals = CTE( Order.objects .filter(id__in=[]) .values("region_id") .annotate(total=Sum("amount")), name="totals", ) - orders = ( - totals.join(Order, region=totals.col.region_id, _join_type=LOUTER) - .with_cte(totals) + orders = with_cte( + totals, + select=totals + .join(Order, region=totals.col.region_id, _join_type=LOUTER) .annotate(region_total=totals.col.total) .order_by("amount") ) @@ -604,16 +597,16 @@ def test_union_query_with_cte(self): .filter(region__parent="sun") .only("region", "amount") ) - orders_cte = With(orders, name="orders_cte") + orders_cte = CTE(orders, name="orders_cte") orders_cte_queryset = orders_cte.queryset() earth_orders = orders_cte_queryset.filter(region="earth") mars_orders = orders_cte_queryset.filter(region="mars") earth_mars = earth_orders.union(mars_orders, all=True) - earth_mars_cte = ( - earth_mars - .with_cte(orders_cte) + earth_mars_cte = with_cte( + orders_cte, + select=earth_mars .order_by("region", "amount") .values_list("region", "amount") ) @@ -631,8 +624,10 @@ def test_union_query_with_cte(self): def test_cte_select_pk(self): orders = Order.objects.filter(region="earth").values("pk") - cte = With(orders) - queryset = cte.join(orders, pk=cte.col.pk).with_cte(cte).order_by("pk") + cte = CTE(orders) + queryset = with_cte( + cte, select=cte.join(orders, pk=cte.col.pk) + ).order_by("pk") print(queryset.query) self.assertEqual(list(queryset), [ {'pk': 9}, @@ -642,7 +637,7 @@ def test_cte_select_pk(self): ]) def test_django52_resolve_ref_regression(self): - cte = With( + cte = CTE( Order.objects.annotate( pnt_id=F("region__parent_id"), region_name=F("region__name"), @@ -655,9 +650,9 @@ def test_django52_resolve_ref_regression(self): "region_name", ) ) - qs = ( - cte.queryset() - .with_cte(cte) + qs = with_cte( + cte, + select=cte.queryset() .values( amt=cte.col.amount, pnt_id=cte.col.pnt_id, diff --git a/tests/test_django.py b/tests/test_django.py index 83c92a2..52d9988 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -4,52 +4,11 @@ from django.db import OperationalError, ProgrammingError from django.db.models import Window from django.db.models.functions import Rank -from django.test import TestCase, skipUnlessDBFeature +from django.test import TestCase -from .models import Order, Region, User +from django_cte import CTE, with_cte - -@skipUnlessDBFeature("supports_select_union") -class NonCteQueries(TestCase): - """Test non-CTE queries - - These tests were adapted from the Django test suite. The models used - here use CTEManager and CTEQuerySet to verify feature parity with - their base classes Manager and QuerySet. - """ - - @classmethod - def setUpTestData(cls): - Order.objects.all().delete() - - def test_union_with_select_related_and_order(self): - e1 = User.objects.create(name="e1") - a1 = Order.objects.create(region_id="earth", user=e1) - a2 = Order.objects.create(region_id="moon", user=e1) - Order.objects.create(region_id="sun", user=e1) - base_qs = Order.objects.select_related("user").order_by() - qs1 = base_qs.filter(region_id="earth") - qs2 = base_qs.filter(region_id="moon") - print(qs1.union(qs2).order_by("pk").query) - self.assertSequenceEqual(qs1.union(qs2).order_by("pk"), [a1, a2]) - - @skipUnlessDBFeature("supports_slicing_ordering_in_compound") - def test_union_with_select_related_and_first(self): - e1 = User.objects.create(name="e1") - a1 = Order.objects.create(region_id="earth", user=e1) - Order.objects.create(region_id="moon", user=e1) - base_qs = Order.objects.select_related("user") - qs1 = base_qs.filter(region_id="earth") - qs2 = base_qs.filter(region_id="moon") - self.assertEqual(qs1.union(qs2).first(), a1) - - def test_union_with_first(self): - e1 = User.objects.create(name="e1") - a1 = Order.objects.create(region_id="earth", user=e1) - base_qs = Order.objects.order_by() - qs1 = base_qs.filter(region_id="earth") - qs2 = base_qs.filter(region_id="moon") - self.assertEqual(qs1.union(qs2).first(), a1) +from .models import Order, Region class WindowFunctions(TestCase): @@ -57,8 +16,7 @@ class WindowFunctions(TestCase): def test_heterogeneous_filter_in_cte(self): if django.VERSION < (4, 2): raise SkipTest("feature added in Django 4.2") - from django_cte import With - cte = With( + cte = CTE( Order.objects.annotate( region_amount_rank=Window( Rank(), partition_by="region_id", order_by="-amount" @@ -68,7 +26,7 @@ def test_heterogeneous_filter_in_cte(self): .values("region_id", "region_amount_rank") .filter(region_amount_rank=1, region_id__in=["sun", "moon"]) ) - qs = cte.join(Region, name=cte.col.region_id).with_cte(cte) + qs = with_cte(cte, select=cte.join(Region, name=cte.col.region_id)) print(qs.query) # ProgrammingError: column cte.region_id does not exist # WITH RECURSIVE "cte" AS (SELECT * FROM ( diff --git a/tests/test_manager.py b/tests/test_manager.py index be085f9..e2ad6de 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,61 +1,28 @@ from django.db.models.expressions import F -from django.db.models.query import QuerySet from django.test import TestCase -from django_cte import With, CTEQuerySet, CTEManager +from django_cte import CTE, with_cte from .models import ( - Order, OrderFromLT40, - OrderLT40AsManager, OrderCustomManagerNQuery, - OrderCustomManager, LT40QuerySet, - LTManager, - LT25QuerySet, ) class TestCTE(TestCase): - def test_cte_queryset_correct_defaultmanager(self): - self.assertEqual(type(Order._default_manager), CTEManager) - self.assertEqual(type(Order.objects.all()), CTEQuerySet) - - def test_cte_queryset_correct_from_queryset(self): - self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) - - def test_cte_queryset_correct_queryset_as_manager(self): - self.assertEqual(type(OrderLT40AsManager.objects.all()), LT40QuerySet) - - def test_cte_queryset_correct_manager_n_from_queryset(self): - self.assertIsInstance( - OrderCustomManagerNQuery._default_manager, LTManager) - self.assertEqual(type( - OrderCustomManagerNQuery.objects.all()), LT25QuerySet) - - def test_cte_create_manager_from_non_cteQuery(self): - class BrokenQuerySet(QuerySet): - "This should be a CTEQuerySet if we want this to work" - - with self.assertRaises(TypeError): - CTEManager.from_queryset(BrokenQuerySet)() - - def test_cte_queryset_correct_limitedmanager(self): - self.assertEqual(type(OrderCustomManager._default_manager), LTManager) - # Check the expected even if not ideal behavior occurs - self.assertIsInstance(OrderCustomManager.objects.all(), CTEQuerySet) def test_cte_queryset_with_from_queryset(self): self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) - cte = With( + cte = CTE( OrderFromLT40.objects .annotate(region_parent=F("region__parent_id")) .filter(region__parent_id="sun") ) - orders = ( - cte.queryset() - .with_cte(cte) + orders = with_cte( + cte, + select=cte.queryset() .lt40() # custom queryset method .order_by("region_id", "amount") ) @@ -77,14 +44,14 @@ def test_cte_queryset_with_from_queryset(self): ]) def test_cte_queryset_with_custom_queryset(self): - cte = With( + cte = CTE( OrderCustomManagerNQuery.objects .annotate(region_parent=F("region__parent_id")) .filter(region__parent_id="sun") ) - orders = ( - cte.queryset() - .with_cte(cte) + orders = with_cte( + cte, + select=cte.queryset() .lt25() # custom queryset method .order_by("region_id", "amount") ) @@ -102,10 +69,10 @@ def test_cte_queryset_with_custom_queryset(self): ]) def test_cte_queryset_with_deferred_loading(self): - cte = With( + cte = CTE( OrderCustomManagerNQuery.objects.order_by("id").only("id")[:1] ) - orders = cte.queryset().with_cte(cte) + orders = with_cte(cte, select=cte) print(orders.query) self.assertEqual([x.id for x in orders], [1]) diff --git a/tests/test_raw.py b/tests/test_raw.py index ade9774..3590803 100644 --- a/tests/test_raw.py +++ b/tests/test_raw.py @@ -1,7 +1,7 @@ from django.db.models import IntegerField, TextField from django.test import TestCase -from django_cte import With +from django_cte import CTE, with_cte from django_cte.raw import raw_cte_sql from .models import Region @@ -13,7 +13,7 @@ class TestRawCTE(TestCase): def test_raw_cte_sql(self): - cte = With(raw_cte_sql( + cte = CTE(raw_cte_sql( """ SELECT region_id, AVG(amount) AS avg_order FROM orders @@ -23,19 +23,16 @@ def test_raw_cte_sql(self): ["moon"], {"region_id": text_field, "avg_order": int_field}, )) - moon_avg = ( - cte - .join(Region, name=cte.col.region_id) - .annotate(avg_order=cte.col.avg_order) - .with_cte(cte) - ) + moon_avg = with_cte( + cte, select=cte.join(Region, name=cte.col.region_id) + ).annotate(avg_order=cte.col.avg_order) print(moon_avg.query) data = [(r.name, r.parent.name, r.avg_order) for r in moon_avg] self.assertEqual(data, [('moon', 'earth', 2)]) def test_raw_cte_sql_name_escape(self): - cte = With( + cte = CTE( raw_cte_sql( """ SELECT region_id, AVG(amount) AS avg_order @@ -48,12 +45,9 @@ def test_raw_cte_sql_name_escape(self): ), name="mixedCaseCTEName" ) - moon_avg = ( - cte - .join(Region, name=cte.col.region_id) - .annotate(avg_order=cte.col.avg_order) - .with_cte(cte) - ) + moon_avg = with_cte( + cte, select=cte.join(Region, name=cte.col.region_id) + ).annotate(avg_order=cte.col.avg_order) self.assertTrue( str(moon_avg.query).startswith( 'WITH RECURSIVE "mixedCaseCTEName"') diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 6506d8f..b28dbab 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -16,7 +16,7 @@ from django.db.utils import DatabaseError from django.test import TestCase -from django_cte import With +from django_cte import CTE, with_cte from .models import KeyPair, Region @@ -48,11 +48,11 @@ def make_regions_cte(cte): all=True, ) - cte = With.recursive(make_regions_cte) + cte = CTE.recursive(make_regions_cte) - regions = ( - cte.join(Region, name=cte.col.name) - .with_cte(cte) + regions = with_cte( + cte, + select=cte.join(Region, name=cte.col.name) .annotate( path=cte.col.path, depth=cte.col.depth, @@ -108,8 +108,10 @@ def make_regions_cte(cte): ), all=True, ) - cte = With.recursive(make_regions_cte) - regions = cte.join(Region, name=cte.col.name).with_cte(cte).annotate( + cte = CTE.recursive(make_regions_cte) + regions = with_cte( + cte, select=cte.join(Region, name=cte.col.name) + ).annotate( path=cte.col.path, depth=cte.col.depth, is_planet=cte.col.is_planet, @@ -135,8 +137,8 @@ def make_regions_cte(cte): cte.join(Region, parent=cte.col.name), all=True, ) - cte = With.recursive(make_regions_cte) - regions = cte.join(Region, name=cte.col.name).with_cte(cte) + cte = CTE.recursive(make_regions_cte) + regions = with_cte(cte, select=cte.join(Region, name=cte.col.name)) print(regions.query) try: @@ -164,8 +166,8 @@ def make_bad_cte(cte): return cte.join(Region, parent=cte.col.name).values( depth=cte.col.depth + 1, ) - cte = With.recursive(make_bad_cte) - regions = cte.join(Region, name=cte.col.name).with_cte(cte) + cte = CTE.recursive(make_bad_cte) + regions = with_cte(cte, select=cte.join(Region, name=cte.col.name)) with self.assertRaises(ValueError) as context: print(regions.query) self.assertIn("Circular reference:", str(context.exception)) @@ -184,11 +186,10 @@ def make_regions_cte(cte): ), all=True, ) - cte = With.recursive(make_regions_cte) - regions = ( - Region.objects.all() - .with_cte(cte) - .annotate(_ex=Exists( + cte = CTE.recursive(make_regions_cte) + regions = with_cte( + cte, + select=Region.objects.annotate(_ex=Exists( cte.queryset() .values(value=Value("1", output_field=int_field)) .filter(name=OuterRef("name")) @@ -213,8 +214,8 @@ def make_regions_cte(cte): ), all=True, ) - cte = With.recursive(make_regions_cte) - regions = cte.queryset().with_cte(cte).filter(depth=2).order_by("name") + cte = CTE.recursive(make_regions_cte) + regions = with_cte(cte, select=cte).filter(depth=2).order_by("name") pickled_qs = pickle.loads(pickle.dumps(regions)) @@ -230,30 +231,28 @@ def make_regions_cte(cte): value=F('name'), ).union( cte.join( - Region.objects.all().annotate( - value=F('name'), - ), + Region.objects.annotate(value=F('name')), parent_id=cte.col.name, ), all=True, ) - cte = With.recursive(make_regions_cte) - query = cte.queryset().with_cte(cte) + cte = CTE.recursive(make_regions_cte) + query = with_cte(cte, select=cte) - exclude_leaves = With(cte.queryset().filter( + exclude_leaves = CTE(cte.queryset().filter( parent__name='sun', ).annotate( value=Concat(F('name'), F('name')) ), name='value_cte') - query = query.annotate( + query = with_cte(exclude_leaves, select=query.annotate( _exclude_leaves=Exists( exclude_leaves.queryset().filter( name=OuterRef("name"), value=OuterRef("value"), ) ) - ).filter(_exclude_leaves=True).with_cte(exclude_leaves) + ).filter(_exclude_leaves=True)) print(query.query) # Nothing should be returned. @@ -268,23 +267,23 @@ def make_regions_cte(cte): rank=F('value'), ).union( cte.join( - KeyPair.objects.all().order_by(), + KeyPair.objects.order_by(), parent_id=cte.col.id, ).annotate( rank=F('value'), ), all=True, ) - cte = With.recursive(make_regions_cte) - children = cte.queryset().with_cte(cte) + cte = CTE.recursive(make_regions_cte) + children = with_cte(cte, select=cte) - xdups = With(cte.queryset().filter( + xdups = CTE(cte.queryset().filter( parent__key="level 1", ).annotate( rank=F('value') ).values('id', 'rank'), name='xdups') - children = children.annotate( + children = with_cte(xdups, select=children.annotate( _exclude=Exists( ( xdups.queryset().filter( @@ -293,7 +292,7 @@ def make_regions_cte(cte): ) ) ) - ).filter(_exclude=True).with_cte(xdups) + ).filter(_exclude=True)) print(children.query) query = KeyPair.objects.filter(parent__in=children) @@ -311,9 +310,9 @@ def test_materialized(self): # This test covers MATERIALIZED option in SQL query def make_regions_cte(cte): return KeyPair.objects.all() - cte = With.recursive(make_regions_cte, materialized=True) + cte = CTE.recursive(make_regions_cte, materialized=True) - query = KeyPair.objects.with_cte(cte) + query = with_cte(cte, select=KeyPair) print(query.query) self.assertTrue( str(query.query).startswith('WITH RECURSIVE "cte" AS MATERIALIZED') @@ -326,8 +325,8 @@ def make_regions_cte(cte): ).values("pk").union( cte.join(Region, parent=cte.col.pk).values("pk") ) - cte = With.recursive(make_regions_cte) - queryset = cte.queryset().with_cte(cte).order_by("pk") + cte = CTE.recursive(make_regions_cte) + queryset = with_cte(cte, select=cte).order_by("pk") print(queryset.query) self.assertEqual(list(queryset), [ {'pk': 'earth'}, diff --git a/tests/test_v1/__init__.py b/tests/test_v1/__init__.py new file mode 100644 index 0000000..0e93b30 --- /dev/null +++ b/tests/test_v1/__init__.py @@ -0,0 +1,19 @@ +from unmagic import fixture + +from .. import ignore_v1_warnings + + +@fixture(autouse=__file__) +def ignore_v1_deprecations(): + with ignore_v1_warnings(): + yield + + +@fixture(autouse=__file__, scope="class") +def ignore_v1_deprecations_in_class_setup(): + with ignore_v1_warnings(): + yield + + +with ignore_v1_warnings(): + from . import models # noqa: F401 diff --git a/tests/test_v1/models.py b/tests/test_v1/models.py new file mode 100644 index 0000000..63fb183 --- /dev/null +++ b/tests/test_v1/models.py @@ -0,0 +1,97 @@ +from django.db.models import Manager + +from django_cte import CTEManager, CTEQuerySet + +from ..models import ( + KeyPair as V2KeyPair, + Order as V2Order, + Region as V2Region, + User, # noqa: F401 +) + + +class LT40QuerySet(CTEQuerySet): + + def lt40(self): + return self.filter(amount__lt=40) + + +class LT30QuerySet(CTEQuerySet): + + def lt30(self): + return self.filter(amount__lt=30) + + +class LT25QuerySet(CTEQuerySet): + + def lt25(self): + return self.filter(amount__lt=25) + + +class LTManager(CTEManager): + pass + + +class V1Region(V2Region): + objects = CTEManager() + + class Meta: + proxy = True + + +Region = V1Region + + +class V1Order(V2Order): + objects = CTEManager() + + class Meta: + proxy = True + + +Order = V1Order + + +class V1OrderFromLT40(Order): + class Meta: + proxy = True + objects = CTEManager.from_queryset(LT40QuerySet)() + + +class V1OrderLT40AsManager(Order): + class Meta: + proxy = True + objects = LT40QuerySet.as_manager() + + +class V1OrderCustomManagerNQuery(Order): + class Meta: + proxy = True + objects = LTManager.from_queryset(LT25QuerySet)() + + +class V1OrderCustomManager(Order): + class Meta: + proxy = True + objects = LTManager() + + +class V1OrderPlainManager(Order): + class Meta: + proxy = True + objects = Manager() + + +class V1KeyPair(V2KeyPair): + objects = CTEManager() + + class Meta: + proxy = True + + +KeyPair = V1KeyPair +OrderCustomManager = V1OrderCustomManager +OrderCustomManagerNQuery = V1OrderCustomManagerNQuery +OrderFromLT40 = V1OrderFromLT40 +OrderLT40AsManager = V1OrderLT40AsManager +OrderPlainManager = V1OrderPlainManager diff --git a/tests/test_v1/test_combinators.py b/tests/test_v1/test_combinators.py new file mode 100644 index 0000000..c493d06 --- /dev/null +++ b/tests/test_v1/test_combinators.py @@ -0,0 +1,271 @@ +import pytest +from django.db.models import Value +from django.db.models.aggregates import Sum +from django.test import TestCase + +from django_cte import With + +from .models import Order, OrderPlainManager + + +class TestCTECombinators(TestCase): + + def test_cte_union_query(self): + one = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")), + name="one" + ) + two = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount") * 2), + name="two" + ) + + earths = ( + one.join( + Order.objects.filter(region_id="earth"), + region=one.col.region_id + ) + .with_cte(one) + .annotate(region_total=one.col.total) + .values_list("amount", "region_id", "region_total") + ) + mars = ( + two.join( + Order.objects.filter(region_id="mars"), + region=two.col.region_id + ) + .with_cte(two) + .annotate(region_total=two.col.total) + .values_list("amount", "region_id", "region_total") + ) + combined = earths.union(mars, all=True) + print(combined.query) + + self.assertEqual(sorted(combined), [ + (30, 'earth', 126), + (31, 'earth', 126), + (32, 'earth', 126), + (33, 'earth', 126), + (40, 'mars', 246), + (41, 'mars', 246), + (42, 'mars', 246), + ]) + + # queries used in union should still work on their own + print(earths.query) + self.assertEqual(sorted(earths),[ + (30, 'earth', 126), + (31, 'earth', 126), + (32, 'earth', 126), + (33, 'earth', 126), + ]) + print(mars.query) + self.assertEqual(sorted(mars),[ + (40, 'mars', 246), + (41, 'mars', 246), + (42, 'mars', 246), + ]) + + def test_cte_union_with_non_cte_query(self): + one = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")), + ) + + earths = ( + one.join( + Order.objects.filter(region_id="earth"), + region=one.col.region_id + ) + .with_cte(one) + .annotate(region_total=one.col.total) + ) + plain_mars = ( + OrderPlainManager.objects.filter(region_id="mars") + .annotate(region_total=Value(0)) + ) + # Note: this does not work in the opposite order. A CTE query + # must come first to invoke custom CTE combinator logic. + combined = earths.union(plain_mars, all=True) \ + .values_list("amount", "region_id", "region_total") + print(combined.query) + + self.assertEqual(sorted(combined), [ + (30, 'earth', 126), + (31, 'earth', 126), + (32, 'earth', 126), + (33, 'earth', 126), + (40, 'mars', 0), + (41, 'mars', 0), + (42, 'mars', 0), + ]) + + def test_cte_union_with_duplicate_names(self): + cte_sun = With( + Order.objects + .filter(region__parent="sun") + .values("region_id") + .annotate(total=Sum("amount")), + ) + cte_proxima = With( + Order.objects + .filter(region__parent="proxima centauri") + .values("region_id") + .annotate(total=2 * Sum("amount")), + ) + + orders_sun = ( + cte_sun.join(Order, region=cte_sun.col.region_id) + .with_cte(cte_sun) + .annotate(region_total=cte_sun.col.total) + ) + orders_proxima = ( + cte_proxima.join(Order, region=cte_proxima.col.region_id) + .with_cte(cte_proxima) + .annotate(region_total=cte_proxima.col.total) + ) + + msg = "Found two or more CTEs named 'cte'" + with pytest.raises(ValueError, match=msg): + orders_sun.union(orders_proxima) + + def test_cte_union_of_same_cte(self): + cte = With( + Order.objects + .filter(region__parent="sun") + .values("region_id") + .annotate(total=Sum("amount")), + ) + + orders_big = ( + cte.join(Order, region=cte.col.region_id) + .with_cte(cte) + .annotate(region_total=3 * cte.col.total) + ) + orders_small = ( + cte.join(Order, region=cte.col.region_id) + .with_cte(cte) + .annotate(region_total=cte.col.total) + ) + + orders = orders_big.union(orders_small) \ + .values_list("amount", "region_id", "region_total") + print(orders.query) + + self.assertEqual(sorted(orders), [ + (10, 'mercury', 33), + (10, 'mercury', 99), + (11, 'mercury', 33), + (11, 'mercury', 99), + (12, 'mercury', 33), + (12, 'mercury', 99), + (20, 'venus', 86), + (20, 'venus', 258), + (21, 'venus', 86), + (21, 'venus', 258), + (22, 'venus', 86), + (22, 'venus', 258), + (23, 'venus', 86), + (23, 'venus', 258), + (30, 'earth', 126), + (30, 'earth', 378), + (31, 'earth', 126), + (31, 'earth', 378), + (32, 'earth', 126), + (32, 'earth', 378), + (33, 'earth', 126), + (33, 'earth', 378), + (40, 'mars', 123), + (40, 'mars', 369), + (41, 'mars', 123), + (41, 'mars', 369), + (42, 'mars', 123), + (42, 'mars', 369) + ]) + + def test_cte_intersection(self): + cte_big = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")), + name='big' + ) + cte_small = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")), + name='small' + ) + orders_big = ( + cte_big.join(Order, region=cte_big.col.region_id) + .with_cte(cte_big) + .annotate(region_total=cte_big.col.total) + .filter(region_total__gte=86) + ) + orders_small = ( + cte_small.join(Order, region=cte_small.col.region_id) + .with_cte(cte_small) + .annotate(region_total=cte_small.col.total) + .filter(region_total__lte=123) + ) + + orders = orders_small.intersection(orders_big) \ + .values_list("amount", "region_id", "region_total") + print(orders.query) + + self.assertEqual(sorted(orders), [ + (20, 'venus', 86), + (21, 'venus', 86), + (22, 'venus', 86), + (23, 'venus', 86), + (40, 'mars', 123), + (41, 'mars', 123), + (42, 'mars', 123), + ]) + + def test_cte_difference(self): + cte_big = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")), + name='big' + ) + cte_small = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")), + name='small' + ) + orders_big = ( + cte_big.join(Order, region=cte_big.col.region_id) + .with_cte(cte_big) + .annotate(region_total=cte_big.col.total) + .filter(region_total__gte=86) + ) + orders_small = ( + cte_small.join(Order, region=cte_small.col.region_id) + .with_cte(cte_small) + .annotate(region_total=cte_small.col.total) + .filter(region_total__lte=123) + ) + + orders = orders_small.difference(orders_big) \ + .values_list("amount", "region_id", "region_total") + print(orders.query) + + self.assertEqual(sorted(orders), [ + (1, 'moon', 6), + (2, 'moon', 6), + (3, 'moon', 6), + (10, 'mercury', 33), + (10, 'proxima centauri b', 33), + (11, 'mercury', 33), + (11, 'proxima centauri b', 33), + (12, 'mercury', 33), + (12, 'proxima centauri b', 33), + ]) diff --git a/tests/test_v1/test_cte.py b/tests/test_v1/test_cte.py new file mode 100644 index 0000000..6193106 --- /dev/null +++ b/tests/test_v1/test_cte.py @@ -0,0 +1,642 @@ +from unittest import SkipTest + +from django.db.models import IntegerField, TextField +from django.db.models.aggregates import Count, Max, Min, Sum +from django.db.models.expressions import ( + Exists, ExpressionWrapper, F, OuterRef, Subquery, +) +from django.db.models.sql.constants import LOUTER +from django.test import TestCase + +from django_cte import With +from django_cte import CTEManager + +from .models import Order, Region, User + +int_field = IntegerField() +text_field = TextField() + + +class TestCTE(TestCase): + + def test_simple_cte_query(self): + cte = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")) + ) + + orders = ( + # FROM orders INNER JOIN cte ON orders.region_id = cte.region_id + cte.join(Order, region=cte.col.region_id) + + # Add `WITH ...` before `SELECT ... FROM orders ...` + .with_cte(cte) + + # Annotate each Order with a "region_total" + .annotate(region_total=cte.col.total) + ) + print(orders.query) + + data = sorted((o.amount, o.region_id, o.region_total) for o in orders) + self.assertEqual(data, [ + (1, 'moon', 6), + (2, 'moon', 6), + (3, 'moon', 6), + (10, 'mercury', 33), + (10, 'proxima centauri b', 33), + (11, 'mercury', 33), + (11, 'proxima centauri b', 33), + (12, 'mercury', 33), + (12, 'proxima centauri b', 33), + (20, 'venus', 86), + (21, 'venus', 86), + (22, 'venus', 86), + (23, 'venus', 86), + (30, 'earth', 126), + (31, 'earth', 126), + (32, 'earth', 126), + (33, 'earth', 126), + (40, 'mars', 123), + (41, 'mars', 123), + (42, 'mars', 123), + (1000, 'sun', 1000), + (2000, 'proxima centauri', 2000), + ]) + + def test_cte_name_escape(self): + totals = With( + Order.objects + .filter(region__parent="sun") + .values("region_id") + .annotate(total=Sum("amount")), + name="mixedCaseCTEName" + ) + orders = ( + totals + .join(Order, region=totals.col.region_id) + .with_cte(totals) + .annotate(region_total=totals.col.total) + .order_by("amount") + ) + self.assertTrue( + str(orders.query).startswith('WITH RECURSIVE "mixedCaseCTEName"')) + + def test_cte_queryset(self): + sub_totals = With( + Order.objects + .values(region_parent=F("region__parent_id")) + .annotate(total=Sum("amount")), + ) + regions = ( + Region.objects.all() + .with_cte(sub_totals) + .annotate( + child_regions_total=Subquery( + sub_totals.queryset() + .filter(region_parent=OuterRef("name")) + .values("total"), + ), + ) + .order_by("name") + ) + print(regions.query) + + data = [(r.name, r.child_regions_total) for r in regions] + self.assertEqual(data, [ + ("bernard's star", None), + ('deimos', None), + ('earth', 6), + ('mars', None), + ('mercury', None), + ('moon', None), + ('phobos', None), + ('proxima centauri', 33), + ('proxima centauri b', None), + ('sun', 368), + ('venus', None) + ]) + + def test_cte_queryset_with_model_result(self): + cte = With( + Order.objects + .annotate(region_parent=F("region__parent_id")), + ) + orders = cte.queryset().with_cte(cte) + print(orders.query) + + data = sorted( + (x.region_id, x.amount, x.region_parent) for x in orders)[:5] + self.assertEqual(data, [ + ("earth", 30, "sun"), + ("earth", 31, "sun"), + ("earth", 32, "sun"), + ("earth", 33, "sun"), + ("mars", 40, "sun"), + ]) + self.assertTrue( + all(isinstance(x, Order) for x in orders), + repr([x for x in orders]), + ) + + def test_cte_queryset_with_join(self): + cte = With( + Order.objects + .annotate(region_parent=F("region__parent_id")), + ) + orders = ( + cte.queryset() + .with_cte(cte) + .annotate(parent=F("region__parent_id")) + .order_by("region_id", "amount") + ) + print(orders.query) + + data = [(x.region_id, x.region_parent, x.parent) for x in orders][:5] + self.assertEqual(data, [ + ("earth", "sun", "sun"), + ("earth", "sun", "sun"), + ("earth", "sun", "sun"), + ("earth", "sun", "sun"), + ("mars", "sun", "sun"), + ]) + + def test_cte_queryset_with_values_result(self): + cte = With( + Order.objects + .values( + "region_id", + region_parent=F("region__parent_id"), + ) + .distinct() + ) + values = ( + cte.queryset() + .with_cte(cte) + .filter(region_parent__isnull=False) + ) + print(values.query) + + def key(item): + return item["region_parent"], item["region_id"] + + data = sorted(values, key=key)[:5] + self.assertEqual(data, [ + {'region_id': 'moon', 'region_parent': 'earth'}, + { + 'region_id': 'proxima centauri b', + 'region_parent': 'proxima centauri', + }, + {'region_id': 'earth', 'region_parent': 'sun'}, + {'region_id': 'mars', 'region_parent': 'sun'}, + {'region_id': 'mercury', 'region_parent': 'sun'}, + ]) + + def test_named_simple_ctes(self): + totals = With( + Order.objects + .filter(region__parent="sun") + .values("region_id") + .annotate(total=Sum("amount")), + name="totals", + ) + region_count = With( + Region.objects + .filter(parent="sun") + .values("parent_id") + .annotate(num=Count("name")), + name="region_count", + ) + orders = ( + region_count.join( + totals.join(Order, region=totals.col.region_id), + region__parent=region_count.col.parent_id + ) + .with_cte(totals) + .with_cte(region_count) + .annotate(region_total=totals.col.total) + .annotate(region_count=region_count.col.num) + .order_by("amount") + ) + print(orders.query) + + data = [( + o.amount, + o.region_id, + o.region_count, + o.region_total, + ) for o in orders] + self.assertEqual(data, [ + (10, 'mercury', 4, 33), + (11, 'mercury', 4, 33), + (12, 'mercury', 4, 33), + (20, 'venus', 4, 86), + (21, 'venus', 4, 86), + (22, 'venus', 4, 86), + (23, 'venus', 4, 86), + (30, 'earth', 4, 126), + (31, 'earth', 4, 126), + (32, 'earth', 4, 126), + (33, 'earth', 4, 126), + (40, 'mars', 4, 123), + (41, 'mars', 4, 123), + (42, 'mars', 4, 123), + ]) + + def test_named_ctes(self): + def make_root_mapping(rootmap): + return Region.objects.filter( + parent__isnull=True + ).values( + "name", + root=F("name"), + ).union( + rootmap.join(Region, parent=rootmap.col.name).values( + "name", + root=rootmap.col.root, + ), + all=True, + ) + rootmap = With.recursive(make_root_mapping, name="rootmap") + + totals = With( + rootmap.join(Order, region_id=rootmap.col.name) + .values( + root=rootmap.col.root, + ).annotate( + orders_count=Count("id"), + region_total=Sum("amount"), + ), + name="totals", + ) + + root_regions = ( + totals.join(Region, name=totals.col.root) + .with_cte(rootmap) + .with_cte(totals) + .annotate( + # count of orders in this region and all subregions + orders_count=totals.col.orders_count, + # sum of order amounts in this region and all subregions + region_total=totals.col.region_total, + ) + ) + print(root_regions.query) + + data = sorted( + (r.name, r.orders_count, r.region_total) for r in root_regions + ) + self.assertEqual(data, [ + ('proxima centauri', 4, 2033), + ('sun', 18, 1374), + ]) + + def test_materialized_option(self): + totals = With( + Order.objects + .filter(region__parent="sun") + .values("region_id") + .annotate(total=Sum("amount")), + materialized=True + ) + orders = ( + totals + .join(Order, region=totals.col.region_id) + .with_cte(totals) + .annotate(region_total=totals.col.total) + .order_by("amount") + ) + self.assertTrue( + str(orders.query).startswith( + 'WITH RECURSIVE "cte" AS MATERIALIZED' + ) + ) + + def test_update_cte_query(self): + cte = With( + Order.objects + .values(region_parent=F("region__parent_id")) + .annotate(total=Sum("amount")) + .filter(total__isnull=False) + ) + # not the most efficient query, but it exercises CTEUpdateQuery + Order.objects.all().with_cte(cte).filter(region_id__in=Subquery( + cte.queryset() + .filter(region_parent=OuterRef("region_id")) + .values("region_parent") + )).update(amount=Subquery( + cte.queryset() + .filter(region_parent=OuterRef("region_id")) + .values("total") + )) + + data = set((o.region_id, o.amount) for o in Order.objects.filter( + region_id__in=["earth", "sun", "proxima centauri", "mars"] + )) + self.assertEqual(data, { + ('earth', 6), + ('mars', 40), + ('mars', 41), + ('mars', 42), + ('proxima centauri', 33), + ('sun', 368), + }) + + def test_update_with_subquery(self): + # Test for issue: https://github.com/dimagi/django-cte/issues/9 + # Issue is not reproduces on sqlite3 use postgres to run. + # To reproduce the problem it's required to have some join + # in the select-query so the compiler will turn it into a subquery. + # To add a join use a filter over field of related model + orders = Order.objects.filter(region__parent_id='sun') + orders.update(amount=0) + data = {(order.region_id, order.amount) for order in orders} + self.assertEqual(data, { + ('mercury', 0), + ('venus', 0), + ('earth', 0), + ('mars', 0), + }) + + def test_delete_cte_query(self): + raise SkipTest( + "this test will not work until `QuerySet.delete` (Django method) " + "calls `self.query.chain(sql.DeleteQuery)` instead of " + "`sql.DeleteQuery(self.model)`" + ) + cte = With( + Order.objects + .values(region_parent=F("region__parent_id")) + .annotate(total=Sum("amount")) + .filter(total__isnull=False) + ) + Order.objects.all().with_cte(cte).annotate( + cte_has_order=Exists( + cte.queryset() + .values("total") + .filter(region_parent=OuterRef("region_id")) + ) + ).filter(cte_has_order=False).delete() + + data = [(o.region_id, o.amount) for o in Order.objects.all()] + self.assertEqual(data, [ + ('sun', 1000), + ('earth', 30), + ('earth', 31), + ('earth', 32), + ('earth', 33), + ('proxima centauri', 2000), + ]) + + def test_outerref_in_cte_query(self): + # This query is meant to return the difference between min and max + # order of each region, through a subquery + min_and_max = With( + Order.objects + .filter(region=OuterRef("pk")) + .values('region') # This is to force group by region_id + .annotate( + amount_min=Min("amount"), + amount_max=Max("amount"), + ) + .values('amount_min', 'amount_max') + ) + regions = ( + Region.objects + .annotate( + difference=Subquery( + min_and_max.queryset().with_cte(min_and_max).annotate( + difference=ExpressionWrapper( + F('amount_max') - F('amount_min'), + output_field=int_field, + ), + ).values('difference')[:1], + output_field=IntegerField() + ) + ) + .order_by("name") + ) + print(regions.query) + + data = [(r.name, r.difference) for r in regions] + self.assertEqual(data, [ + ("bernard's star", None), + ('deimos', None), + ('earth', 3), + ('mars', 2), + ('mercury', 2), + ('moon', 2), + ('phobos', None), + ('proxima centauri', 0), + ('proxima centauri b', 2), + ('sun', 0), + ('venus', 3) + ]) + + def test_experimental_left_outer_join(self): + totals = With( + Order.objects + .values("region_id") + .annotate(total=Sum("amount")) + .filter(total__gt=100) + ) + orders = ( + totals + .join(Order, region=totals.col.region_id, _join_type=LOUTER) + .with_cte(totals) + .annotate(region_total=totals.col.total) + ) + print(orders.query) + self.assertIn("LEFT OUTER JOIN", str(orders.query)) + self.assertNotIn("INNER JOIN", str(orders.query)) + + data = sorted((o.region_id, o.amount, o.region_total) for o in orders) + self.assertEqual(data, [ + ('earth', 30, 126), + ('earth', 31, 126), + ('earth', 32, 126), + ('earth', 33, 126), + ('mars', 40, 123), + ('mars', 41, 123), + ('mars', 42, 123), + ('mercury', 10, None), + ('mercury', 11, None), + ('mercury', 12, None), + ('moon', 1, None), + ('moon', 2, None), + ('moon', 3, None), + ('proxima centauri', 2000, 2000), + ('proxima centauri b', 10, None), + ('proxima centauri b', 11, None), + ('proxima centauri b', 12, None), + ('sun', 1000, 1000), + ('venus', 20, None), + ('venus', 21, None), + ('venus', 22, None), + ('venus', 23, None), + ]) + + def test_non_cte_subquery(self): + """ + Verifies that subquery annotations are handled correctly when the + subquery model doesn't use the CTE manager, and the query results + match expected behavior + """ + self.assertNotIsInstance(User.objects, CTEManager) + + sub_totals = With( + Order.objects + .values(region_parent=F("region__parent_id")) + .annotate( + total=Sum("amount"), + # trivial subquery example testing existence of + # a user for the order + non_cte_subquery=Exists( + User.objects.filter(pk=OuterRef("user_id")) + ), + ), + ) + regions = ( + Region.objects.all() + .with_cte(sub_totals) + .annotate( + child_regions_total=Subquery( + sub_totals.queryset() + .filter(region_parent=OuterRef("name")) + .values("total"), + ), + ) + .order_by("name") + ) + print(regions.query) + + data = [(r.name, r.child_regions_total) for r in regions] + self.assertEqual(data, [ + ("bernard's star", None), + ('deimos', None), + ('earth', 6), + ('mars', None), + ('mercury', None), + ('moon', None), + ('phobos', None), + ('proxima centauri', 33), + ('proxima centauri b', None), + ('sun', 368), + ('venus', None) + ]) + + def test_explain(self): + """ + Verifies that using .explain() prepends the EXPLAIN clause in the + correct position + """ + + totals = With( + Order.objects + .filter(region__parent="sun") + .values("region_id") + .annotate(total=Sum("amount")), + name="totals", + ) + region_count = With( + Region.objects + .filter(parent="sun") + .values("parent_id") + .annotate(num=Count("name")), + name="region_count", + ) + orders = ( + region_count.join( + totals.join(Order, region=totals.col.region_id), + region__parent=region_count.col.parent_id + ) + .with_cte(totals) + .with_cte(region_count) + .annotate(region_total=totals.col.total) + .annotate(region_count=region_count.col.num) + .order_by("amount") + ) + print(orders.query) + + self.assertIsInstance(orders.explain(), str) + + def test_empty_result_set_cte(self): + """ + Verifies that the CTEQueryCompiler can handle empty result sets in the + related CTEs + """ + totals = With( + Order.objects + .filter(id__in=[]) + .values("region_id") + .annotate(total=Sum("amount")), + name="totals", + ) + orders = ( + totals.join(Order, region=totals.col.region_id) + .with_cte(totals) + .annotate(region_total=totals.col.total) + .order_by("amount") + ) + + self.assertEqual(len(orders), 0) + + def test_left_outer_join_on_empty_result_set_cte(self): + totals = With( + Order.objects + .filter(id__in=[]) + .values("region_id") + .annotate(total=Sum("amount")), + name="totals", + ) + orders = ( + totals.join(Order, region=totals.col.region_id, _join_type=LOUTER) + .with_cte(totals) + .annotate(region_total=totals.col.total) + .order_by("amount") + ) + + self.assertEqual(len(orders), 22) + + def test_union_query_with_cte(self): + orders = ( + Order.objects + .filter(region__parent="sun") + .only("region", "amount") + ) + orders_cte = With(orders, name="orders_cte") + orders_cte_queryset = orders_cte.queryset() + + earth_orders = orders_cte_queryset.filter(region="earth") + mars_orders = orders_cte_queryset.filter(region="mars") + + earth_mars = earth_orders.union(mars_orders, all=True) + earth_mars_cte = ( + earth_mars + .with_cte(orders_cte) + .order_by("region", "amount") + .values_list("region", "amount") + ) + print(earth_mars_cte.query) + + self.assertEqual(list(earth_mars_cte), [ + ('earth', 30), + ('earth', 31), + ('earth', 32), + ('earth', 33), + ('mars', 40), + ('mars', 41), + ('mars', 42), + ]) + + def test_cte_select_pk(self): + orders = Order.objects.filter(region="earth").values("pk") + cte = With(orders) + queryset = cte.join(orders, pk=cte.col.pk).with_cte(cte).order_by("pk") + print(queryset.query) + self.assertEqual(list(queryset), [ + {'pk': 9}, + {'pk': 10}, + {'pk': 11}, + {'pk': 12}, + ]) diff --git a/tests/test_v1/test_django.py b/tests/test_v1/test_django.py new file mode 100644 index 0000000..83c92a2 --- /dev/null +++ b/tests/test_v1/test_django.py @@ -0,0 +1,87 @@ +from unittest import SkipTest + +import django +from django.db import OperationalError, ProgrammingError +from django.db.models import Window +from django.db.models.functions import Rank +from django.test import TestCase, skipUnlessDBFeature + +from .models import Order, Region, User + + +@skipUnlessDBFeature("supports_select_union") +class NonCteQueries(TestCase): + """Test non-CTE queries + + These tests were adapted from the Django test suite. The models used + here use CTEManager and CTEQuerySet to verify feature parity with + their base classes Manager and QuerySet. + """ + + @classmethod + def setUpTestData(cls): + Order.objects.all().delete() + + def test_union_with_select_related_and_order(self): + e1 = User.objects.create(name="e1") + a1 = Order.objects.create(region_id="earth", user=e1) + a2 = Order.objects.create(region_id="moon", user=e1) + Order.objects.create(region_id="sun", user=e1) + base_qs = Order.objects.select_related("user").order_by() + qs1 = base_qs.filter(region_id="earth") + qs2 = base_qs.filter(region_id="moon") + print(qs1.union(qs2).order_by("pk").query) + self.assertSequenceEqual(qs1.union(qs2).order_by("pk"), [a1, a2]) + + @skipUnlessDBFeature("supports_slicing_ordering_in_compound") + def test_union_with_select_related_and_first(self): + e1 = User.objects.create(name="e1") + a1 = Order.objects.create(region_id="earth", user=e1) + Order.objects.create(region_id="moon", user=e1) + base_qs = Order.objects.select_related("user") + qs1 = base_qs.filter(region_id="earth") + qs2 = base_qs.filter(region_id="moon") + self.assertEqual(qs1.union(qs2).first(), a1) + + def test_union_with_first(self): + e1 = User.objects.create(name="e1") + a1 = Order.objects.create(region_id="earth", user=e1) + base_qs = Order.objects.order_by() + qs1 = base_qs.filter(region_id="earth") + qs2 = base_qs.filter(region_id="moon") + self.assertEqual(qs1.union(qs2).first(), a1) + + +class WindowFunctions(TestCase): + + def test_heterogeneous_filter_in_cte(self): + if django.VERSION < (4, 2): + raise SkipTest("feature added in Django 4.2") + from django_cte import With + cte = With( + Order.objects.annotate( + region_amount_rank=Window( + Rank(), partition_by="region_id", order_by="-amount" + ), + ) + .order_by("region_id") + .values("region_id", "region_amount_rank") + .filter(region_amount_rank=1, region_id__in=["sun", "moon"]) + ) + qs = cte.join(Region, name=cte.col.region_id).with_cte(cte) + print(qs.query) + # ProgrammingError: column cte.region_id does not exist + # WITH RECURSIVE "cte" AS (SELECT * FROM ( + # SELECT "orders"."region_id" AS "col1", ... + # "region" INNER JOIN "cte" ON "region"."name" = ("cte"."region_id") + try: + self.assertEqual({r.name for r in qs}, {"moon", "sun"}) + except (OperationalError, ProgrammingError) as err: + if "cte.region_id" in str(err): + raise SkipTest( + "window function auto-aliasing breaks CTE " + "column references" + ) + raise + if django.VERSION < (5, 2): + assert 0, "unexpected pass" diff --git a/tests/test_v1/test_manager.py b/tests/test_v1/test_manager.py new file mode 100644 index 0000000..be085f9 --- /dev/null +++ b/tests/test_v1/test_manager.py @@ -0,0 +1,111 @@ +from django.db.models.expressions import F +from django.db.models.query import QuerySet +from django.test import TestCase + +from django_cte import With, CTEQuerySet, CTEManager + +from .models import ( + Order, + OrderFromLT40, + OrderLT40AsManager, + OrderCustomManagerNQuery, + OrderCustomManager, + LT40QuerySet, + LTManager, + LT25QuerySet, +) + + +class TestCTE(TestCase): + def test_cte_queryset_correct_defaultmanager(self): + self.assertEqual(type(Order._default_manager), CTEManager) + self.assertEqual(type(Order.objects.all()), CTEQuerySet) + + def test_cte_queryset_correct_from_queryset(self): + self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) + + def test_cte_queryset_correct_queryset_as_manager(self): + self.assertEqual(type(OrderLT40AsManager.objects.all()), LT40QuerySet) + + def test_cte_queryset_correct_manager_n_from_queryset(self): + self.assertIsInstance( + OrderCustomManagerNQuery._default_manager, LTManager) + self.assertEqual(type( + OrderCustomManagerNQuery.objects.all()), LT25QuerySet) + + def test_cte_create_manager_from_non_cteQuery(self): + class BrokenQuerySet(QuerySet): + "This should be a CTEQuerySet if we want this to work" + + with self.assertRaises(TypeError): + CTEManager.from_queryset(BrokenQuerySet)() + + def test_cte_queryset_correct_limitedmanager(self): + self.assertEqual(type(OrderCustomManager._default_manager), LTManager) + # Check the expected even if not ideal behavior occurs + self.assertIsInstance(OrderCustomManager.objects.all(), CTEQuerySet) + + def test_cte_queryset_with_from_queryset(self): + self.assertEqual(type(OrderFromLT40.objects.all()), LT40QuerySet) + + cte = With( + OrderFromLT40.objects + .annotate(region_parent=F("region__parent_id")) + .filter(region__parent_id="sun") + ) + orders = ( + cte.queryset() + .with_cte(cte) + .lt40() # custom queryset method + .order_by("region_id", "amount") + ) + print(orders.query) + + data = [(x.region_id, x.amount, x.region_parent) for x in orders] + self.assertEqual(data, [ + ("earth", 30, "sun"), + ("earth", 31, "sun"), + ("earth", 32, "sun"), + ("earth", 33, "sun"), + ('mercury', 10, 'sun'), + ('mercury', 11, 'sun'), + ('mercury', 12, 'sun'), + ('venus', 20, 'sun'), + ('venus', 21, 'sun'), + ('venus', 22, 'sun'), + ('venus', 23, 'sun'), + ]) + + def test_cte_queryset_with_custom_queryset(self): + cte = With( + OrderCustomManagerNQuery.objects + .annotate(region_parent=F("region__parent_id")) + .filter(region__parent_id="sun") + ) + orders = ( + cte.queryset() + .with_cte(cte) + .lt25() # custom queryset method + .order_by("region_id", "amount") + ) + print(orders.query) + + data = [(x.region_id, x.amount, x.region_parent) for x in orders] + self.assertEqual(data, [ + ('mercury', 10, 'sun'), + ('mercury', 11, 'sun'), + ('mercury', 12, 'sun'), + ('venus', 20, 'sun'), + ('venus', 21, 'sun'), + ('venus', 22, 'sun'), + ('venus', 23, 'sun'), + ]) + + def test_cte_queryset_with_deferred_loading(self): + cte = With( + OrderCustomManagerNQuery.objects.order_by("id").only("id")[:1] + ) + orders = cte.queryset().with_cte(cte) + print(orders.query) + + self.assertEqual([x.id for x in orders], [1]) diff --git a/tests/test_v1/test_raw.py b/tests/test_v1/test_raw.py new file mode 100644 index 0000000..ade9774 --- /dev/null +++ b/tests/test_v1/test_raw.py @@ -0,0 +1,60 @@ +from django.db.models import IntegerField, TextField +from django.test import TestCase + +from django_cte import With +from django_cte.raw import raw_cte_sql + +from .models import Region + +int_field = IntegerField() +text_field = TextField() + + +class TestRawCTE(TestCase): + + def test_raw_cte_sql(self): + cte = With(raw_cte_sql( + """ + SELECT region_id, AVG(amount) AS avg_order + FROM orders + WHERE region_id = %s + GROUP BY region_id + """, + ["moon"], + {"region_id": text_field, "avg_order": int_field}, + )) + moon_avg = ( + cte + .join(Region, name=cte.col.region_id) + .annotate(avg_order=cte.col.avg_order) + .with_cte(cte) + ) + print(moon_avg.query) + + data = [(r.name, r.parent.name, r.avg_order) for r in moon_avg] + self.assertEqual(data, [('moon', 'earth', 2)]) + + def test_raw_cte_sql_name_escape(self): + cte = With( + raw_cte_sql( + """ + SELECT region_id, AVG(amount) AS avg_order + FROM orders + WHERE region_id = %s + GROUP BY region_id + """, + ["moon"], + {"region_id": text_field, "avg_order": int_field}, + ), + name="mixedCaseCTEName" + ) + moon_avg = ( + cte + .join(Region, name=cte.col.region_id) + .annotate(avg_order=cte.col.avg_order) + .with_cte(cte) + ) + self.assertTrue( + str(moon_avg.query).startswith( + 'WITH RECURSIVE "mixedCaseCTEName"') + ) diff --git a/tests/test_v1/test_recursive.py b/tests/test_v1/test_recursive.py new file mode 100644 index 0000000..6506d8f --- /dev/null +++ b/tests/test_v1/test_recursive.py @@ -0,0 +1,335 @@ +import pickle +from unittest import SkipTest + +from django.db.models import IntegerField, TextField +from django.db.models.expressions import ( + Case, + Exists, + ExpressionWrapper, + F, + OuterRef, + Q, + Value, + When, +) +from django.db.models.functions import Concat +from django.db.utils import DatabaseError +from django.test import TestCase + +from django_cte import With + +from .models import KeyPair, Region + +int_field = IntegerField() +text_field = TextField() + + +class TestRecursiveCTE(TestCase): + + def test_recursive_cte_query(self): + def make_regions_cte(cte): + return Region.objects.filter( + # non-recursive: get root nodes + parent__isnull=True + ).values( + "name", + path=F("name"), + depth=Value(0, output_field=int_field), + ).union( + # recursive union: get descendants + cte.join(Region, parent=cte.col.name).values( + "name", + path=Concat( + cte.col.path, Value(" / "), F("name"), + output_field=text_field, + ), + depth=cte.col.depth + Value(1, output_field=int_field), + ), + all=True, + ) + + cte = With.recursive(make_regions_cte) + + regions = ( + cte.join(Region, name=cte.col.name) + .with_cte(cte) + .annotate( + path=cte.col.path, + depth=cte.col.depth, + ) + .filter(depth=2) + .order_by("path") + ) + print(regions.query) + + data = [(r.name, r.path, r.depth) for r in regions] + self.assertEqual(data, [ + ('moon', 'sun / earth / moon', 2), + ('deimos', 'sun / mars / deimos', 2), + ('phobos', 'sun / mars / phobos', 2), + ]) + + def test_recursive_cte_reference_in_condition(self): + def make_regions_cte(cte): + return Region.objects.filter( + parent__isnull=True + ).values( + "name", + path=F("name"), + depth=Value(0, output_field=int_field), + is_planet=Value(0, output_field=int_field), + ).union( + cte.join( + Region, parent=cte.col.name + ).annotate( + # annotations for filter and CASE/WHEN conditions + parent_name=ExpressionWrapper( + cte.col.name, + output_field=text_field, + ), + parent_depth=ExpressionWrapper( + cte.col.depth, + output_field=int_field, + ), + ).filter( + ~Q(parent_name="mars"), + ).values( + "name", + path=Concat( + cte.col.path, Value("\x01"), F("name"), + output_field=text_field, + ), + depth=cte.col.depth + Value(1, output_field=int_field), + is_planet=Case( + When(parent_depth=0, then=Value(1)), + default=Value(0), + output_field=int_field, + ), + ), + all=True, + ) + cte = With.recursive(make_regions_cte) + regions = cte.join(Region, name=cte.col.name).with_cte(cte).annotate( + path=cte.col.path, + depth=cte.col.depth, + is_planet=cte.col.is_planet, + ).order_by("path") + + data = [(r.path.split("\x01"), r.is_planet) for r in regions] + print(data) + self.assertEqual(data, [ + (["bernard's star"], 0), + (['proxima centauri'], 0), + (['proxima centauri', 'proxima centauri b'], 1), + (['sun'], 0), + (['sun', 'earth'], 1), + (['sun', 'earth', 'moon'], 0), + (['sun', 'mars'], 1), # mars moons excluded: parent_name != 'mars' + (['sun', 'mercury'], 1), + (['sun', 'venus'], 1), + ]) + + def test_recursive_cte_with_empty_union_part(self): + def make_regions_cte(cte): + return Region.objects.none().union( + cte.join(Region, parent=cte.col.name), + all=True, + ) + cte = With.recursive(make_regions_cte) + regions = cte.join(Region, name=cte.col.name).with_cte(cte) + + print(regions.query) + try: + self.assertEqual(regions.count(), 0) + except DatabaseError: + raise SkipTest( + "Expected failure: QuerySet omits `EmptyQuerySet` from " + "UNION queries resulting in invalid CTE SQL" + ) + + # -- recursive query "cte" does not have the form + # -- non-recursive-term UNION [ALL] recursive-term + # WITH RECURSIVE cte AS ( + # SELECT "tests_region"."name", "tests_region"."parent_id" + # FROM "tests_region", "cte" + # WHERE "tests_region"."parent_id" = ("cte"."name") + # ) + # SELECT COUNT(*) + # FROM "tests_region", "cte" + # WHERE "tests_region"."name" = ("cte"."name") + + def test_circular_ref_error(self): + def make_bad_cte(cte): + # NOTE: not a valid recursive CTE query + return cte.join(Region, parent=cte.col.name).values( + depth=cte.col.depth + 1, + ) + cte = With.recursive(make_bad_cte) + regions = cte.join(Region, name=cte.col.name).with_cte(cte) + with self.assertRaises(ValueError) as context: + print(regions.query) + self.assertIn("Circular reference:", str(context.exception)) + + def test_attname_should_not_mask_col_name(self): + def make_regions_cte(cte): + return Region.objects.filter( + name="moon" + ).values( + "name", + "parent_id", + ).union( + cte.join(Region, name=cte.col.parent_id).values( + "name", + "parent_id", + ), + all=True, + ) + cte = With.recursive(make_regions_cte) + regions = ( + Region.objects.all() + .with_cte(cte) + .annotate(_ex=Exists( + cte.queryset() + .values(value=Value("1", output_field=int_field)) + .filter(name=OuterRef("name")) + )) + .filter(_ex=True) + .order_by("name") + ) + print(regions.query) + + data = [r.name for r in regions] + self.assertEqual(data, ['earth', 'moon', 'sun']) + + def test_pickle_recursive_cte_queryset(self): + def make_regions_cte(cte): + return Region.objects.filter( + parent__isnull=True + ).annotate( + depth=Value(0, output_field=int_field), + ).union( + cte.join(Region, parent=cte.col.name).annotate( + depth=cte.col.depth + Value(1, output_field=int_field), + ), + all=True, + ) + cte = With.recursive(make_regions_cte) + regions = cte.queryset().with_cte(cte).filter(depth=2).order_by("name") + + pickled_qs = pickle.loads(pickle.dumps(regions)) + + data = [(r.name, r.depth) for r in pickled_qs] + self.assertEqual(data, [(r.name, r.depth) for r in regions]) + self.assertEqual(data, [('deimos', 2), ('moon', 2), ('phobos', 2)]) + + def test_alias_change_in_annotation(self): + def make_regions_cte(cte): + return Region.objects.filter( + parent__name="sun", + ).annotate( + value=F('name'), + ).union( + cte.join( + Region.objects.all().annotate( + value=F('name'), + ), + parent_id=cte.col.name, + ), + all=True, + ) + cte = With.recursive(make_regions_cte) + query = cte.queryset().with_cte(cte) + + exclude_leaves = With(cte.queryset().filter( + parent__name='sun', + ).annotate( + value=Concat(F('name'), F('name')) + ), name='value_cte') + + query = query.annotate( + _exclude_leaves=Exists( + exclude_leaves.queryset().filter( + name=OuterRef("name"), + value=OuterRef("value"), + ) + ) + ).filter(_exclude_leaves=True).with_cte(exclude_leaves) + print(query.query) + + # Nothing should be returned. + self.assertFalse(query) + + def test_alias_as_subquery(self): + # This test covers CTEColumnRef.relabeled_clone + def make_regions_cte(cte): + return KeyPair.objects.filter( + parent__key="level 1", + ).annotate( + rank=F('value'), + ).union( + cte.join( + KeyPair.objects.all().order_by(), + parent_id=cte.col.id, + ).annotate( + rank=F('value'), + ), + all=True, + ) + cte = With.recursive(make_regions_cte) + children = cte.queryset().with_cte(cte) + + xdups = With(cte.queryset().filter( + parent__key="level 1", + ).annotate( + rank=F('value') + ).values('id', 'rank'), name='xdups') + + children = children.annotate( + _exclude=Exists( + ( + xdups.queryset().filter( + id=OuterRef("id"), + rank=OuterRef("rank"), + ) + ) + ) + ).filter(_exclude=True).with_cte(xdups) + + print(children.query) + query = KeyPair.objects.filter(parent__in=children) + print(query.query) + print(children.query) + self.assertEqual(query.get().key, 'level 3') + # Tests the case in which children's query was modified since it was + # used in a subquery to define `query` above. + self.assertEqual( + list(c.key for c in children), + ['level 2', 'level 2'] + ) + + def test_materialized(self): + # This test covers MATERIALIZED option in SQL query + def make_regions_cte(cte): + return KeyPair.objects.all() + cte = With.recursive(make_regions_cte, materialized=True) + + query = KeyPair.objects.with_cte(cte) + print(query.query) + self.assertTrue( + str(query.query).startswith('WITH RECURSIVE "cte" AS MATERIALIZED') + ) + + def test_recursive_self_queryset(self): + def make_regions_cte(cte): + return Region.objects.filter( + pk="earth" + ).values("pk").union( + cte.join(Region, parent=cte.col.pk).values("pk") + ) + cte = With.recursive(make_regions_cte) + queryset = cte.queryset().with_cte(cte).order_by("pk") + print(queryset.query) + self.assertEqual(list(queryset), [ + {'pk': 'earth'}, + {'pk': 'moon'}, + ])