diff --git a/cpmpy/solvers/cplex.py b/cpmpy/solvers/cplex.py index 79d256ca7..917bea778 100644 --- a/cpmpy/solvers/cplex.py +++ b/cpmpy/solvers/cplex.py @@ -341,6 +341,10 @@ def _make_numexpr(self, cpm_expr): if cpm_expr.name == "sub": a,b = self.solver_vars(cpm_expr.args) return a - b + + if cpm_expr.name == "mul": + a,b = self.solver_vars(cpm_expr.args) + return a * b raise NotImplementedError("CPLEX: Not a known supported numexpr {}".format(cpm_expr)) @@ -366,10 +370,10 @@ def transform(self, cpm_expr): cpm_cons = decompose_in_tree(cpm_cons, supported, csemap=self._csemap) cpm_cons = flatten_constraint(cpm_cons, csemap=self._csemap) # flat normal form cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum', 'sub']), csemap=self._csemap) # constraints that support reification - cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"]), csemap=self._csemap) # supports >, <, != + cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub", "mul"]), csemap=self._csemap) # supports >, <, != cpm_cons = only_bv_reifies(cpm_cons, csemap=self._csemap) cpm_cons = only_implies(cpm_cons, csemap=self._csemap) # anything that can create full reif should go above... - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum", "sub", "min", "max", "abs", "mul"}), csemap=self._csemap) # CPLEX supports quadratic constraints and division by constants + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum", "->", "sub", "min", "max", "abs"}), csemap=self._csemap) # CPLEX supports quadratic constraints and division by constants cpm_cons = only_positive_bv(cpm_cons, csemap=self._csemap) # after linearization, rewrite ~bv into 1-bv return cpm_cons @@ -416,10 +420,6 @@ def add(self, cpm_expr_orig): # a BoundedLinearExpression LHS, special case, like in objective cplexlhs = self._make_numexpr(lhs) self.cplex_model.add_constraint(cplexlhs == cplexrhs) - - elif lhs.name == 'mul': - raise NotImplementedError(f'CPLEX only supports quadratic constraints that define a convex region, i.e. quadratic equalities are not supported: {cpm_expr}') - else: # Global functions if lhs.name == 'min': diff --git a/cpmpy/solvers/exact.py b/cpmpy/solvers/exact.py index 37c00c1b0..553f5181f 100644 --- a/cpmpy/solvers/exact.py +++ b/cpmpy/solvers/exact.py @@ -511,7 +511,7 @@ def transform(self, cpm_expr): cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum"]), csemap=self._csemap) # supports >, <, != cpm_cons = only_bv_reifies(cpm_cons, csemap=self._csemap) cpm_cons = only_implies(cpm_cons, csemap=self._csemap) # anything that can create full reif should go above... - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum","mul"}), csemap=self._csemap) # the core of the MIP-linearization + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum","->","mul"}), csemap=self._csemap) # the core of the MIP-linearization cpm_cons = only_positive_bv(cpm_cons, csemap=self._csemap) # after linearisation, rewrite ~bv into 1-bv return cpm_cons diff --git a/cpmpy/solvers/gurobi.py b/cpmpy/solvers/gurobi.py index 9d1e6f012..570b4d95e 100644 --- a/cpmpy/solvers/gurobi.py +++ b/cpmpy/solvers/gurobi.py @@ -356,7 +356,7 @@ def transform(self, cpm_expr): cpm_cons = only_bv_reifies(cpm_cons, csemap=self._csemap) cpm_cons = only_implies(cpm_cons, csemap=self._csemap) # anything that can create full reif should go above... # gurobi does not round towards zero, so no 'div' in supported set: https://github.com/CPMpy/cpmpy/pull/593#issuecomment-2786707188 - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","sub","min","max","mul","abs","pow"}), csemap=self._csemap) # the core of the MIP-linearization + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","->","sub","min","max","mul","abs","pow"}), csemap=self._csemap) # the core of the MIP-linearization cpm_cons = only_positive_bv(cpm_cons, csemap=self._csemap) # after linearization, rewrite ~bv into 1-bv return cpm_cons diff --git a/cpmpy/solvers/pindakaas.py b/cpmpy/solvers/pindakaas.py index 20540903a..66cefc3d0 100755 --- a/cpmpy/solvers/pindakaas.py +++ b/cpmpy/solvers/pindakaas.py @@ -238,7 +238,7 @@ def transform(self, cpm_expr): cpm_cons = only_bv_reifies(cpm_cons, csemap=self._csemap) cpm_cons = only_implies(cpm_cons, csemap=self._csemap) cpm_cons = linearize_constraint( - cpm_cons, supported=frozenset({"sum", "wsum", "and", "or"}), csemap=self._csemap + cpm_cons, supported=frozenset({"sum", "wsum", "->", "and", "or"}), csemap=self._csemap ) cpm_cons = int2bool(cpm_cons, self.ivarmap, encoding=self.encoding) return cpm_cons diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index def9e6804..1525058d1 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -371,7 +371,7 @@ def transform(self, cpm_expr): cpm_cons = flatten_constraint(cpm_cons, csemap=self._csemap) # flat normal form cpm_cons = only_bv_reifies(cpm_cons, csemap=self._csemap) cpm_cons = only_implies(cpm_cons, csemap=self._csemap) - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum", "and", "or"}), csemap=self._csemap) # the core of the MIP-linearization + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum", "->", "and", "or"}), csemap=self._csemap) # the core of the MIP-linearization cpm_cons = int2bool(cpm_cons, self.ivarmap, encoding=self.encoding) cpm_cons = only_positive_coefficients(cpm_cons) return cpm_cons diff --git a/cpmpy/transformations/linearize.py b/cpmpy/transformations/linearize.py index d128ba0c7..b8ca87efe 100644 --- a/cpmpy/transformations/linearize.py +++ b/cpmpy/transformations/linearize.py @@ -70,7 +70,7 @@ from ..expressions.variables import _BoolVarImpl, boolvar, NegBoolView, _NumVarImpl, intvar -def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False, csemap=None): +def linearize_constraint(lst_of_expr, supported={"sum","wsum","->"}, reified=False, csemap=None): """ Transforms all constraints to a linear form. This function assumes all constraints are in 'flat normal form' with only boolean variables on the lhs of an implication. @@ -128,14 +128,40 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False, c continue elif is_false_cst(lin): indicator_constraints=[] # do not add any constraints - newlist+=linearize_constraint([~cond], supported=supported, csemap=csemap) # post linear version of unary constraint + newlist += linearize_constraint([~cond], supported=supported, csemap=csemap, reified=reified) # post linear version of unary constraint break # do not need to add other - else: + elif "->" in supported and not reified: indicator_constraints.append(cond.implies(lin)) # Add indicator constraint + else: # need to linearize the implication constraint itself + # either -> is not supported, or we are in a reified context (nested -> constraints are not linear) + assert isinstance(lin, Comparison), f"Expected a comparison as rhs of implication constraint, got {lin}" + if lin.args[0].name not in {"sum", "wsum"}: + assert lin.args[0].name in supported, f"Unexpected rhs of implication: {lin}, it is not supported ({supported})" + indicator_constraints.append(cond.implies(lin)) + continue + + # need to write as big-M + assert lin.args[0].name in frozenset({'sum', 'wsum'}), f"Expected sum or wsum as rhs of implication constraint, but got {lin}" + assert is_num(lin.args[1]) + lb, ub = get_bounds(lin.args[0]) + if lin.name == "<=": + M = lin.args[1] - ub # subtracting M from lhs will always satisfy the implied constraint + lin.args[0] += M * ~cond + indicator_constraints.append(lin) + elif lin.name == ">=": + M = lin.args[1] - lb # adding M to lhs will always satisfy the implied constraint + lin.args[0] += M * ~cond + indicator_constraints.append(lin) + elif lin.name == "==": + indicator_constraints += linearize_constraint([cond.implies(lin.args[0] <= lin.args[1]), + cond.implies(lin.args[0] >= lin.args[1])], + supported=supported, reified=reified, csemap=csemap) + else: + raise ValueError(f"Unexpected linearized rhs of implication {lin} in {cpm_expr}") newlist+=indicator_constraints # ensure no new solutions are created - new_vars = set(get_variables(lin_sub)) - set(get_variables(sub_expr)) + new_vars = set(get_variables(lin_sub)) - set(get_variables(sub_expr)) - {cond, ~cond} newlist += linearize_constraint([(~cond).implies(nv == nv.lb) for nv in new_vars], supported=supported, reified=reified, csemap=csemap) else: # supported operator @@ -158,9 +184,26 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False, c # linearize unsupported operators elif isinstance(lhs, Operator) and lhs.name not in supported: - if lhs.name == "mul" and is_num(lhs.args[0]): - lhs = Operator("wsum",[[lhs.args[0]], [lhs.args[1]]]) - cpm_expr = eval_comparison(cpm_expr.name, lhs, rhs) + if lhs.name == "mul": + bv_idx = None + if is_num(lhs.args[0]): # const * iv rhs + lhs = Operator("wsum",[[lhs.args[0]], [lhs.args[1]]]) + newlist += linearize_constraint([eval_comparison(cpm_expr.name, lhs, rhs)], supported=supported, reified=reified, csemap=csemap) + continue + elif isinstance(lhs.args[0], _BoolVarImpl): + bv_idx = 0 + elif isinstance(lhs.args[1], _BoolVarImpl): + bv_idx = 1 + + if bv_idx is not None: + # bv * iv rhs, rewrite to (bv -> iv rhs) & (~bv -> 0 rhs) + bv, iv = lhs.args[bv_idx], lhs.args[1-bv_idx] + bv_true = bv.implies(eval_comparison(cpm_expr.name, iv, rhs)) + bv_false = (~bv).implies(eval_comparison(cpm_expr.name, 0, rhs)) + newlist += linearize_constraint(simplify_boolean([bv_true, bv_false]), supported=supported, reified=reified, csemap=csemap) + continue + else: + raise NotImplementedError(f"Linearization of integer multiplication {cpm_expr} is not supported") elif lhs.name == "pow" and "pow" not in supported: if "mul" not in supported: diff --git a/tests/test_constraints.py b/tests/test_constraints.py index b7eb9f836..a41ab5fde 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -46,11 +46,11 @@ # Exclude certain operators for solvers. # Not all solvers support all operators in CPMpy EXCLUDE_OPERATORS = {"gurobi": {}, - "pysat": {"mul", "div", "pow", "mod"}, # int2bool but mul, and friends, not linearized + "pysat": {"mul-int", "div", "pow", "mod"}, # int2bool but integer-multiplication, and friends, not linearized "pysdd": {"sum", "wsum", "sub", "mod", "div", "pow", "abs", "mul","-"}, - "pindakaas": {"mul", "div", "pow", "mod"}, + "pindakaas": {"mul-int", "div", "pow", "mod"}, "exact": {}, - "cplex": {"mul", "div", "mod", "pow"}, + "cplex": {"mul-int", "div", "mod", "pow"}, "pumpkin": {"pow", "mod"}, } @@ -86,13 +86,17 @@ def numexprs(solver): yield Operator("wsum", [list(range(len(NUM_ARGS))), NUM_ARGS]) yield Operator("wsum", [[True, BoolVal(False), np.True_], NUM_ARGS]) # bit of everything continue - elif name == "mul": - yield Operator(name, [3,NUM_ARGS[0]]) - yield Operator(name, NUM_ARGS[:2]) - if solver != "minizinc": # bug in minizinc, see https://github.com/MiniZinc/libminizinc/issues/962 - yield Operator(name, [3,BOOL_ARGS[0]]) elif name == "div" or name == "pow": yield Operator(name, [NN_VAR,3]) + elif name == "mul" and "mul-int" not in EXCLUDE_OPERATORS.get(solver, {}): + yield Operator(name, [3, NUM_ARGS[0]]) + yield Operator(name, NUM_ARGS[:arity]) + yield Operator(name, NUM_ARGS[:2]) + if solver != "minizinc": # bug in minizinc, see https://github.com/MiniZinc/libminizinc/issues/962 + yield Operator(name, [3, BOOL_ARGS[0]]) + + elif name == "mul" and "mul-bool" not in EXCLUDE_OPERATORS.get(solver, {}): + yield Operator(name, BOOL_ARGS[:arity]) elif arity != 0: yield Operator(name, NUM_ARGS[:arity]) else: diff --git a/tests/test_trans_linearize.py b/tests/test_trans_linearize.py index 87c0c0178..ba42cae70 100644 --- a/tests/test_trans_linearize.py +++ b/tests/test_trans_linearize.py @@ -238,6 +238,73 @@ def test_sub(self): [lin_cons] = linearize_constraint([cons]) self.assertEqual(str(lin_cons), "sum([1, -1] * [x, y]) == 3") + def test_bool_mult(self): + + x = cp.intvar(-5, 10, name="x") + y = cp.intvar(-5, 10, name="y") + a = cp.boolvar(name="a") + b = cp.boolvar(name="b") + + def assert_cons_is_true(cons): + return lambda : self.assertTrue(cons.value()) + + cons = b * x == y + bt,bf = linearize_constraint([cons]) + self.assertEqual(str(bt), "(b) -> (sum([1, -1] * [x, y]) == 0)") + self.assertEqual(str(bf), "(~b) -> (sum([y]) == 0)") + + cp.Model([bt,bf]).solveAll(display=assert_cons_is_true(cons)) + + cons = x * b == y + bt,bf = linearize_constraint([cons]) + self.assertEqual(str(bt), "(b) -> (sum([1, -1] * [x, y]) == 0)") + self.assertEqual(str(bf), "(~b) -> (sum([y]) == 0)") + + cp.Model([bt,bf]).solveAll(display=assert_cons_is_true(cons)) + + cons = a.implies(b * x <= y) + lin_cons = linearize_constraint([cons]) + self.assertEqual(str(lin_cons[0]), "(a) -> (sum([1, -1, -15] * [x, y, ~b]) <= 0)") + self.assertEqual(str(lin_cons[1]), "(a) -> (sum([1, 5] * [y, b]) >= 0)") + + lin_cnt = cp.Model(lin_cons).solveAll(display=assert_cons_is_true(cons)) + cons_cnt = cp.Model(cons).solveAll(display=assert_cons_is_true(cp.all(lin_cons))) + self.assertEqual(lin_cnt, cons_cnt) + + cons = a.implies(b * x >= y) + lin_cons = linearize_constraint([cons]) + self.assertEqual(str(lin_cons[0]), "(a) -> (sum([1, -1, 15] * [x, y, ~b]) >= 0)") + self.assertEqual(str(lin_cons[1]), "(a) -> (sum([1, -10] * [y, b]) <= 0)") + + lin_cnt = cp.Model(lin_cons).solveAll(display=assert_cons_is_true(cons)) + cons_cnt = cp.Model(cons).solveAll(display=assert_cons_is_true(cp.all(lin_cons))) + self.assertEqual(lin_cnt, cons_cnt) + + + def test_implies(self): + + x = cp.intvar(1, 10, name="x") + y = cp.intvar(1, 10, name="y") + b = cp.boolvar(name="b") + + cons = b.implies(x + y <= 5) + [lin_cons] = linearize_constraint([cons], supported={"sum", "wsum"}) # no support for "->" + self.assertEqual(str(lin_cons), "sum([1, 1, -15] * [x, y, ~b]) <= 5") + + cons = b.implies(x + y >= 5) + [lin_cons] = linearize_constraint([cons], supported={"sum", "wsum"}) # no support for "->" + self.assertEqual(str(lin_cons), "sum([1, 1, 3] * [x, y, ~b]) >= 5") + + cons = b.implies(x + y == 5) + lin_cons = linearize_constraint([cons], supported={"sum", "wsum"}) # no support for "->" + assert len(lin_cons) == 2 + self.assertEqual(str(lin_cons[0]), "sum([1, 1, -15] * [x, y, ~b]) <= 5") + self.assertEqual(str(lin_cons[1]), "sum([1, 1, 3] * [x, y, ~b]) >= 5") + + + + + class TestConstRhs(unittest.TestCase):