Skip to content

Commit

Permalink
Improve KeyTransform initializer and types (#940)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamchainz authored Aug 28, 2022
1 parent 5019529 commit fb9322b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 39 deletions.
40 changes: 21 additions & 19 deletions src/django_mysql/models/fields/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,18 @@ def _check_spec_recursively(
subpath = f"{path}.{key}"
errors.extend(self._check_spec_recursively(value, subpath))
elif value not in KeyTransform.SPEC_MAP:
valid_names = ", ".join(
sorted(x.__name__ for x in KeyTransform.SPEC_MAP.keys())
)
errors.append(
checks.Error(
"The value for '{}' in 'spec{}' is not an allowed type".format(
key, path
),
hint="'spec' values must be one of the following "
"types: {}".format(KeyTransform.SPEC_MAP_NAMES),
hint=(
"'spec' values must be one of the following types: "
+ valid_names
),
obj=self,
id="django_mysql.E011",
)
Expand Down Expand Up @@ -306,10 +311,8 @@ class KeyTransform(Transform):
dict: "BINARY",
}

SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys()))

TYPE_MAP: dict[str, type[Field] | Field] = {
"BINARY": DynamicField,
TYPE_MAP: dict[str, Field[Any, Any]] = {
# Excludes BINARY -> DynamicField as that’s requires spec
"CHAR": TextField(),
"DATE": DateField(),
"DATETIME": DateTimeField(),
Expand All @@ -322,23 +325,22 @@ def __init__(
self,
key_name: str,
data_type: str,
*args: Any,
*expressions: Any,
subspec: SpecDict | None = None,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.key_name = key_name
self.data_type = data_type

try:
output_field = self.TYPE_MAP[data_type]
except KeyError: # pragma: no cover
raise ValueError(f"Invalid data_type '{data_type}'")

output_field: Field[Any, Any]
if data_type == "BINARY":
self.output_field = output_field(spec=subspec)
output_field = DynamicField(spec=subspec)
else:
self.output_field = output_field
try:
output_field = self.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type {data_type!r}")

super().__init__(*expressions, output_field=output_field)

self.key_name = key_name
self.data_type = data_type

def as_sql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
Expand Down
33 changes: 13 additions & 20 deletions src/django_mysql/models/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,20 +441,16 @@ class AsType(Func):
template = "%(expressions)s AS %(data_type)s"

def __init__(self, expression: ExpressionArgument, data_type: str) -> None:
from django_mysql.models.fields.dynamic import KeyTransform

if not hasattr(expression, "resolve_expression"):
expression = Value(expression)

if data_type not in self.TYPE_MAP:
if data_type not in KeyTransform.TYPE_MAP and data_type != "BINARY":
raise ValueError(f"Invalid data_type '{data_type}'")

super().__init__(expression, data_type=data_type)

@property
def TYPE_MAP(self) -> dict[str, type[DjangoField] | DjangoField]:
from django_mysql.models.fields.dynamic import KeyTransform

return KeyTransform.TYPE_MAP


class ColumnAdd(Func):
function = "COLUMN_ADD"
Expand Down Expand Up @@ -508,25 +504,22 @@ def __init__(
self,
expression: ExpressionArgument,
column_name: ExpressionArgument,
data_type: ExpressionArgument,
data_type: str,
):
from django_mysql.models.fields.dynamic import DynamicField, KeyTransform

if not hasattr(column_name, "resolve_expression"):
column_name = Value(column_name)

try:
output_field = self.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type '{data_type}'")

output_field: DjangoField[Any, Any]
if data_type == "BINARY":
output_field = output_field()
output_field = DynamicField()
else:
try:
output_field = KeyTransform.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type {data_type!r}")

super().__init__(
expression, column_name, output_field=output_field, data_type=data_type
)

@property
def TYPE_MAP(self) -> dict[str, DjangoField | type[DjangoField]]:
from django_mysql.models.fields.dynamic import KeyTransform

return KeyTransform.TYPE_MAP
7 changes: 7 additions & 0 deletions tests/testapp/test_dynamicfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.test.utils import isolate_apps

from django_mysql.models import DynamicField
from django_mysql.models.fields.dynamic import KeyTransform
from tests.testapp.models import DynamicModel, SpeclessDynamicModel


Expand Down Expand Up @@ -159,6 +160,12 @@ def test_non_existent_transform(self):
def test_has_key(self):
assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3]

def test_key_transform_initialize_bad_type(self):
with pytest.raises(ValueError) as excinfo:
KeyTransform("x", "unknown")

assert str(excinfo.value) == "Invalid data_type 'unknown'"

def test_key_transform_datey(self):
assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [
self.objs[4]
Expand Down

0 comments on commit fb9322b

Please sign in to comment.