diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index aeebf9d58..9ee450360 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -87,12 +87,22 @@ import copy import warnings from types import GeneratorType +from typing import Union, TypeAlias, TypeVar, Collection, Optional, Sequence, cast, List + import numpy as np import cpmpy as cp from .utils import is_int, is_num, is_any_list, flatlist, get_bounds, is_boolexpr, is_true_cst, is_false_cst, argvals, is_bool from ..exceptions import IncompleteFunctionError, TypeError +# Define types +BoolConst : TypeAlias = Union[bool, np.bool_, "BoolVal"] +NumConst : TypeAlias = Union[BoolConst, int, float, np.integer, np.floating] +ExprOrConst : TypeAlias = Union["Expression", NumConst] + +T = TypeVar('T') +FlatList = Sequence[T] | "NDVarArray" +NestedList = Sequence[Union[T, Collection[T]]] class Expression(object): """ @@ -110,7 +120,7 @@ class Expression(object): - any ``__op__`` python operator overloading """ - def __init__(self, name, arg_list): + def __init__(self, name:str, arg_list): self.name = name if isinstance(arg_list, (tuple, GeneratorType)): @@ -149,7 +159,7 @@ def set_description(self, txt, override_print=True, full_print=False): self._override_print = override_print self._full_print = full_print - def __str__(self): + def __str__(self) -> str: if not hasattr(self, "desc") or self._override_print is False: return self.__repr__() out = self.desc @@ -158,7 +168,7 @@ def __str__(self): return out - def __repr__(self): + def __repr__(self) -> str: strargs = [] for arg in self.args: if isinstance(arg, np.ndarray): @@ -226,7 +236,7 @@ def deepcopy(self, memodict={}): # implication constraint: self -> other # Python does not offer relevant syntax... # for double implication, use equivalence self == other - def implies(self, other): + def implies(self, other:ExprOrConst) -> "Expression": # other constant if is_true_cst(other): return BoolVal(True) @@ -235,7 +245,7 @@ def implies(self, other): return Operator('->', [self, other]) # Comparisons - def __eq__(self, other): + def __eq__(self, other: ExprOrConst) -> "Expression": # type: ignore # BoolExpr == 1|true|0|false, common case, simply BoolExpr if self.is_bool() and is_num(other): if other is True or other == 1: @@ -244,24 +254,24 @@ def __eq__(self, other): return ~self return Comparison("==", self, other) - def __ne__(self, other): + def __ne__(self, other: ExprOrConst) -> "Comparison": # type: ignore return Comparison("!=", self, other) - def __lt__(self, other): + def __lt__(self, other: ExprOrConst) -> "Comparison": # type: ignore return Comparison("<", self, other) - def __le__(self, other): + def __le__(self, other: ExprOrConst) -> "Comparison": # type: ignore return Comparison("<=", self, other) - def __gt__(self, other): + def __gt__(self, other: ExprOrConst) -> "Comparison": # type: ignore return Comparison(">", self, other) - def __ge__(self, other): + def __ge__(self, other: ExprOrConst) -> "Comparison": # type: ignore return Comparison(">=", self, other) # Boolean Operators # Implements bitwise operations & | ^ and ~ (and, or, xor, not) - def __and__(self, other): + def __and__(self, other: ExprOrConst) -> "Expression": # some simple constant removal if is_true_cst(other): return self @@ -271,7 +281,7 @@ def __and__(self, other): f"E.g. always write (x==2)&(y<5).") return Operator("and", [self, other]) - def __rand__(self, other): + def __rand__(self, other: ExprOrConst) -> "Expression": # some simple constant removal if is_true_cst(other): return self @@ -281,7 +291,7 @@ def __rand__(self, other): f"did you forget to put brackets? E.g. always write (x==2)&(y<5).") return Operator("and", [other, self]) - def __or__(self, other): + def __or__(self, other: ExprOrConst) -> "Expression": # some simple constant removal if is_false_cst(other): return self @@ -291,7 +301,7 @@ def __or__(self, other): f"did you forget to put brackets? E.g. always write (x==2)|(y<5).") return Operator("or", [self, other]) - def __ror__(self, other): + def __ror__(self, other: ExprOrConst) -> "Expression": # some simple constant removal if is_false_cst(other): return self @@ -301,7 +311,7 @@ def __ror__(self, other): f"did you forget to put brackets? E.g. always write (x==2)|(y<5).") return Operator("or", [other, self]) - def __xor__(self, other): + def __xor__(self, other: ExprOrConst) -> "Expression": # some simple constant removal if is_true_cst(other): return ~self @@ -309,7 +319,7 @@ def __xor__(self, other): return self return cp.Xor([self, other]) - def __rxor__(self, other): + def __rxor__(self, other: ExprOrConst) -> "Expression": # some simple constant removal if is_true_cst(other): return ~self @@ -319,31 +329,31 @@ def __rxor__(self, other): # Mathematical Operators, including 'r'everse if it exists # Addition - def __add__(self, other): + def __add__(self, other: ExprOrConst) -> "Expression": if is_num(other) and other == 0: return self return Operator("sum", [self, other]) - def __radd__(self, other): + def __radd__(self, other: ExprOrConst) -> "Expression": # type: ignore if is_num(other) and other == 0: return self return Operator("sum", [other, self]) # substraction - def __sub__(self, other): + def __sub__(self, other: ExprOrConst) -> "Expression": # if is_num(other) and other == 0: # return self # return Operator("sub", [self, other]) return self.__add__(-other) - def __rsub__(self, other): + def __rsub__(self, other: ExprOrConst) -> "Expression": # type: ignore # if is_num(other) and other == 0: # return -self # return Operator("sub", [other, self]) return (-self).__radd__(other) # multiplication, puts the 'constant' (other) first - def __mul__(self, other): + def __mul__(self, other: ExprOrConst) -> "Expression": if is_num(other) and other == 1: return self # this unnecessarily complicates wsum creation @@ -351,7 +361,7 @@ def __mul__(self, other): # return other return Operator("mul", [self, other]) - def __rmul__(self, other): + def __rmul__(self, other: ExprOrConst) -> "Expression": # type: ignore if is_num(other) and other == 1: return self # this unnecessarily complicates wsum creation @@ -363,36 +373,36 @@ def __rmul__(self, other): #object.__matmul__(self, other) # other mathematical ones - def __truediv__(self, other): + def __truediv__(self, other: ExprOrConst) -> "Expression": warnings.warn("We only support floordivision, use // in stead of /", SyntaxWarning) return self.__floordiv__(other) - def __rtruediv__(self, other): + def __rtruediv__(self, other: ExprOrConst) -> "Expression": # type: ignore warnings.warn("We only support floordivision, use // in stead of /", SyntaxWarning) return self.__rfloordiv__(other) - def __floordiv__(self, other): + def __floordiv__(self, other: ExprOrConst) -> "Expression": if is_num(other) and other == 1: return self return Operator("div", [self, other]) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: ExprOrConst) -> "Expression": # type: ignore return Operator("div", [other, self]) - def __mod__(self, other): + def __mod__(self, other: ExprOrConst) -> "Expression": return Operator("mod", [self, other]) - def __rmod__(self, other): + def __rmod__(self, other: ExprOrConst) -> "Expression": # type: ignore return Operator("mod", [other, self]) - def __pow__(self, other, modulo=None): + def __pow__(self, other: ExprOrConst, modulo=None) -> "Expression": assert (modulo is None), "Power operator: modulo not supported" if is_num(other): if other == 1: return self return Operator("pow", [self, other]) - def __rpow__(self, other, modulo=None): + def __rpow__(self, other: ExprOrConst, modulo=None) -> "Expression": # type: ignore assert (modulo is None), "Power operator: modulo not supported" return Operator("pow", [other, self]) @@ -400,7 +410,7 @@ def __rpow__(self, other, modulo=None): #object.__divmod__(self, other) # unary mathematical operators - def __neg__(self): + def __neg__(self) -> "Expression": # special case, -(w*x) -> -w*x if self.name == 'mul' and is_num(self.args[0]): return Operator(self.name, [-self.args[0], self.args[1]]) @@ -409,13 +419,13 @@ def __neg__(self): return Operator(self.name, [[-a for a in self.args[0]], self.args[1]]) return Operator("-", [self]) - def __pos__(self): + def __pos__(self) -> "Expression": return self - def __abs__(self): + def __abs__(self) -> "Expression": return cp.Abs(self) - def __invert__(self): + def __invert__(self) -> "Expression": if not (is_boolexpr(self)): raise TypeError("Not operator is only allowed on boolean expressions: {0}".format(self)) return Operator("not", [self]) @@ -430,21 +440,21 @@ class BoolVal(Expression): Wrapper for python or numpy BoolVals """ - def __init__(self, arg): + def __init__(self, arg: BoolConst) -> None: assert is_true_cst(arg) or is_false_cst(arg), f"BoolVal must be initialized with a boolean constant, got {arg} of type {type(arg)}" super(BoolVal, self).__init__("boolval", [bool(arg)]) def value(self): return self.args[0] - def __invert__(self): + def __invert__(self) -> "BoolVal": return BoolVal(not self.args[0]) - def __bool__(self): + def __bool__(self) -> bool: """Called to implement truth value testing and the built-in operation bool(), return stored value""" return self.args[0] - def __int__(self): + def __int__(self) -> int: """Called to implement conversion to numerical""" return int(self.args[0]) @@ -452,9 +462,9 @@ def get_bounds(self): v = int(self.args[0]) return (v,v) - def __and__(self, other): + def __and__(self, other: ExprOrConst) -> Expression: if is_bool(other): # Boolean constant - return BoolVal(self.args[0] and other) + return BoolVal(cast("bool", self.args[0] and other)) elif isinstance(other, Expression) and other.is_bool(): if self.args[0]: return other @@ -463,9 +473,9 @@ def __and__(self, other): raise ValueError(f"{self}&{other} is not valid. Expected Boolean constant or Boolean Expression, but got {other} of type {type(other)}.") - def __rand__(self, other): + def __rand__(self, other: ExprOrConst) -> Expression: if is_bool(other): # Boolean constant - return BoolVal(self.args[0] and other) + return BoolVal(cast("bool", self.args[0] and other)) elif isinstance(other, Expression) and other.is_bool(): if self.args[0]: return other @@ -474,9 +484,9 @@ def __rand__(self, other): raise ValueError(f"{self}&{other} is not valid. Expected Boolean constant or Boolean Expression, but got {other} of type {type(other)}.") - def __or__(self, other): + def __or__(self, other : ExprOrConst) -> Expression: if is_bool(other): # Boolean constant - return BoolVal(self.args[0] or other) + return BoolVal(cast("bool", self.args[0] or other)) elif isinstance(other, Expression) and other.is_bool(): if not self.args[0]: return other @@ -485,9 +495,9 @@ def __or__(self, other): raise ValueError(f"{self}|{other} is not valid. Expected Boolean constant or Boolean Expression, but got {other} of type {type(other)}.") - def __ror__(self, other): + def __ror__(self, other: ExprOrConst) -> Expression: if is_bool(other): # Boolean constant - return BoolVal(self.args[0] or other) + return BoolVal(cast("bool", self.args[0] or other)) elif isinstance(other, Expression) and other.is_bool(): if not self.args[0]: return other @@ -495,7 +505,7 @@ def __ror__(self, other): return BoolVal(True) raise ValueError(f"{self}|{other} is not valid. Expected Boolean constant or Boolean Expression, but got {other} of type {type(other)}.") - def __xor__(self, other): + def __xor__(self, other : ExprOrConst) -> Expression: if is_bool(other): # Boolean constant return BoolVal(self.args[0] ^ other) elif isinstance(other, Expression) and other.is_bool(): @@ -506,7 +516,7 @@ def __xor__(self, other): raise ValueError(f"{self}^^{other} is not valid. Expected Boolean constant or Boolean Expression, but got {other} of type {type(other)}.") - def __rxor__(self, other): + def __rxor__(self, other: ExprOrConst) -> Expression: if is_bool(other): # Boolean constant return BoolVal(self.args[0] ^ other) elif isinstance(other, Expression) and other.is_bool(): @@ -524,11 +534,17 @@ def has_subexpr(self) -> bool: """ return False # BoolVal is a wrapper for a python or numpy constant boolean. - def implies(self, other): - if self.args[0]: + def implies(self, other: ExprOrConst) -> "Expression": + # other constant + if is_true_cst(other): + return BoolVal(True) + elif is_false_cst(other): + return ~self + elif self.args[0] and isinstance(other, Expression): return other - else: - return other == other # Always true, but keep variables in the model + elif not self.args[0]: + return Operator('->', [self, other]) + raise ValueError(f"Expected Boolean constant or Expression but got {other} of type {type(other)}.") class Comparison(Expression): @@ -536,7 +552,7 @@ class Comparison(Expression): """ allowed = {'==', '!=', '<=', '<', '>=', '>'} - def __init__(self, name, left, right): + def __init__(self, name:str, left:ExprOrConst, right:ExprOrConst): assert (name in Comparison.allowed), f"Symbol {name} not allowed" super().__init__(name, [left, right]) @@ -554,7 +570,7 @@ def __bool__(self): # return the value of the expression # optional, default: None - def value(self): + def value(self) -> Optional[bool]: arg_vals = argvals(self.args) if any(a is None for a in arg_vals): return None @@ -591,7 +607,7 @@ class Operator(Expression): } printmap = {'sum': '+', 'sub': '-', 'mul': '*', 'div': '//'} - def __init__(self, name, arg_list): + def __init__(self, name: str, arg_list: NestedList[ExprOrConst]): # sanity checks assert (name in Operator.allowed), "Operator {} not allowed".format(name) arity, is_bool_op = Operator.allowed[name] diff --git a/cpmpy/transformations/normalize.py b/cpmpy/transformations/normalize.py index f45e4401c..587002ed4 100644 --- a/cpmpy/transformations/normalize.py +++ b/cpmpy/transformations/normalize.py @@ -125,10 +125,10 @@ def simplify_boolean(lst_of_expr, num_context=False): cond, bool_expr = args if is_false_cst(cond) or is_true_cst(bool_expr): newlist.append(1 if num_context else BoolVal(True)) - elif is_true_cst(cond): - newlist.append(bool_expr) elif is_false_cst(bool_expr): newlist += simplify_boolean([cp.transformations.negation.recurse_negation(cond)]) + elif is_true_cst(cond): + newlist.append(bool_expr) else: newlist.append(cond.implies(bool_expr)) diff --git a/tests/test_trans_simplify.py b/tests/test_trans_simplify.py index d733bec15..6961e5557 100644 --- a/tests/test_trans_simplify.py +++ b/tests/test_trans_simplify.py @@ -133,3 +133,42 @@ def test_nested_boolval(self): self.assertEqual(str(self.transform(cons)), "[sum([1, 2] * [bv[0], ~bv[1]]) == 1]") self.assertTrue(cp.Model(cons).solve()) + def test_implies(self): + + true = cp.BoolVal(True) + false = cp.BoolVal(False) + expr = cp.intvar(0,10,name="x") >= 3 + + e = true.implies(expr) + self.assertEqual(str(e), "x >= 3") + e = Operator("->", [true, expr]) + self.assertEqual(str(self.transform(e)), "[x >= 3]") + + e = false.implies(expr) + self.assertEqual(str(e), "(boolval(False)) -> (x >= 3)") + e = Operator("->", [false, expr]) + self.assertEqual(str(self.transform(e)), "[boolval(True)]") + + e = false.implies(true) + self.assertEqual(str(e), "boolval(True)") + e = true.implies(false) + self.assertEqual(str(e), "boolval(False)") + + e = Operator("->", [false, true]) + self.assertEqual(str(self.transform(e)), "[boolval(True)]") + e = Operator("->", [true, false]) + self.assertEqual(str(self.transform(e)), "[boolval(False)]") + + # with non-CPMpy constants + e = true.implies(False) + self.assertEqual(str(e), "boolval(False)") + e = Operator("->", [true, False]) + self.assertEqual(str(self.transform(e)), "[boolval(False)]") + e = false.implies(True) + self.assertEqual(str(e), "boolval(True)") + e = Operator("->", [false, True]) + self.assertEqual(str(self.transform(e)), "[boolval(True)]") + + + +