|
| 1 | +from datetime import datetime |
| 2 | + |
| 3 | +from django.conf import settings |
1 | 4 | from django.db import NotSupportedError
|
| 5 | +from django.db.models import DateField, DateTimeField, TimeField |
2 | 6 | from django.db.models.expressions import Func
|
3 | 7 | from django.db.models.functions import JSONArray
|
4 | 8 | from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
|
@@ -196,6 +200,33 @@ def trunc(self, compiler, connection):
|
196 | 200 | return {"$dateTrunc": lhs_mql}
|
197 | 201 |
|
198 | 202 |
|
| 203 | +def trunc_convert_value(self, value, expression, connection): |
| 204 | + if connection.vendor == "mongodb": |
| 205 | + # A custom TruncBase.convert_value() for MongoDB. |
| 206 | + if value is None: |
| 207 | + return None |
| 208 | + convert_to_tz = settings.USE_TZ and self.get_tzname() != "UTC" |
| 209 | + if isinstance(self.output_field, DateTimeField): |
| 210 | + if convert_to_tz: |
| 211 | + # Unlike other databases, MongoDB returns the value in UTC, |
| 212 | + # so rather than setting the time zone equal to self.tzinfo, |
| 213 | + # the value must be converted to tzinfo. |
| 214 | + value = value.astimezone(self.tzinfo) |
| 215 | + elif isinstance(value, datetime): |
| 216 | + if isinstance(self.output_field, DateField): |
| 217 | + if convert_to_tz: |
| 218 | + value = value.astimezone(self.tzinfo) |
| 219 | + # Truncate for Trunc(..., output_field=DateField) |
| 220 | + value = value.date() |
| 221 | + elif isinstance(self.output_field, TimeField): |
| 222 | + if convert_to_tz: |
| 223 | + value = value.astimezone(self.tzinfo) |
| 224 | + # Truncate for Trunc(..., output_field=TimeField) |
| 225 | + value = value.time() |
| 226 | + return value |
| 227 | + return self.convert_value(value, expression, connection) |
| 228 | + |
| 229 | + |
199 | 230 | def trunc_date(self, compiler, connection):
|
200 | 231 | # Cast to date rather than truncate to date.
|
201 | 232 | lhs_mql = process_lhs(self, compiler, connection)
|
@@ -256,6 +287,7 @@ def register_functions():
|
256 | 287 | Substr.as_mql = substr
|
257 | 288 | Trim.as_mql = trim("trim")
|
258 | 289 | TruncBase.as_mql = trunc
|
| 290 | + TruncBase.convert_value = trunc_convert_value |
259 | 291 | TruncDate.as_mql = trunc_date
|
260 | 292 | TruncTime.as_mql = trunc_time
|
261 | 293 | Upper.as_mql = preserve_null("toUpper")
|
0 commit comments