From ce87cc3d8edd339f9466a31f6859fd0cb1271db5 Mon Sep 17 00:00:00 2001 From: Ashley Whetter Date: Wed, 30 Oct 2019 21:24:37 -0700 Subject: [PATCH 1/3] Can understand functools.total_ordering --- mypy/plugins/default.py | 4 + mypy/plugins/functools.py | 105 +++++++++++++++++++++++++++ mypy/test/testcheck.py | 1 + test-data/unit/check-functools.test | 109 ++++++++++++++++++++++++++++ 4 files changed, 219 insertions(+) create mode 100644 mypy/plugins/functools.py create mode 100644 test-data/unit/check-functools.test diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index dc17450664c8..4257eb44df2a 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -93,6 +93,7 @@ def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: from mypy.plugins import attrs from mypy.plugins import dataclasses + from mypy.plugins import functools if fullname in attrs.attr_class_makers: return attrs.attr_class_maker_callback @@ -103,6 +104,9 @@ def get_class_decorator_hook(self, fullname: str ) elif fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_class_maker_callback + elif fullname in functools.functools_total_ordering_makers: + return functools.functools_total_ordering_maker_callback + return None diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py new file mode 100644 index 000000000000..7a534244daf2 --- /dev/null +++ b/mypy/plugins/functools.py @@ -0,0 +1,105 @@ +"""Plugin for supporting the functools standard library module.""" +from typing import Dict, NamedTuple, Optional + +import mypy.plugin +from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR2, Argument, FuncItem, Var +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType + + +functools_total_ordering_makers = { + 'functools.total_ordering', +} + +_ORDERING_METHODS = { + '__lt__', + '__le__', + '__gt__', + '__ge__', +} + + +_MethodInfo = NamedTuple('_MethodInfo', [('is_static', bool), ('type', CallableType)]) + + +def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, + auto_attribs_default: bool = False) -> None: + """Add dunder methods to classes decorated with functools.total_ordering.""" + if ctx.api.options.python_version < (3, 2): + ctx.api.fail('"functools.total_ordering" is not supported in Python 2 or 3.1', ctx.reason) + return + + comparison_methods = _analyze_class(ctx) + if not comparison_methods: + ctx.api.fail( + 'No ordering operation defined when using "functools.total_ordering": < > <= >=', + ctx.reason) + return + + # prefer __lt__ to __le__ to __gt__ to __ge__ + root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k)) + root_method = comparison_methods[root] + if not root_method: + # None of the defined comparison methods can be analysed + return + + other_type = _find_other_type(root_method) + bool_type = ctx.api.named_type('__builtins__.bool') + ret_type = bool_type # type: Type + if root_method.type.ret_type != ctx.api.named_type('__builtins__.bool'): + proper_ret_type = get_proper_type(root_method.type.ret_type) + if not (isinstance(proper_ret_type, UnboundType) + and proper_ret_type.name.endswith('bool')): + ret_type = AnyType(TypeOfAny.implementation_artifact) + for additional_op in _ORDERING_METHODS: + # Either the method is not implemented + # or has an unknown signature that we can now extrapolate. + if not comparison_methods.get(additional_op): + args = [Argument(Var('other', other_type), other_type, None, ARG_POS)] + add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type) + + +def _find_other_type(method: _MethodInfo) -> Type: + """Find the type of the ``other`` argument in a comparision method.""" + first_arg_pos = 0 if method.is_static else 1 + cur_pos_arg = 0 + other_arg = None + for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types): + if arg_kind in (ARG_POS, ARG_OPT): + if cur_pos_arg == first_arg_pos: + other_arg = arg_type + break + + cur_pos_arg += 1 + elif arg_kind != ARG_STAR2: + other_arg = arg_type + break + + if other_arg is None: + return AnyType(TypeOfAny.implementation_artifact) + + return other_arg + + +def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> Dict[str, Optional[_MethodInfo]]: + """Analyze the class body, its parents, and return the comparison methods found.""" + # Traverse the MRO and collect ordering methods. + comparison_methods = {} # type: Dict[str, Optional[_MethodInfo]] + # Skip object because total_ordering does not use methods from object + for cls in ctx.cls.info.mro[:-1]: + for name in _ORDERING_METHODS: + if name in cls.names and name not in comparison_methods: + node = cls.names[name].node + if isinstance(node, FuncItem) and isinstance(node.type, CallableType): + comparison_methods[name] = _MethodInfo(node.is_static, node.type) + continue + + if isinstance(node, Var): + proper_type = get_proper_type(node.type) + if isinstance(proper_type, CallableType): + comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type) + continue + + comparison_methods[name] = None + + return comparison_methods diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 34d9b66da0c1..c8b957b55f4d 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -91,6 +91,7 @@ 'check-errorcodes.test', 'check-annotated.test', 'check-parameter-specification.test', + 'check-functools.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test new file mode 100644 index 000000000000..4b44a20a093e --- /dev/null +++ b/test-data/unit/check-functools.test @@ -0,0 +1,109 @@ +[case testTotalOrderingEqLt] +from functools import total_ordering + +@total_ordering +class Ord: + def __eq__(self, other: object) -> bool: + return False + + def __lt__(self, other: "Ord") -> bool: + return False + +reveal_type(Ord() < Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() <= Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() == Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() > Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() >= Ord()) # N: Revealed type is 'builtins.bool' + +Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int") +Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int") +Ord() == 1 +Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int") +Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int") +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingLambda] +from functools import total_ordering +from typing import Any, Callable + +@total_ordering +class Ord: + __eq__: Callable[[Any, object], bool] = lambda self, other: False + __lt__: Callable[[Any, "Ord"], bool] = lambda self, other: False + +reveal_type(Ord() < Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() <= Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() == Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() > Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() >= Ord()) # N: Revealed type is 'builtins.bool' + +Ord() < 1 # E: Argument 1 has incompatible type "int"; expected "Ord" +Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int") +Ord() == 1 +Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int") +Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int") +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingNonCallable] +from functools import total_ordering + +@total_ordering +class Ord(object): + def __eq__(self, other: object) -> bool: + return False + + __lt__ = 5 + +Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord") +Ord() > Ord() # E: "int" not callable +Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord") + +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingReturnNotBool] +from functools import total_ordering + +@total_ordering +class Ord: + def __eq__(self, other: object) -> bool: + return False + + def __lt__(self, other: "Ord") -> str: + return "blah" + +reveal_type(Ord() < Ord()) # N: Revealed type is 'builtins.str' +reveal_type(Ord() <= Ord()) # N: Revealed type is 'Any' +reveal_type(Ord() == Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() > Ord()) # N: Revealed type is 'Any' +reveal_type(Ord() >= Ord()) # N: Revealed type is 'Any' + +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingAllowsAny] +from functools import total_ordering + +@total_ordering +class Ord: + def __eq__(self, other): + return False + + def __gt__(self, other): + return False + +reveal_type(Ord() < Ord()) # N: Revealed type is 'Any' +Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord") +reveal_type(Ord() == Ord()) # N: Revealed type is 'Any' +reveal_type(Ord() > Ord()) # N: Revealed type is 'Any' +Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord") + +Ord() < 1 # E: Unsupported left operand type for < ("Ord") +Ord() <= 1 # E: Unsupported left operand type for <= ("Ord") +Ord() == 1 +Ord() > 1 +Ord() >= 1 # E: Unsupported left operand type for >= ("Ord") +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] From 15b51fcc071499b933e4cf469b701d9668cba255 Mon Sep 17 00:00:00 2001 From: Ashley Whetter Date: Wed, 30 Oct 2019 21:24:37 -0700 Subject: [PATCH 2/3] fixup! Can understand functools.total_ordering --- mypy/plugins/functools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 7a534244daf2..d2905c06c2e8 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -25,8 +25,8 @@ def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False) -> None: """Add dunder methods to classes decorated with functools.total_ordering.""" - if ctx.api.options.python_version < (3, 2): - ctx.api.fail('"functools.total_ordering" is not supported in Python 2 or 3.1', ctx.reason) + if ctx.api.options.python_version < (3,): + ctx.api.fail('"functools.total_ordering" is not supported in Python 2', ctx.reason) return comparison_methods = _analyze_class(ctx) @@ -49,7 +49,7 @@ def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, if root_method.type.ret_type != ctx.api.named_type('__builtins__.bool'): proper_ret_type = get_proper_type(root_method.type.ret_type) if not (isinstance(proper_ret_type, UnboundType) - and proper_ret_type.name.endswith('bool')): + and proper_ret_type.name.split('.')[-1] == 'bool'): ret_type = AnyType(TypeOfAny.implementation_artifact) for additional_op in _ORDERING_METHODS: # Either the method is not implemented @@ -60,7 +60,7 @@ def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, def _find_other_type(method: _MethodInfo) -> Type: - """Find the type of the ``other`` argument in a comparision method.""" + """Find the type of the ``other`` argument in a comparison method.""" first_arg_pos = 0 if method.is_static else 1 cur_pos_arg = 0 other_arg = None From 33f80f5909f260d92309d03a6259f8fcbbbb3938 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 11 Apr 2021 14:36:27 -0700 Subject: [PATCH 3/3] Update double quotes in error messages --- test-data/unit/check-functools.test | 36 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index 4b44a20a093e..416006591425 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -9,11 +9,11 @@ class Ord: def __lt__(self, other: "Ord") -> bool: return False -reveal_type(Ord() < Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() <= Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() == Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() > Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() >= Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() <= Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() > Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() >= Ord()) # N: Revealed type is "builtins.bool" Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int") Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int") @@ -32,11 +32,11 @@ class Ord: __eq__: Callable[[Any, object], bool] = lambda self, other: False __lt__: Callable[[Any, "Ord"], bool] = lambda self, other: False -reveal_type(Ord() < Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() <= Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() == Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() > Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() >= Ord()) # N: Revealed type is 'builtins.bool' +reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() <= Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() > Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() >= Ord()) # N: Revealed type is "builtins.bool" Ord() < 1 # E: Argument 1 has incompatible type "int"; expected "Ord" Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int") @@ -74,11 +74,11 @@ class Ord: def __lt__(self, other: "Ord") -> str: return "blah" -reveal_type(Ord() < Ord()) # N: Revealed type is 'builtins.str' -reveal_type(Ord() <= Ord()) # N: Revealed type is 'Any' -reveal_type(Ord() == Ord()) # N: Revealed type is 'builtins.bool' -reveal_type(Ord() > Ord()) # N: Revealed type is 'Any' -reveal_type(Ord() >= Ord()) # N: Revealed type is 'Any' +reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.str" +reveal_type(Ord() <= Ord()) # N: Revealed type is "Any" +reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() > Ord()) # N: Revealed type is "Any" +reveal_type(Ord() >= Ord()) # N: Revealed type is "Any" [builtins fixtures/ops.pyi] [builtins fixtures/dict.pyi] @@ -94,10 +94,10 @@ class Ord: def __gt__(self, other): return False -reveal_type(Ord() < Ord()) # N: Revealed type is 'Any' +reveal_type(Ord() < Ord()) # N: Revealed type is "Any" Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord") -reveal_type(Ord() == Ord()) # N: Revealed type is 'Any' -reveal_type(Ord() > Ord()) # N: Revealed type is 'Any' +reveal_type(Ord() == Ord()) # N: Revealed type is "Any" +reveal_type(Ord() > Ord()) # N: Revealed type is "Any" Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord") Ord() < 1 # E: Unsupported left operand type for < ("Ord")