From 4247c122b20e78c79b6c89042bf80f46b254aee3 Mon Sep 17 00:00:00 2001 From: Chris Eykamp Date: Sun, 12 Jul 2020 02:11:17 -0700 Subject: [PATCH] Add support for Optional types --- prodict/__init__.py | 18 ++++++++++++++++-- test_prodict.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/prodict/__init__.py b/prodict/__init__.py index 84cc0eb..e91b9f7 100644 --- a/prodict/__init__.py +++ b/prodict/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, List, TypeVar, Tuple +from typing import Any, List, TypeVar, Tuple, Union import copy # from typing_inspect import get_parameters @@ -55,7 +55,21 @@ def get_attr_default_value(cls, attr_name: str): @classmethod def attr_type(cls, attr_name: str): - return cls.attr_types()[attr_name] + t = cls.attr_types()[attr_name] + if hasattr(t, "__origin__") and t.__origin__ is Union: + args = t.__args__ + if len(args) != 2: + raise TypeError("Unsupported Union -- only 2 elements (i.e. Optional[]) allowed") + + # Index 1 is where we'll usually find the None, so check there first + if args[1] is type(None): + return args[0] + elif args[0] is type(None): + return args[1] + else: + raise TypeError("Unsupported Union -- only Unions with None (i.e. Optional[]) allowed") + else: + return cls.attr_types()[attr_name] @classmethod def attr_types(cls): diff --git a/test_prodict.py b/test_prodict.py index 9bc1a14..1cc7387 100644 --- a/test_prodict.py +++ b/test_prodict.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any, Tuple +from typing import List, Optional, Any, Tuple, Union import unittest from datetime import datetime from prodict import Prodict @@ -480,6 +480,37 @@ class MyLinkListNode(Prodict): assert type(root_node) is type(copied) +class SupportedOptional(Prodict): + f1: int + f2: Optional[str] + +class UnsupportedUnion1(Prodict): + f1: Union[str, int] # Illegal + +class UnsupportedUnion2(Prodict): + f1: Union[str, int, Ram] # Also illegal + + +def test_optional(): + # Should work + obj = SupportedOptional.from_dict({'f1': 33, 'f2': "field2"}) # Fails with original code, works with updated + + # Should fail + try: + obj = UnsupportedUnion1.from_dict({'f1': 4}) + assert(False) + except TypeError: + pass + + # Should fail + try: + obj = UnsupportedUnion2.from_dict({'f1': 4}) + assert(False) + except TypeError: + pass + + + if __name__ == '__main__': start_time = datetime.now().timestamp() @@ -509,6 +540,7 @@ class MyLinkListNode(Prodict): test_use_defaults_method() test_deepcopy1() test_deepcopy2() + test_optional() end_time = datetime.now().timestamp()