Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/implicitdict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from dataclasses import dataclass
from datetime import datetime as datetime_type
from types import UnionType
from typing import Literal, Optional, Union, get_args, get_origin, get_type_hints # pyright:ignore[reportDeprecated]
from typing import ( # pyright:ignore[reportDeprecated]
Literal,
Optional,
Self,
Union,
get_args,
get_origin,
get_type_hints,
)

import arrow
import pytimeparse
Expand Down Expand Up @@ -101,7 +109,7 @@ def parse(cls, source: dict, parse_type: type):
if key in hints:
# This entry has an explicit type
try:
kwargs[key] = _parse_value(value, hints[key])
kwargs[key] = _parse_value(value, hints[key], parse_type)
except _PARSING_ERRORS as e:
raise _bubble_up_parse_error(e, key)
else:
Expand Down Expand Up @@ -175,7 +183,7 @@ def has_field_with_value(self, field_name: str) -> bool:
return field_name in self and self[field_name] is not None


def _parse_value(value, value_type: type):
def _parse_value(value, value_type: type, root_type: type):
generic_type = get_origin(value_type)
if generic_type:
# Type is generic
Expand All @@ -192,7 +200,7 @@ def _parse_value(value, value_type: type):
result = []
for i, v in enumerate(value_list):
try:
result.append(_parse_value(v, arg_types[0]))
result.append(_parse_value(v, arg_types[0], root_type))
except _PARSING_ERRORS as e:
raise _bubble_up_parse_error(e, f"[{i}]")
return result
Expand All @@ -201,9 +209,9 @@ def _parse_value(value, value_type: type):
# value is a dict of some kind
result = {}
for k, v in value.items():
parsed_key = k if arg_types[0] is str else _parse_value(k, arg_types[0])
parsed_key = k if arg_types[0] is str else _parse_value(k, arg_types[0], root_type)
try:
parsed_value = _parse_value(v, arg_types[1])
parsed_value = _parse_value(v, arg_types[1], root_type)
except _PARSING_ERRORS as e:
raise _bubble_up_parse_error(e, k)
result[parsed_key] = parsed_value
Expand All @@ -220,7 +228,7 @@ def _parse_value(value, value_type: type):
# omitting the field's value
return None
else:
return _parse_value(value, arg_types[0])
return _parse_value(value, arg_types[0], root_type)

elif generic_type is Literal and len(arg_types) == 1:
# Type is a Literal (parsed value must match specified value)
Expand All @@ -231,12 +239,15 @@ def _parse_value(value, value_type: type):
else:
raise ValueError(f"Automatic parsing of {value_type} type is not yet implemented")

elif value_type == Self:
# value is outself type
return ImplicitDict.parse(value, root_type)
elif issubclass(value_type, ImplicitDict):
# value is an ImplicitDict
return ImplicitDict.parse(value, value_type)

if hasattr(value_type, "__orig_bases__") and value_type.__orig_bases__:
return value_type(_parse_value(value, value_type.__orig_bases__[0]))
return value_type(_parse_value(value, value_type.__orig_bases__[0], root_type))

else:
# value is a non-generic type that is not an ImplicitDict
Expand Down
8 changes: 7 additions & 1 deletion src/implicitdict/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from datetime import datetime
from types import UnionType
from typing import Literal, TypeAlias, Union, cast, get_args, get_origin, get_type_hints
from typing import Literal, Self, TypeAlias, Union, cast, get_args, get_origin, get_type_hints

from . import ImplicitDict, StringBasedDateTime, StringBasedTimeDelta, _fullname, _get_fields

Expand Down Expand Up @@ -170,6 +170,12 @@ def _schema_for(

schema_vars = schema_vars_resolver(value_type)

if value_type == Self:
if not schema_vars.path_to:
raise NotImplementedError(f"SchemaVarsResolver for {value_type} didn't returned a path_to function")

return {"$ref": schema_vars.path_to(context, context)}, False

if issubclass(value_type, ImplicitDict):
make_json_schema(value_type, schema_vars_resolver, schema_repository)

Expand Down
33 changes: 31 additions & 2 deletions tests/test_jsonschema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Self

import jsonschema

Expand All @@ -8,11 +9,13 @@

from .test_types import (
ContainerData,
HiddenReferencingSelf,
InheritanceData,
NestedDefinitionsData,
NormalUsageData,
OptionalData,
PropertiesData,
ReferencingSelf,
SpecialSubclassesContainer,
SpecialTypesData,
)
Expand All @@ -28,10 +31,26 @@ def path_to(t_dest: type, t_src: type) -> str:


def _verify_schema_validation(obj, obj_type: type[ImplicitDict]) -> None:
def _root_resolver(t: type) -> SchemaVars:
"""Special resolver that references '#' for Self at root of the schema"""

if t == Self:

def path_to(t_dest: type, t_src: type) -> str:
if t_src == obj_type:
return "#"
else:
return "#/definitions/" + t_dest.__module__ + t_dest.__qualname__

full_name = t.__module__ + t.__qualname__
return SchemaVars(name=full_name, path_to=path_to)

return _resolver(t)

repo = {}
implicitdict.jsonschema.make_json_schema(obj_type, _resolver, repo)
implicitdict.jsonschema.make_json_schema(obj_type, _root_resolver, repo)

name = _resolver(obj_type).name
name = _root_resolver(obj_type).name
schema = repo[name]
del repo[name]
if repo:
Expand Down Expand Up @@ -113,3 +132,13 @@ def test_special_types():
def test_nested_definitions():
data = NestedDefinitionsData.example_value()
_verify_schema_validation(data, NestedDefinitionsData)


def test_self():
data = ReferencingSelf.example_value()
_verify_schema_validation(data, ReferencingSelf)


def test_hidden_self():
data = HiddenReferencingSelf.example_value()
_verify_schema_validation(data, HiddenReferencingSelf)
46 changes: 45 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# comments)
import enum
from datetime import UTC, datetime
from typing import List, Optional # noqa UP035
from typing import List, Optional, Self # noqa UP035

from implicitdict import ImplicitDict, StringBasedDateTime, StringBasedTimeDelta

Expand Down Expand Up @@ -210,3 +210,47 @@ def example_value():
},
NestedDefinitionsData,
)


class ReferencingSelf(ImplicitDict):
foo: str
bar: Self | None

@staticmethod
def example_value():
return ImplicitDict.parse(
{
"foo": "foo",
"bar": {
"foo": "subfoo",
},
},
ReferencingSelf,
)


class HiddenReferencingSelf(ImplicitDict):
baz: ReferencingSelf
bazs: list[ReferencingSelf]

@staticmethod
def example_value():
return ImplicitDict.parse(
{
"baz": {
"foo": "foo",
"bar": {
"foo": "subfoo",
},
},
"bazs": [
{
"foo": "foo",
"bar": {
"foo": "subfoo",
},
}
],
},
HiddenReferencingSelf,
)