diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index 287b7ae9..2bd7f099 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -179,13 +179,18 @@ def _check_spec_recursively( subpath = f"{path}.{key}" errors.extend(self._check_spec_recursively(value, subpath)) elif value not in KeyTransform.SPEC_MAP: + valid_names = ", ".join( + sorted(x.__name__ for x in KeyTransform.SPEC_MAP.keys()) + ) errors.append( checks.Error( "The value for '{}' in 'spec{}' is not an allowed type".format( key, path ), - hint="'spec' values must be one of the following " - "types: {}".format(KeyTransform.SPEC_MAP_NAMES), + hint=( + "'spec' values must be one of the following types: " + + valid_names + ), obj=self, id="django_mysql.E011", ) @@ -306,10 +311,8 @@ class KeyTransform(Transform): dict: "BINARY", } - SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys())) - - TYPE_MAP: dict[str, type[Field] | Field] = { - "BINARY": DynamicField, + TYPE_MAP: dict[str, Field[Any, Any]] = { + # Excludes BINARY -> DynamicField as that’s requires spec "CHAR": TextField(), "DATE": DateField(), "DATETIME": DateTimeField(), @@ -322,23 +325,22 @@ def __init__( self, key_name: str, data_type: str, - *args: Any, + *expressions: Any, subspec: SpecDict | None = None, - **kwargs: Any, ) -> None: - super().__init__(*args, **kwargs) - self.key_name = key_name - self.data_type = data_type - - try: - output_field = self.TYPE_MAP[data_type] - except KeyError: # pragma: no cover - raise ValueError(f"Invalid data_type '{data_type}'") - + output_field: Field[Any, Any] if data_type == "BINARY": - self.output_field = output_field(spec=subspec) + output_field = DynamicField(spec=subspec) else: - self.output_field = output_field + try: + output_field = self.TYPE_MAP[data_type] + except KeyError: + raise ValueError(f"Invalid data_type {data_type!r}") + + super().__init__(*expressions, output_field=output_field) + + self.key_name = key_name + self.data_type = data_type def as_sql( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper diff --git a/src/django_mysql/models/functions.py b/src/django_mysql/models/functions.py index e5173e21..23db10c8 100644 --- a/src/django_mysql/models/functions.py +++ b/src/django_mysql/models/functions.py @@ -441,20 +441,16 @@ class AsType(Func): template = "%(expressions)s AS %(data_type)s" def __init__(self, expression: ExpressionArgument, data_type: str) -> None: + from django_mysql.models.fields.dynamic import KeyTransform + if not hasattr(expression, "resolve_expression"): expression = Value(expression) - if data_type not in self.TYPE_MAP: + if data_type not in KeyTransform.TYPE_MAP and data_type != "BINARY": raise ValueError(f"Invalid data_type '{data_type}'") super().__init__(expression, data_type=data_type) - @property - def TYPE_MAP(self) -> dict[str, type[DjangoField] | DjangoField]: - from django_mysql.models.fields.dynamic import KeyTransform - - return KeyTransform.TYPE_MAP - class ColumnAdd(Func): function = "COLUMN_ADD" @@ -508,25 +504,22 @@ def __init__( self, expression: ExpressionArgument, column_name: ExpressionArgument, - data_type: ExpressionArgument, + data_type: str, ): + from django_mysql.models.fields.dynamic import DynamicField, KeyTransform + if not hasattr(column_name, "resolve_expression"): column_name = Value(column_name) - try: - output_field = self.TYPE_MAP[data_type] - except KeyError: - raise ValueError(f"Invalid data_type '{data_type}'") - + output_field: DjangoField[Any, Any] if data_type == "BINARY": - output_field = output_field() + output_field = DynamicField() + else: + try: + output_field = KeyTransform.TYPE_MAP[data_type] + except KeyError: + raise ValueError(f"Invalid data_type {data_type!r}") super().__init__( expression, column_name, output_field=output_field, data_type=data_type ) - - @property - def TYPE_MAP(self) -> dict[str, DjangoField | type[DjangoField]]: - from django_mysql.models.fields.dynamic import KeyTransform - - return KeyTransform.TYPE_MAP diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index 95c885b9..d324de6f 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -15,6 +15,7 @@ from django.test.utils import isolate_apps from django_mysql.models import DynamicField +from django_mysql.models.fields.dynamic import KeyTransform from tests.testapp.models import DynamicModel, SpeclessDynamicModel @@ -159,6 +160,12 @@ def test_non_existent_transform(self): def test_has_key(self): assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3] + def test_key_transform_initialize_bad_type(self): + with pytest.raises(ValueError) as excinfo: + KeyTransform("x", "unknown") + + assert str(excinfo.value) == "Invalid data_type 'unknown'" + def test_key_transform_datey(self): assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [ self.objs[4]