From 24116a9b8201996724ee1f350c0874773e231407 Mon Sep 17 00:00:00 2001 From: "victor.lee" Date: Fri, 27 Jul 2018 19:47:19 -0700 Subject: [PATCH] table alias bugfix from https://github.com/adamchainz/django-mysql/pull/286 --- django_mysql/models/query.py | 106 ++++++++++++++++++++++------------ django_mysql/rewrite_query.py | 18 ++++-- 2 files changed, 82 insertions(+), 42 deletions(-) diff --git a/django_mysql/models/query.py b/django_mysql/models/query.py index 6f802b05..146bdf91 100644 --- a/django_mysql/models/query.py +++ b/django_mysql/models/query.py @@ -8,7 +8,7 @@ from subprocess import PIPE, Popen from django.db import connections, models -from django.db.models.sql.where import ExtraWhere +from django.db.models.sql.where import AND, ExtraWhere from django.db.transaction import atomic from django.test.utils import CaptureQueriesContext from django.utils import six @@ -155,44 +155,15 @@ def ignore_index(self, *index_names, **kwargs): return self._index_hint(*index_names, **kwargs) def _index_hint(self, *index_names, **kwargs): - hint = kwargs.pop('hint') - table_name = kwargs.pop('table_name', None) - for_ = kwargs.pop('for_', None) - if kwargs: + kwargs.setdefault('table_name', self.model._meta.db_table) + if set(kwargs.keys()) - {'table_name', 'for_', 'hint'}: raise ValueError( - "{}_index accepts only 'for_' and 'table_name' as keyword " - "arguments" - .format(hint.lower()) + "{}_index accepts only 'for_' and 'table_name' as " + "keyword arguments" + .format(kwargs['hint'].lower()) ) - - if hint != 'USE' and not len(index_names): - raise ValueError( - "{}_index requires at least one index name" - .format(hint.lower()) - ) - - if table_name is None: - table_name = self.model._meta.db_table - - if for_ in ('JOIN', 'ORDER BY', 'GROUP BY'): - for_bit = 'FOR {} '.format(for_) - elif for_ is None: - for_bit = '' - else: - raise ValueError("for_ must be one of: None, 'JOIN', 'ORDER BY', " - "'GROUP BY'") - - if len(index_names) == 0: - indexes = "NONE" - else: - indexes = "`" + "`,`".join(index_names) + "`" - - hint = ( - "/*QueryRewrite':index=`{table_name}` {hint} {for_bit}{indexes}*/1" - .format(table_name=table_name, hint=hint, for_bit=for_bit, - indexes=indexes) - ) - return self.extra(where=[hint]) + hint = IndexHint(index_names, **kwargs) + return self.extra(hints=(hint, )) # Features handled by extra classes/functions @@ -217,6 +188,17 @@ def pt_visual_explain(self, display=True): def handler(self): return Handler(self) + def extra(self, select=None, where=None, params=None, tables=None, + order_by=None, select_params=None, hints=None): + clone = super(QuerySetMixin, self).extra( + select, where, params, tables, order_by, select_params) + if hints: + for hint in hints: + if not isinstance(hint, IndexHint): + raise ValueError("hint should be instance of IndexHint") + clone.query.where.add(hint, AND) + return clone + class QuerySet(QuerySetMixin, models.QuerySet): pass @@ -608,3 +590,53 @@ def pt_visual_explain(queryset, display=True): print(explanation) else: return explanation + + +class IndexHint(object): + contains_aggregate = False + + def __init__(self, index_names, table_name, hint, for_=None, alias=None): + if hint != 'USE' and not len(index_names): + raise ValueError( + "{}_index requires at least one index name" + .format(hint.lower())) + self.hint = hint + + self.for_ = for_ + if self.for_ not in ('JOIN', 'ORDER BY', 'GROUP BY') \ + and self.for_ is not None: + raise ValueError("for_ must be one of: None, 'JOIN', 'ORDER BY', " + "'GROUP BY'") + + self.table_name = table_name + self.indexes = index_names + self.alias = alias or self.table_name + + def clone(self): + return type(self)( + self.indexes, self.table_name, self.hint, self.for_, self.alias) + + def __str__(self): + alias_str = "" + if self.alias and self.alias != self.table_name: + alias_str = " AS `{}`".format(self.alias) + indexes = "NONE" + if len(self.indexes) > 0: + indexes = "`{}`".format("`,`".join(self.indexes)) + for_bit = "" + if self.for_ is not None: + for_bit = "FOR {} " .format(self.for_) + return ( + "/*QueryRewrite':index=`{table_name}`{alias_str} {hint} " + "{for_bit}{indexes}*/1" + .format(table_name=self.table_name, hint=self.hint, + for_bit=for_bit, indexes=indexes, alias_str=alias_str) + ) + + def as_sql(self, compiler, connection): + return str(self), [] + + def relabeled_clone(self, change_map): + clone = self.clone() + clone.alias = change_map.get(self.alias, self.alias) + return clone diff --git a/django_mysql/rewrite_query.py b/django_mysql/rewrite_query.py index 0a379c76..4c9978a9 100644 --- a/django_mysql/rewrite_query.py +++ b/django_mysql/rewrite_query.py @@ -25,6 +25,7 @@ r""" index= (?P`[^`]+`) + (?:\ AS\ `(?P[^`]+)`)? \ # space (?PUSE|IGNORE|FORCE) \ # space @@ -58,6 +59,7 @@ def rewrite_query(sql): index_match.group('rule'), index_match.group('index_names'), index_match.group('for_what'), + index_match.group('alias'), )) # Silently fail on unrecognized rewrite requests @@ -167,18 +169,23 @@ def modify_sql(sql, add_comments, add_hints, add_index_hints): table_spec_re_template = r''' \b(?PFROM|JOIN) \s+ - {table_name} + {table_name}{optional_alias} \s+ ''' replacement_template = ( - r'\g {table_name} ' + r'\g {table_name}{optional_alias} ' r'{rule} INDEX {for_section}({index_names}) ' ) -def modify_sql_index_hints(sql, table_name, rule, index_names, for_what): - table_spec_re = table_spec_re_template.format(table_name=table_name) +def modify_sql_index_hints( + sql, table_name, rule, index_names, for_what, alias): + alias_re = '' + if alias: + alias_re = '\s+{}'.format(re.escape(alias)) + table_spec_re = table_spec_re_template.format( + table_name=re.escape(table_name), optional_alias=alias_re) if for_what: for_section = 'FOR {} '.format(for_what) else: @@ -187,6 +194,7 @@ def modify_sql_index_hints(sql, table_name, rule, index_names, for_what): table_name=table_name, rule=rule, for_section=for_section, - index_names=('' if index_names == 'NONE' else index_names) + index_names=('' if index_names == 'NONE' else index_names), + optional_alias=' {}'.format(alias) if alias else '' ) return re.sub(table_spec_re, replacement, sql, count=1, flags=re.VERBOSE)