|
7 | 7 | """
|
8 | 8 | import numpy as np
|
9 | 9 |
|
| 10 | +import operator |
| 11 | + |
10 | 12 | from pandas.errors import AbstractMethodError
|
11 | 13 | from pandas.compat.numpy import function as nv
|
| 14 | +from pandas.compat import set_function_name, PY3 |
| 15 | +from pandas.core.dtypes.common import is_list_like |
| 16 | +from pandas.core import ops |
12 | 17 |
|
13 | 18 | _not_implemented_message = "{} does not implement {}."
|
14 | 19 |
|
@@ -623,3 +628,125 @@ def _ndarray_values(self):
|
623 | 628 | used for interacting with our indexers.
|
624 | 629 | """
|
625 | 630 | return np.array(self)
|
| 631 | + |
| 632 | + |
| 633 | +class ExtensionOpsMixin(object): |
| 634 | + """ |
| 635 | + A base class for linking the operators to their dunder names |
| 636 | + """ |
| 637 | + @classmethod |
| 638 | + def _add_arithmetic_ops(cls): |
| 639 | + cls.__add__ = cls._create_arithmetic_method(operator.add) |
| 640 | + cls.__radd__ = cls._create_arithmetic_method(ops.radd) |
| 641 | + cls.__sub__ = cls._create_arithmetic_method(operator.sub) |
| 642 | + cls.__rsub__ = cls._create_arithmetic_method(ops.rsub) |
| 643 | + cls.__mul__ = cls._create_arithmetic_method(operator.mul) |
| 644 | + cls.__rmul__ = cls._create_arithmetic_method(ops.rmul) |
| 645 | + cls.__pow__ = cls._create_arithmetic_method(operator.pow) |
| 646 | + cls.__rpow__ = cls._create_arithmetic_method(ops.rpow) |
| 647 | + cls.__mod__ = cls._create_arithmetic_method(operator.mod) |
| 648 | + cls.__rmod__ = cls._create_arithmetic_method(ops.rmod) |
| 649 | + cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv) |
| 650 | + cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv) |
| 651 | + cls.__truediv__ = cls._create_arithmetic_method(operator.truediv) |
| 652 | + cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv) |
| 653 | + if not PY3: |
| 654 | + cls.__div__ = cls._create_arithmetic_method(operator.div) |
| 655 | + cls.__rdiv__ = cls._create_arithmetic_method(ops.rdiv) |
| 656 | + |
| 657 | + cls.__divmod__ = cls._create_arithmetic_method(divmod) |
| 658 | + cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod) |
| 659 | + |
| 660 | + @classmethod |
| 661 | + def _add_comparison_ops(cls): |
| 662 | + cls.__eq__ = cls._create_comparison_method(operator.eq) |
| 663 | + cls.__ne__ = cls._create_comparison_method(operator.ne) |
| 664 | + cls.__lt__ = cls._create_comparison_method(operator.lt) |
| 665 | + cls.__gt__ = cls._create_comparison_method(operator.gt) |
| 666 | + cls.__le__ = cls._create_comparison_method(operator.le) |
| 667 | + cls.__ge__ = cls._create_comparison_method(operator.ge) |
| 668 | + |
| 669 | + |
| 670 | +class ExtensionScalarOpsMixin(ExtensionOpsMixin): |
| 671 | + """A mixin for defining the arithmetic and logical operations on |
| 672 | + an ExtensionArray class, where it is assumed that the underlying objects |
| 673 | + have the operators already defined. |
| 674 | +
|
| 675 | + Usage |
| 676 | + ------ |
| 677 | + If you have defined a subclass MyExtensionArray(ExtensionArray), then |
| 678 | + use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to |
| 679 | + get the arithmetic operators. After the definition of MyExtensionArray, |
| 680 | + insert the lines |
| 681 | +
|
| 682 | + MyExtensionArray._add_arithmetic_ops() |
| 683 | + MyExtensionArray._add_comparison_ops() |
| 684 | +
|
| 685 | + to link the operators to your class. |
| 686 | + """ |
| 687 | + |
| 688 | + @classmethod |
| 689 | + def _create_method(cls, op, coerce_to_dtype=True): |
| 690 | + """ |
| 691 | + A class method that returns a method that will correspond to an |
| 692 | + operator for an ExtensionArray subclass, by dispatching to the |
| 693 | + relevant operator defined on the individual elements of the |
| 694 | + ExtensionArray. |
| 695 | +
|
| 696 | + Parameters |
| 697 | + ---------- |
| 698 | + op : function |
| 699 | + An operator that takes arguments op(a, b) |
| 700 | + coerce_to_dtype : bool |
| 701 | + boolean indicating whether to attempt to convert |
| 702 | + the result to the underlying ExtensionArray dtype |
| 703 | + (default True) |
| 704 | +
|
| 705 | + Returns |
| 706 | + ------- |
| 707 | + A method that can be bound to a method of a class |
| 708 | +
|
| 709 | + Example |
| 710 | + ------- |
| 711 | + Given an ExtensionArray subclass called MyExtensionArray, use |
| 712 | +
|
| 713 | + >>> __add__ = cls._create_method(operator.add) |
| 714 | +
|
| 715 | + in the class definition of MyExtensionArray to create the operator |
| 716 | + for addition, that will be based on the operator implementation |
| 717 | + of the underlying elements of the ExtensionArray |
| 718 | +
|
| 719 | + """ |
| 720 | + |
| 721 | + def _binop(self, other): |
| 722 | + def convert_values(param): |
| 723 | + if isinstance(param, ExtensionArray) or is_list_like(param): |
| 724 | + ovalues = param |
| 725 | + else: # Assume its an object |
| 726 | + ovalues = [param] * len(self) |
| 727 | + return ovalues |
| 728 | + lvalues = self |
| 729 | + rvalues = convert_values(other) |
| 730 | + |
| 731 | + # If the operator is not defined for the underlying objects, |
| 732 | + # a TypeError should be raised |
| 733 | + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] |
| 734 | + |
| 735 | + if coerce_to_dtype: |
| 736 | + try: |
| 737 | + res = self._from_sequence(res) |
| 738 | + except TypeError: |
| 739 | + pass |
| 740 | + |
| 741 | + return res |
| 742 | + |
| 743 | + op_name = ops._get_op_name(op, True) |
| 744 | + return set_function_name(_binop, op_name, cls) |
| 745 | + |
| 746 | + @classmethod |
| 747 | + def _create_arithmetic_method(cls, op): |
| 748 | + return cls._create_method(op) |
| 749 | + |
| 750 | + @classmethod |
| 751 | + def _create_comparison_method(cls, op): |
| 752 | + return cls._create_method(op, coerce_to_dtype=False) |
0 commit comments