-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy path_dataclass.py
145 lines (118 loc) · 4.98 KB
/
_dataclass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import enum
import json
from dataclasses import fields, is_dataclass
from typing import Any, Dict, get_args, get_origin, get_type_hints, Union
class _DataclassEncoder(json.JSONEncoder):
# pyre-ignore
def default(self, o: Any) -> Any:
if is_dataclass(o):
props = {}
for field in fields(o):
props[field.name] = getattr(o, field.name)
origin = get_origin(get_type_hints(type(o))[field.name])
if isinstance(field.type, str) and origin == Union:
props[f"{field.name}_type"] = type(getattr(o, field.name)).__name__
return props
if isinstance(o, bytes):
return list(o)
return super().default(o)
# Dataclass Decoder
# pyre-ignore
def _is_optional(T: Any) -> bool:
return (
get_origin(T) is Union
and len(get_args(T)) > 0
and isinstance(None, get_args(T)[-1])
)
# pyre-ignore
def _is_strict_union(T: Any, cls: Any, key: str) -> bool:
return isinstance(T, str) and get_origin(get_type_hints(cls)[key]) is Union
# pyre-ignore
def _get_class_from_union(json_dict: Dict[str, Any], key: str, cls: Any) -> Any:
"""Search through all possible types in the Union and select the type we
want to unpack (note in the serialization of a PyObject to JSON,
the type we want to unpack is keyed by f"{field.name}_type").
"""
_type = json_dict[key + "_type"]
res = [x for x in get_args(get_type_hints(cls)[key]) if x.__name__ == _type]
return res[0]
# pyre-ignore
def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any:
"""Initializes a dataclass given a dictionary loaded from a json,
`json_dict`, and the expected class, `cls`, by iterating through the fields
of the class and retrieving the data for each. If there is a field that is
missing in the data, and that field is not the Optional type,
`_json_to_dataclass` raises a TypeError.
Args:
`json_dict` : Dictionary formatted to represent a class, where fields are keys in the dictionary,
and values are values with the required type (as outlined in the dataclass definition). If a field is
specified to be another dataclass, the value will be another dictionary. See example below:
SAMPLE JSON:
{field1 : v1, inner_class : {field2_1: v2_1, field2_2: v2_1}}.
`cls` : The class that we should be unpacking from the given dictionary
(gives us an idea of what fields and values will be present in `json_dict`)
SAMPLE CLASSES for Above JSON:
@dataclass
class AnotherDataClass
field2_1: int
field2_2: str
@dataclass
class Example
field1 : str
inner_class: AnotherDataClass
Returns: An initialized PyObject of class: `cls`, given the data from `json_dict`.
"""
if not is_dataclass(cls) or is_dataclass(json_dict):
return json_dict
# initialize dataclass by iterating through all required fields
cls_flds = fields(cls)
data = {}
for field in cls_flds:
key = field.name
T = field.type
if _is_optional(T):
T = get_args(T)[0]
value = json_dict.get(key, None)
elif _is_strict_union(T, cls, key):
# If the field is a Union type, we determine exactly what type we
# are trying to initialize by calling `_get_class_from_union`, and
# then make a recursive call construct this new class
_cls = _get_class_from_union(json_dict, key, cls)
data[key] = _json_to_dataclass(json_dict[key], _cls)
continue
else:
try:
value = json_dict[key]
except KeyError:
raise TypeError(
f"Invalid Buffer. Received no value for field: {key}, but {key} : {T} is not an Optional type."
)
if value is None:
data[key] = None
continue
if is_dataclass(T):
data[key] = _json_to_dataclass(value, T)
continue
if get_origin(T) is list:
T = get_args(T)[0]
data[key] = [_json_to_dataclass(e, T) for e in value]
continue
# If T is a Union, then check which type in the Union it is and initialize.
# eg. Double type in schema.py
if get_origin(T) is Union:
res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)]
data[key] = res[0](value)
continue
# If T is an enum then lookup the value in the enum otherwise try to
# cast value to whatever type is required
if isinstance(T, enum.EnumMeta):
data[key] = T[value]
else:
data[key] = T(value)
return cls(**data)