diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index aeebf9d58..d4da9ac17 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -566,6 +566,28 @@ def value(self): elif self.name == ">=": return arg_vals[0] >= arg_vals[1] return None # default + def get_bounds(self): + (lb1, ub1), (lb2, ub2) = get_bounds(self.args[0]), get_bounds(self.args[1]) + if self.name == "==": + if lb1 == ub1 == lb2 == ub2: return (1,1) # equal domains, trivially true + if ub1 < lb2 or ub2 < lb1: return (0,0) # disjoint, trivially false + if self.name == "!=": + if ub1 < lb2 or ub2 < lb1: return (1,1) # disjoint, trivially true + if lb1 == ub1 == lb2 == ub2: return (0,0) # equal domains, trivially false + if self.name == "<=": + if ub1 <= lb2: return (1,1) # domain of lhs is leq domain of rhs + if lb1 > ub2: return (0,0) # domain of lhs is gt domain of rhs + if self.name == "<": + if ub1 < lb2: return (1,1) # domain of lhs is lt domain of rhs + if lb1 >= ub2: return (0,0) # domain of lhs is geq domain of rhs + if self.name == ">=": + if lb1 >= ub2: return (1,1) # domain of lhs is geq domain of rhs + if ub1 < lb2: return (0,0) # domain of lhs is lt domain of rhs + if self.name == ">": + if lb1 > ub2: return (1,1) # domain of lhs is gt domain of rhs + if ub1 <= lb2: return (0,0) # domain of lhs is leq domain of rhs + return (0,1) + class Operator(Expression): """ @@ -751,18 +773,10 @@ def get_bounds(self): lowerbound, upperbound = sum(lbs), sum(ubs) elif self.name == 'wsum': weights, vars = self.args - bounds = [] - lowerbound, upperbound = 0,0 - #this may seem like too many lines, but avoiding np.sum avoids overflowing things at int32 bounds - for w, (lb, ub) in zip(weights, [get_bounds(arg) for arg in vars]): - x,y = int(w) * lb, int(w) * ub - if x <= y: # x is the lb of this arg - lowerbound += x - upperbound += y - else: - lowerbound += y - upperbound += x - + lbs, ubs = get_bounds(vars) + lbs, ubs = [w * lb for w,lb in zip(weights,lbs)], [w * ub for w, ub in zip(weights,ubs)] + lowerbound = sum(lb if lb <= ub else ub for lb,ub in zip(lbs,ubs)) + upperbound = sum(ub if ub >= lb else lb for lb, ub in zip(lbs, ubs)) elif self.name == 'sub': lb1, ub1 = get_bounds(self.args[0]) lb2, ub2 = get_bounds(self.args[1]) diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index 2be5c53d4..ced328b31 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -658,7 +658,7 @@ def decompose(self): decomp = [sum(self.args[:2]) == 1] if len(self.args) > 2: decomp = Xor([decomp,self.args[2:]]).decompose()[0] - return decomp, [] + return cp.transformations.normalize.simplify_boolean(decomp), [] def value(self): return sum(argvals(self.args)) % 2 == 1 diff --git a/cpmpy/transformations/normalize.py b/cpmpy/transformations/normalize.py index f45e4401c..53c68922b 100644 --- a/cpmpy/transformations/normalize.py +++ b/cpmpy/transformations/normalize.py @@ -9,7 +9,7 @@ from ..expressions.core import BoolVal, Expression, Comparison, Operator from ..expressions.globalfunctions import GlobalFunction -from ..expressions.utils import eval_comparison, is_false_cst, is_true_cst, is_boolexpr, is_num, is_bool +from ..expressions.utils import eval_comparison, is_false_cst, is_true_cst, is_boolexpr, is_num, is_bool, get_bounds from ..expressions.variables import NDVarArray, _BoolVarImpl from ..exceptions import NotSupportedError from ..expressions.globalconstraints import GlobalConstraint @@ -169,6 +169,15 @@ def simplify_boolean(lst_of_expr, num_context=False): elif isinstance(expr, Comparison): lhs, rhs = simplify_boolean(expr.args, num_context=True) name = expr.name + + lb, ub = get_bounds(eval_comparison(name, lhs, rhs)) + if lb == 0 == ub: + newlist.append(0 if num_context else BoolVal(False)) + continue + if lb == 1 == ub: + newlist.append(1 if num_context else BoolVal(True)) + continue + if is_num(lhs) and is_boolexpr(rhs): # flip arguments of comparison to reduct nb of cases if name == "<": name = ">" elif name == ">": name = "<" diff --git a/tests/test_constraints.py b/tests/test_constraints.py index d2d1cada8..4276769af 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -196,6 +196,7 @@ def global_constraints(solver): if name == "Xor": yield Xor(BOOL_ARGS) yield Xor(BOOL_ARGS + [True,False]) + yield Xor([True, BOOL_ARGS[0]]) continue elif name == "Inverse": expr = cls(NUM_ARGS, [1,0,2]) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 74520dc2b..ca2524797 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -6,7 +6,8 @@ from cpmpy.expressions import * from cpmpy.expressions.variables import NDVarArray from cpmpy.expressions.core import Comparison, Operator, Expression -from cpmpy.expressions.utils import eval_comparison, get_bounds, argval +from cpmpy.expressions.utils import eval_comparison, get_bounds, argval, all_pairs + class TestComparison(unittest.TestCase): def test_comps(self): @@ -450,6 +451,29 @@ def test_bounds_unary(self): self.assertGreaterEqual(val,lb) self.assertLessEqual(val,ub) + def test_bounds_comparison(self): + + x_00 = intvar(0,0, name="x00") + x_01 = intvar(0,1, name="x01") + x_12= intvar(1,2, name="x12") + x_23 = intvar(2,3, name="x23") + + for x,y in all_pairs([0, x_00, x_01, x_12, x_23]): + for comp in ['==','!=','<=','<','>=','>']: + x_bounds = get_bounds(x) + y_bounds = get_bounds(y) + + total_vals = len(range(x_bounds[0],x_bounds[1]+1)) * len(range(y_bounds[0],y_bounds[1]+1)) + + for expr in [Comparison(comp, x,y), Comparison(comp, y,x)]: + lb, ub = expr.get_bounds() + + if lb == 0 == ub: + self.assertEqual(cp.Model(expr).solveAll(), 0) + elif lb == 1 == ub: + self.assertEqual(cp.Model(expr).solveAll(), total_vals) + else: + self.assertNotEqual(cp.Model(expr).solveAll(), total_vals) def test_incomplete_func(self): # element constraint diff --git a/tests/test_flatten.py b/tests/test_flatten.py index 5384c5260..629138745 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -167,18 +167,18 @@ def test_constraint(self): self.assertEqual( str(flatten_constraint( x&y&~z )), "[BV0, BV1, ~BV2]" ) self.assertEqual( str(flatten_constraint( x.implies(y) )), "[(BV0) -> (BV1)]" ) self.assertEqual( str(flatten_constraint( x|(y.implies(z)) )), "[or([BV0, ~BV1, BV2])]" ) - self.assertEqual( str(flatten_constraint( (a > 10)&x )), "[IV0 > 10, BV0]" ) + self.assertEqual( str(flatten_constraint( (a > 8)&x )), "[IV0 > 8, BV0]" ) cp.boolvar() # increase counter - self.assertEqual( str(flatten_constraint( (a > 10).implies(x) )), "[(IV0 > 10) -> (BV0)]" ) + self.assertEqual( str(flatten_constraint( (a > 8).implies(x) )), "[(IV0 > 8) -> (BV0)]" ) cp.boolvar() # increase counter - self.assertEqual( str(flatten_constraint( (a > 10) )), "[IV0 > 10]" ) - self.assertEqual( str(flatten_constraint( (a > 10) == 1 )), "[IV0 > 10]" ) - self.assertEqual( str(flatten_constraint( (a > 10) == 0 )), "[IV0 <= 10]" ) - self.assertEqual( str(flatten_constraint( (a > 10) == x )), "[(IV0 > 10) == (BV0)]" ) + self.assertEqual( str(flatten_constraint( (a > 8) )), "[IV0 > 8]" ) + self.assertEqual( str(flatten_constraint( (a > 8) == 1 )), "[IV0 > 8]" ) + self.assertEqual( str(flatten_constraint( (a > 8) == 0 )), "[IV0 <= 8]" ) + self.assertEqual( str(flatten_constraint( (a > 8) == x )), "[(IV0 > 8) == (BV0)]" ) #self.assertEqual( str(flatten_constraint( x == (a > 10) )), "[(IV0 > 10) == (BV0)]" ) # TODO, make it do the swap (again) - self.assertEqual( str(flatten_constraint( (a > 10) | (b + c > 2) )), "[(BV5) or (BV6), (IV0 > 10) == (BV5), ((IV1) + (IV2) > 2) == (BV6)]" ) - self.assertEqual( str(flatten_constraint( a > 10 )), "[IV0 > 10]" ) - self.assertEqual( str(flatten_constraint( 10 > a )), "[IV0 < 10]" ) # surprising + self.assertEqual( str(flatten_constraint( (a > 8) | (b + c > 2) )), "[(BV5) or (BV6), (IV0 > 8) == (BV5), ((IV1) + (IV2) > 2) == (BV6)]" ) + self.assertEqual( str(flatten_constraint( a > 8 )), "[IV0 > 8]" ) + self.assertEqual( str(flatten_constraint( 8 > a )), "[IV0 < 8]" ) # surprising self.assertEqual( str(flatten_constraint( a+b > c )), "[((IV0) + (IV1)) > (IV2)]" ) #self.assertEqual( str(flatten_constraint( c < a+b )), "[((IV0) + (IV1)) > (IV2)]" ) # TODO, make it do the swap (again) self.assertEqual( str(flatten_constraint( (a+b > c) == x|y )), "[(((IV0) + (IV1)) > (IV2)) == (BV7), ((BV0) or (BV1)) == (BV7)]" ) @@ -213,7 +213,7 @@ def test_constraint(self): self.assertEqual( str(a % 1 == 0), "(IV0) mod 1 == 0" ) # boolexpr as numexpr - self.assertEqual( str(flatten_constraint((a + b == 2) <= c)), "[(BV11) <= (IV2), ((IV0) + (IV1) == 2) == (BV11)]" ) + self.assertEqual( str(flatten_constraint((a + b == 2) < c)), "[(BV11) < (IV2), ((IV0) + (IV1) == 2) == (BV11)]" ) # != in boolexpr, bug #170 self.assertEqual( str(normalized_boolexpr(x != (a == 1))), "((BV12) == (~BV0), [(IV0 == 1) == (BV12)])" ) diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 2b8c8e20e..309b6d53d 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -721,6 +721,23 @@ def test_xor_with_constants(self): self.assertFalse(cp.Model(cp.Xor([False, False])).solve()) self.assertFalse(cp.Model(cp.Xor([False, False, False])).solve()) + def test_issue_620(self): + a = cp.boolvar() + b = cp.boolvar() + c = cp.boolvar() + + model = cp.Model(cp.Xor([(cp.Xor([a, b, c])) <= True, ~((cp.Xor([a, b, c])) <= True)])) + + self.assertTrue(model.solve(solver='ortools')) + if "minizinc" in cp.SolverLookup.supported(): + self.assertTrue(model.solve(solver='minizinc')) + if "z3" in cp.SolverLookup.supported(): + self.assertTrue(model.solve(solver='z3')) + if "choco" in cp.SolverLookup.supported(): + self.assertTrue(model.solve(solver='choco')) + if "gurobi" in cp.SolverLookup.supported(): + self.assertTrue(model.solve(solver='gurobi')) + def test_ite_with_constants(self): x,y,z = cp.boolvar(shape=3) expr = cp.IfThenElse(True, y, z) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index bb4add10d..7219a2d5e 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -840,7 +840,7 @@ def test_installed_solvers(self, solver): model.solve(solver=solver) assert [int(a) for a in v.value()] == [0, 1, 0] - s = cp.SolverLookup.get(solver) + s = cp.SolverLookup.get(solver, model) s.solve() assert [int(a) for a in v.value()] == [0, 1, 0] diff --git a/tests/test_trans_simplify.py b/tests/test_trans_simplify.py index d733bec15..f8a80989b 100644 --- a/tests/test_trans_simplify.py +++ b/tests/test_trans_simplify.py @@ -9,7 +9,7 @@ class TransSimplify(unittest.TestCase): def setUp(self) -> None: self.bvs = cp.boolvar(shape=3, name="bv") - self.ivs = cp.intvar(0, 5, shape=3, name="iv") + self.ivs = cp.intvar(-1, 5, shape=3, name="iv") self.transform = lambda x: simplify_boolean(toplevel_list(x)) @@ -19,10 +19,10 @@ def test_bool_ops(self): expr = Operator("or", self.bvs.tolist() + [True]) self.assertEqual(str(self.transform(expr)), "[boolval(True)]") - expr = Operator("and", self.bvs.tolist() + [False]) + self.ivs[0] >= 10 - self.assertEqual(str(self.transform(expr)), "[0 + (iv[0]) >= 10]") - expr = Operator("and", self.bvs.tolist() + [True]) + self.ivs[0] >= 10 - self.assertEqual(str(self.transform(expr)), "[(and([bv[0], bv[1], bv[2]])) + (iv[0]) >= 10]") + expr = Operator("and", self.bvs.tolist() + [False]) + self.ivs[0] >= 3 + self.assertEqual(str(self.transform(expr)), "[0 + (iv[0]) >= 3]") + expr = Operator("and", self.bvs.tolist() + [True]) + self.ivs[0] >= 3 + self.assertEqual(str(self.transform(expr)), "[(and([bv[0], bv[1], bv[2]])) + (iv[0]) >= 3]") expr = Operator("->", [self.bvs[0], True]) @@ -35,16 +35,16 @@ def test_bool_ops(self): self.assertEqual(str(self.transform(expr)), "[boolval(True)]") def test_bool_in_comp(self): - expr = self.ivs[0] >= False - self.assertEqual(str(self.transform(expr)), '[iv[0] >= 0]') + expr = self.ivs[0] > False + self.assertEqual(str(self.transform(expr)), '[iv[0] > 0]') expr = self.ivs[0] >= True self.assertEqual(str(self.transform(expr)), '[iv[0] >= 1]') expr = (cp.sum(self.ivs) + True) >= 10 self.assertEqual(str(self.transform(expr)), '[sum([iv[0], iv[1], iv[2], 1]) >= 10]') - expr = True + self.ivs[0] >= False - self.assertEqual(str(self.transform(expr)), '[1 + (iv[0]) >= 0]') + expr = True + self.ivs[0] > False + self.assertEqual(str(self.transform(expr)), '[1 + (iv[0]) > 0]') def test_boolvar_comps(self): num_args = {"<0": -1, "0": 0, "]0..1[": 0.5, "1": 1, ">0": 2} @@ -87,8 +87,8 @@ def test_simplify_expressions(self): # with constant, does not change (surprisingly? but we cannot check what the res type is...) expr = cp.max(self.ivs.tolist() + [False]) == 0 self.assertEqual(str(self.transform(expr)), '[max(iv[0],iv[1],iv[2],boolval(False)) == 0]') - expr = 0 == cp.max(self.ivs.tolist() + [True]) - self.assertEqual(str(self.transform(expr)), '[max(iv[0],iv[1],iv[2],boolval(True)) == 0]') + expr = 1 == cp.max(self.ivs.tolist() + [True]) + self.assertEqual(str(self.transform(expr)), '[max(iv[0],iv[1],iv[2],boolval(True)) == 1]') expr = (self.ivs[0] <= self.ivs[1]) == 0 self.assertEqual(str(self.transform(expr)), '[not([(iv[0]) <= (iv[1])])]') diff --git a/tests/test_transf_reif.py b/tests/test_transf_reif.py index 9dcea9540..fc085f3fa 100644 --- a/tests/test_transf_reif.py +++ b/tests/test_transf_reif.py @@ -77,7 +77,7 @@ def test_reif_element(self): def test_reif_rewrite(self): bvs = boolvar(shape=4, name="bvs") - ivs = intvar(1,9, shape=3, name="ivs") + ivs = intvar(0,9, shape=3, name="ivs") rv = boolvar(name="rv") arr = cpm_array([0,1,2])