-
Notifications
You must be signed in to change notification settings - Fork 62
automatic generation of type checks for overload #911
base: numba_typing
Are you sure you want to change the base?
Changes from 5 commits
8bd7bc0
d3f4a5d
b54da17
05c745b
b7446ca
09289bf
1cb60da
967ae29
716402c
2096e94
5ec33ac
7da564b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import numba | ||
from numba import types | ||
from numba.extending import overload | ||
from type_annotations import product_annotations, get_func_annotations | ||
import typing | ||
from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning | ||
import warnings | ||
from numba.typed import List, Dict | ||
from inspect import getfullargspec | ||
|
||
|
||
def overload_list(orig_func): | ||
def overload_inner(ovld_list): | ||
def wrapper(*args): | ||
func_list = ovld_list() | ||
sig_list = [] | ||
for func in func_list: | ||
sig_list.append((product_annotations( | ||
get_func_annotations(func)), func)) | ||
args_orig_func = getfullargspec(orig_func) | ||
values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)} | ||
defaults_dict = {} | ||
if args_orig_func.defaults: | ||
defaults_dict = {name: value for name, value in zip( | ||
args_orig_func.args[::-1], args_orig_func.defaults[::-1])} | ||
result = choose_func_by_sig(sig_list, values_dict, defaults_dict) | ||
|
||
if result is None: | ||
raise numba.TypingError(f'Unsupported types a={a}, b={b}') | ||
|
||
return result | ||
|
||
return overload(orig_func, strict=False)(wrapper) | ||
|
||
return overload_inner | ||
|
||
|
||
def check_int_type(n_type): | ||
return isinstance(n_type, types.Integer) | ||
|
||
|
||
def check_float_type(n_type): | ||
return isinstance(n_type, types.Float) | ||
|
||
|
||
def check_bool_type(n_type): | ||
return isinstance(n_type, types.Boolean) | ||
|
||
|
||
def check_str_type(n_type): | ||
return isinstance(n_type, types.UnicodeType) | ||
|
||
|
||
def check_list_type(self, p_type, n_type): | ||
res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType) | ||
if isinstance(p_type, type): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isinstance(p_type, (list, typing.List)) ? |
||
return res | ||
else: | ||
return res and self.match(p_type.__args__[0], n_type.dtype) | ||
|
||
|
||
def check_tuple_type(self, p_type, n_type): | ||
res = False | ||
if isinstance(n_type, types.Tuple): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using Something like: if not isinstance(n_type, types.Tuple, types.UniTuple):
return False
for p_val, n_val in zip(p_type.__args__, n_type.types):
if not self.match(p_val, n_val):
return False
return True And btw you need to check that size of |
||
res = True | ||
if isinstance(p_type, type): | ||
return res | ||
for p_val, n_val in zip(p_type.__args__, n_type.key): | ||
res = res and self.match(p_val, n_val) | ||
if isinstance(n_type, types.UniTuple): | ||
res = True | ||
if isinstance(p_type, type): | ||
return res | ||
for p_val in p_type.__args__: | ||
res = res and self.match(p_val, n_type.key[0]) | ||
return res | ||
|
||
|
||
def check_dict_type(self, p_type, n_type): | ||
res = False | ||
if isinstance(n_type, types.DictType): | ||
res = True | ||
if isinstance(p_type, type): | ||
return res | ||
for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type): | ||
res = res and self.match(p_val, n_val) | ||
return res | ||
|
||
|
||
class TypeChecker: | ||
|
||
_types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer this checks to be added using |
||
str: check_str_type, list: check_list_type, | ||
tuple: check_tuple_type, dict: check_dict_type} | ||
|
||
def __init__(self): | ||
self._typevars_dict = {} | ||
|
||
def clear_typevars_dict(self): | ||
self._typevars_dict.clear() | ||
|
||
@classmethod | ||
def add_type_check(cls, type_check, func): | ||
cls._types_dict[type_check] = func | ||
|
||
@staticmethod | ||
def _is_generic(p_obj): | ||
if isinstance(p_obj, typing._GenericAlias): | ||
return True | ||
|
||
if isinstance(p_obj, typing._SpecialForm): | ||
return p_obj not in {typing.Any} | ||
|
||
return False | ||
|
||
@staticmethod | ||
def _get_origin(p_obj): | ||
return p_obj.__origin__ | ||
|
||
def match(self, p_type, n_type): | ||
try: | ||
if p_type == typing.Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's do it like this: if p_type == typing.Any:
return True
if self._is_generic(p_type):
origin_type = self._get_origin(p_type)
if origin_type == typing.Generic:
return self.match_generic(p_type, n_type)
return self._types_dict[origin_type](self, p_type, n_type)
if isinstance(p_type, typing.TypeVar):
return self.match_typevar(p_type, n_type)
if p_type in (list, tuple):
return self._types_dict[p_type](self, p_type, n_type)
return self._types_dict[p_type](n_type) |
||
return True | ||
elif self._is_generic(p_type): | ||
origin_type = self._get_origin(p_type) | ||
if origin_type == typing.Generic: | ||
return self.match_generic(p_type, n_type) | ||
else: | ||
return self._types_dict[origin_type](self, p_type, n_type) | ||
elif isinstance(p_type, typing.TypeVar): | ||
return self.match_typevar(p_type, n_type) | ||
else: | ||
if p_type in (list, tuple): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't dict be here too? |
||
return self._types_dict[p_type](self, p_type, n_type) | ||
return self._types_dict[p_type](n_type) | ||
except KeyError: | ||
print((f'A check for the {p_type} was not found.')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should rise an exception |
||
return None | ||
|
||
def match_typevar(self, p_type, n_type): | ||
if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need condition |
||
self._typevars_dict[p_type] = n_type | ||
return True | ||
return self._typevars_dict.get(p_type) == n_type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it should be |
||
|
||
def match_generic(self, p_type, n_type): | ||
res = True | ||
for arg in p_type.__args__: | ||
res = res and self.match(arg, n_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's doesn't feel right. Do we have any test for this case? |
||
return res | ||
|
||
|
||
def choose_func_by_sig(sig_list, values_dict, defaults_dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's do it this way: def choose_func_by_sig(sig_list, values_dict, defaults_dict):
def check_signature(sig_params, types_dict):
checker = TypeChecker()
for name, typ in types_dict.items(): # name,type = 'a',int64
if isinstance(typ, types.Literal):
typ = typ.literal_type
if not checker.match(sig_params[name], typ):
return False
return True
for sig, func in sig_list: # sig = (Signature,func)
for param in sig.parameters: # param = {'a':int,'b':int}
if check_signature(param, values_dict):
return func
return None |
||
checker = TypeChecker() | ||
for sig, func in sig_list: # sig = (Signature,func) | ||
for param in sig.parameters: # param = {'a':int,'b':int} | ||
full_match = True | ||
for name, typ in values_dict.items(): # name,type = 'a',int64 | ||
if isinstance(typ, types.Literal): | ||
|
||
full_match = full_match and checker.match( | ||
param[name], typ.literal_type) | ||
|
||
if sig.defaults.get(name, False): | ||
full_match = full_match and sig.defaults[name] == typ.literal_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that valid def foo(a, b=0):
...
foo(a, 1) ? |
||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this |
||
full_match = full_match and checker.match(param[name], typ) | ||
|
||
if not full_match: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need this |
||
break | ||
|
||
for name, val in defaults_dict.items(): | ||
if not sig.defaults.get(name) is None: | ||
full_match = full_match and sig.defaults[name] == val | ||
|
||
checker.clear_typevars_dict() | ||
if full_match: | ||
return func | ||
|
||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.