-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathoperations.py
275 lines (237 loc) · 10.6 KB
/
operations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import datetime
import json
import re
import uuid
from decimal import Decimal
from bson.decimal128 import Decimal128
from django.conf import settings
from django.db import DataError
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import TextField
from django.db.models.expressions import Combinable, Expression
from django.db.models.functions import Cast
from django.utils import timezone
from django.utils.regex_helper import _lazy_re_compile
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django_mongodb_backend.compiler"
combine_operators = {
Combinable.ADD: "add",
Combinable.SUB: "subtract",
Combinable.MUL: "multiply",
Combinable.DIV: "divide",
Combinable.POW: "pow",
Combinable.MOD: "mod",
Combinable.BITAND: "bitAnd",
Combinable.BITOR: "bitOr",
Combinable.BITXOR: "bitXor",
}
explain_options = {"comment", "verbosity"}
def adapt_datefield_value(self, value):
"""Store DateField as datetime."""
if value is None:
return None
return datetime.datetime.combine(value, datetime.datetime.min.time())
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"MongoDB backend does not support timezone-aware "
"datetimes when USE_TZ is False."
)
return value
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
"""Store DecimalField as Decimal128."""
if value is None:
return None
return Decimal128(value)
def adapt_json_value(self, value, encoder):
if encoder is None:
return value
try:
return json.loads(json.dumps(value, cls=encoder))
except json.decoder.JSONDecodeError as e:
raise DataError from e
def adapt_timefield_value(self, value):
"""Store TimeField as datetime."""
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
return value
if timezone.is_aware(value):
raise ValueError("MongoDB backend does not support timezone-aware times.")
return datetime.datetime.combine(datetime.datetime.min.date(), value)
def _get_arrayfield_converter(self, converter, *args, **kwargs):
# Return a database converter that can be applied to a list of values.
def convert_value(value, expression, connection):
return [converter(x, expression, connection) for x in value]
return convert_value
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "ArrayField":
converters.extend(
[
self._get_arrayfield_converter(converter)
for converter in self.get_db_converters(
Expression(output_field=expression.output_field.base_field)
)
]
)
elif internal_type == "DateField":
converters.append(self.convert_datefield_value)
elif internal_type == "DateTimeField":
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == "DecimalField":
converters.append(self.convert_decimalfield_value)
elif internal_type == "JSONField":
converters.append(self.convert_jsonfield_value)
elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
elif internal_type == "UUIDField":
converters.append(self.convert_uuidfield_value)
return converters
def convert_datefield_value(self, value, expression, connection):
if value is not None:
value = value.date()
return value
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_decimalfield_value(self, value, expression, connection):
if value is not None:
# from Decimal128 to decimal.Decimal()
try:
value = value.to_decimal()
except AttributeError:
# `value` could be an integer in the case of an annotation
# like ExpressionWrapper(Value(1), output_field=DecimalField().
return Decimal(value)
return value
def convert_durationfield_value(self, value, expression, connection):
if value is not None:
try:
value = datetime.timedelta(milliseconds=value)
except TypeError:
# `value` could be Decimal128 if doing a computation with
# DurationField and Decimal128.
value = datetime.timedelta(milliseconds=int(str(value)))
return value
def convert_jsonfield_value(self, value, expression, connection):
"""
Convert dict data to a string so that JSONField.from_db_value() can
decode it using json.loads().
"""
return json.dumps(value)
def convert_timefield_value(self, value, expression, connection):
if value is not None:
value = value.time()
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def combine_expression(self, connector, sub_expressions):
lhs, rhs = sub_expressions
if connector == Combinable.BITLEFTSHIFT:
return {"$multiply": [lhs, {"$pow": [2, rhs]}]}
if connector == Combinable.BITRIGHTSHIFT:
return {"$floor": {"$divide": [lhs, {"$pow": [2, rhs]}]}}
operator = self.combine_operators[connector]
return {f"${operator}": sub_expressions}
def prep_for_like_query(self, x):
"""Escape "x" for $regexMatch queries."""
return re.escape(x)
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return name
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
"""
Return a list of the tables which will be passed as argument to
execute_sql_flush().
"""
return tables
def execute_sql_flush(self, tables):
for table in tables:
if table.startswith("system."):
# Do not drop system collections.
continue
collection = self.connection.database[table]
options = collection.options()
if not options.get("capped", False):
collection.delete_many({})
def explain_query_prefix(self, format=None, **options):
# Validate options.
validated_options = {}
if options:
for valid_option in self.explain_options:
value = options.pop(valid_option, None)
if value is not None:
validated_options[valid_option] = value
# super() raises an error if any options are left after the valid ones
# are popped above.
super().explain_query_prefix(format, **options)
return validated_options
def integer_field_range(self, internal_type):
# MongODB doesn't enforce any integer constraints, but it supports
# integers up to 64 bits.
if internal_type in {
"PositiveBigIntegerField",
"PositiveIntegerField",
"PositiveSmallIntegerField",
}:
return (0, 9223372036854775807)
return (-9223372036854775808, 9223372036854775807)
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr, rhs_expr = super().prepare_join_on_clause(
lhs_table, lhs_field, rhs_table, rhs_field
)
# If the types are different, cast both to string.
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
if lhs_field.db_type(self.connection) != "string":
lhs_expr = Cast(lhs_expr, output_field=TextField())
if rhs_field.db_type(self.connection) != "string":
rhs_expr = Cast(rhs_expr, output_field=TextField())
return lhs_expr, rhs_expr
"""Django uses these methods to generate SQL queries before it generates MQL queries."""
# EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7.
return f"EXTRACT(DOW FROM {sql}) + 1", params
if lookup_type == "iso_week_day":
return f"EXTRACT(ISODOW FROM {sql})", params
if lookup_type == "iso_year":
return f"EXTRACT(ISOYEAR FROM {sql})", params
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid lookup type: {lookup_type!r}")
return f"EXTRACT({lookup_type} FROM {sql})", params
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
if lookup_type == "second":
# Truncate fractional seconds.
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def datetime_cast_date_sql(self, sql, params, tzname):
return f"({sql})::date", params
def datetime_cast_time_sql(self, sql, params, tzname):
return f"({sql})::time", params
def format_for_duration_arithmetic(self, sql):
return "INTERVAL %s MILLISECOND" % sql
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)