Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions prodict/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, TypeVar, Tuple
from typing import Any, List, TypeVar, Tuple, Union
import copy

# from typing_inspect import get_parameters
Expand Down Expand Up @@ -55,7 +55,21 @@ def get_attr_default_value(cls, attr_name: str):

@classmethod
def attr_type(cls, attr_name: str):
return cls.attr_types()[attr_name]
t = cls.attr_types()[attr_name]
if hasattr(t, "__origin__") and t.__origin__ is Union:
args = t.__args__
if len(args) != 2:
raise TypeError("Unsupported Union -- only 2 elements (i.e. Optional[]) allowed")

# Index 1 is where we'll usually find the None, so check there first
if args[1] is type(None):
return args[0]
elif args[0] is type(None):
return args[1]
else:
raise TypeError("Unsupported Union -- only Unions with None (i.e. Optional[]) allowed")
else:
return cls.attr_types()[attr_name]

@classmethod
def attr_types(cls):
Expand Down
34 changes: 33 additions & 1 deletion test_prodict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Any, Tuple
from typing import List, Optional, Any, Tuple, Union
import unittest
from datetime import datetime
from prodict import Prodict
Expand Down Expand Up @@ -480,6 +480,37 @@ class MyLinkListNode(Prodict):
assert type(root_node) is type(copied)


class SupportedOptional(Prodict):
f1: int
f2: Optional[str]

class UnsupportedUnion1(Prodict):
f1: Union[str, int] # Illegal

class UnsupportedUnion2(Prodict):
f1: Union[str, int, Ram] # Also illegal


def test_optional():
# Should work
obj = SupportedOptional.from_dict({'f1': 33, 'f2': "field2"}) # Fails with original code, works with updated

# Should fail
try:
obj = UnsupportedUnion1.from_dict({'f1': 4})
assert(False)
except TypeError:
pass

# Should fail
try:
obj = UnsupportedUnion2.from_dict({'f1': 4})
assert(False)
except TypeError:
pass



if __name__ == '__main__':
start_time = datetime.now().timestamp()

Expand Down Expand Up @@ -509,6 +540,7 @@ class MyLinkListNode(Prodict):
test_use_defaults_method()
test_deepcopy1()
test_deepcopy2()
test_optional()

end_time = datetime.now().timestamp()

Expand Down