Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.
Open
180 changes: 180 additions & 0 deletions numba_typing/overload_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import numba
from numba import types
from numba.extending import overload
from type_annotations import product_annotations, get_func_annotations
import typing
from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
from numba.typed import List, Dict
from inspect import getfullargspec


def overload_list(orig_func):
def overload_inner(ovld_list):
def wrapper(*args):
func_list = ovld_list()
sig_list = []
for func in func_list:
sig_list.append((product_annotations(
get_func_annotations(func)), func))
args_orig_func = getfullargspec(orig_func)
values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)}
defaults_dict = {}
if args_orig_func.defaults:
defaults_dict = {name: value for name, value in zip(
args_orig_func.args[::-1], args_orig_func.defaults[::-1])}
result = choose_func_by_sig(sig_list, values_dict, defaults_dict)

if result is None:
raise numba.TypingError(f'Unsupported types a={a}, b={b}')

return result

return overload(orig_func, strict=False)(wrapper)

return overload_inner


def check_int_type(n_type):
return isinstance(n_type, types.Integer)


def check_float_type(n_type):
return isinstance(n_type, types.Float)


def check_bool_type(n_type):
return isinstance(n_type, types.Boolean)


def check_str_type(n_type):
return isinstance(n_type, types.UnicodeType)


def check_list_type(self, p_type, n_type):
res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

res = isinstance(n_type, (types.List, types.ListType))

if isinstance(p_type, type):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(p_type, (list, typing.List))

?

return res
else:
return res and self.match(p_type.__args__[0], n_type.dtype)


def check_tuple_type(self, p_type, n_type):
res = False
if isinstance(n_type, types.Tuple):
Copy link
Collaborator

@AlexanderKalistratov AlexanderKalistratov Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using n_type.key you could use n_type.types which is defined for both Tuple and UniTuple, so you can unite both branches:

Something like:

if not isinstance(n_type, types.Tuple, types.UniTuple):
    return False

for p_val, n_val in zip(p_type.__args__, n_type.types):
    if not self.match(p_val, n_val):
        return False

return True

And btw you need to check that size of p_type.__args__ and n_type.types are the same.
And I believe you don't have tests for the case when they are different.

res = True
if isinstance(p_type, type):
return res
for p_val, n_val in zip(p_type.__args__, n_type.key):
res = res and self.match(p_val, n_val)
if isinstance(n_type, types.UniTuple):
res = True
if isinstance(p_type, type):
return res
for p_val in p_type.__args__:
res = res and self.match(p_val, n_type.key[0])
return res


def check_dict_type(self, p_type, n_type):
res = False
if isinstance(n_type, types.DictType):
res = True
if isinstance(p_type, type):
return res
for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type):
res = res and self.match(p_val, n_val)
return res


class TypeChecker:

_types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this checks to be added using add_type_check function

str: check_str_type, list: check_list_type,
tuple: check_tuple_type, dict: check_dict_type}

def __init__(self):
self._typevars_dict = {}

def clear_typevars_dict(self):
self._typevars_dict.clear()

@classmethod
def add_type_check(cls, type_check, func):
cls._types_dict[type_check] = func

@staticmethod
def _is_generic(p_obj):
if isinstance(p_obj, typing._GenericAlias):
return True

if isinstance(p_obj, typing._SpecialForm):
return p_obj not in {typing.Any}

return False

@staticmethod
def _get_origin(p_obj):
return p_obj.__origin__

def match(self, p_type, n_type):
try:
if p_type == typing.Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it like this:

            if p_type == typing.Any:
                return True

            if self._is_generic(p_type):
                origin_type = self._get_origin(p_type)
                if origin_type == typing.Generic:
                    return self.match_generic(p_type, n_type)

                return self._types_dict[origin_type](self, p_type, n_type)

            if isinstance(p_type, typing.TypeVar):
                return self.match_typevar(p_type, n_type)

            if p_type in (list, tuple):
                return self._types_dict[p_type](self, p_type, n_type)

            return self._types_dict[p_type](n_type)

return True
elif self._is_generic(p_type):
origin_type = self._get_origin(p_type)
if origin_type == typing.Generic:
return self.match_generic(p_type, n_type)
else:
return self._types_dict[origin_type](self, p_type, n_type)
elif isinstance(p_type, typing.TypeVar):
return self.match_typevar(p_type, n_type)
else:
if p_type in (list, tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't dict be here too?

return self._types_dict[p_type](self, p_type, n_type)
return self._types_dict[p_type](n_type)
except KeyError:
print((f'A check for the {p_type} was not found.'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rise an exception

return None

def match_typevar(self, p_type, n_type):
if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need condition n_type not in self._typevars_dict.values() ?

self._typevars_dict[p_type] = n_type
return True
return self._typevars_dict.get(p_type) == n_type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it should be self.match. E.g. list and 'types.List' are synonyms but will fail equality check ('list != types.List').
And I'm assuming you don't have such tests?


def match_generic(self, p_type, n_type):
res = True
for arg in p_type.__args__:
res = res and self.match(arg, n_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's doesn't feel right. Do we have any test for this case?

return res


def choose_func_by_sig(sig_list, values_dict, defaults_dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it this way:

def choose_func_by_sig(sig_list, values_dict, defaults_dict):
    def check_signature(sig_params, types_dict):
        checker = TypeChecker()
        for name, typ in types_dict.items():  # name,type = 'a',int64
            if isinstance(typ, types.Literal):
                typ = typ.literal_type

            if not checker.match(sig_params[name], typ):
                return False
            
            return True

    for sig, func in sig_list:  # sig = (Signature,func)
        for param in sig.parameters:  # param = {'a':int,'b':int}
            if check_signature(param, values_dict):
                return func

    return None

checker = TypeChecker()
for sig, func in sig_list: # sig = (Signature,func)
for param in sig.parameters: # param = {'a':int,'b':int}
full_match = True
for name, typ in values_dict.items(): # name,type = 'a',int64
if isinstance(typ, types.Literal):

full_match = full_match and checker.match(
param[name], typ.literal_type)

if sig.defaults.get(name, False):
full_match = full_match and sig.defaults[name] == typ.literal_value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that valid sig.defaults[name] == typ.literal_value?
What would happen if we simply pass another literal value? E.g.:

def foo(a, b=0):
    ...

foo(a, 1)

?

else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this else?

full_match = full_match and checker.match(param[name], typ)

if not full_match:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this full_match variable?
Can't we simply break as soon as checker.match returns false?

break

for name, val in defaults_dict.items():
if not sig.defaults.get(name) is None:
full_match = full_match and sig.defaults[name] == val

checker.clear_typevars_dict()
if full_match:
return func

return None
Loading