diff --git a/runtype/dispatch.py b/runtype/dispatch.py index ef6c6e5..8849e40 100644 --- a/runtype/dispatch.py +++ b/runtype/dispatch.py @@ -1,6 +1,6 @@ from collections import defaultdict from functools import wraps -from typing import Any, Dict, Callable, Sequence +from typing import Any, Dict, Callable, Sequence, List from operator import itemgetter import warnings @@ -14,6 +14,22 @@ class DispatchError(Exception): "Thrown whenever a dispatch fails. Contains text describing the conflict." +@dataclass +class DuplicateSignatureError(Exception): + signature: List[type] + first: Callable + second: Callable + + def __str__(self): + code1 = self.first.__code__ + code2 = self.second.__code__ + return (f"Duplicate signature defined for '{self.first.__name__}'\n" + f" - Signature: {self.signature}.\n" + f" - Definition 1: {code1.co_filename}:{code1.co_firstlineno}\n" + f" - Definition 2: {code2.co_filename}:{code2.co_firstlineno}\n" + ) + + # TODO: Remove test_subtypes, replace with support for Type[], like isa(t, Type[t]) class MultiDispatch: """Creates a dispatch group for multiple dispatch @@ -184,10 +200,8 @@ def define_function(self, f): node = node.follow_type[t] if node.func is not None: - code_obj = node.func[0].__code__ - raise ValueError( - f"Function {f.__name__} at {code_obj.co_filename}:{code_obj.co_firstlineno} matches existing signature: {signature}!" - ) + raise DuplicateSignatureError(signature, node.func[0], f) + node.func = f, signature def choose_most_specific_function(self, args, *funcs): diff --git a/tests/test_basic.py b/tests/test_basic.py index 7c869be..bc4ed3e 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -12,7 +12,7 @@ logging.basicConfig(level=logging.INFO) from runtype import Dispatch, DispatchError, dataclass, isa, is_subtype, issubclass, assert_isa, String, Int, validate_func, cv_type_checking, multidispatch -from runtype.dispatch import MultiDispatch +from runtype.dispatch import MultiDispatch, DuplicateSignatureError from runtype.dataclass import Configuration try: @@ -291,7 +291,7 @@ def f(s:str): @dy def f(x: int): return NotImplemented - except ValueError: + except DuplicateSignatureError: pass else: assert False, f @@ -512,7 +512,7 @@ def f(i:int, j:object, k:object): @dy def f(i:int, j:object, k:object): return "Oops" - except ValueError: + except DuplicateSignatureError: pass else: assert False @@ -536,7 +536,7 @@ def f(x: types[0]): def f(x): pass assert False - except ValueError: + except DuplicateSignatureError: pass for t in types[1:]: @@ -545,7 +545,7 @@ def f(x): def f(x: t): pass assert False, t - except ValueError: + except DuplicateSignatureError: pass _test_canon(object, Union[object], include_none=True) @@ -688,7 +688,7 @@ def f(a: int): def f(a: int): return 'a' f.__module__ = 'a' - self.assertRaises(ValueError, multidispatch, f) + self.assertRaises(DuplicateSignatureError, multidispatch, f) def test_none(self): dp = Dispatch()