Skip to content

Commit df973ca

Browse files
committed
More fixes
1 parent 8042b51 commit df973ca

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

src/gurobi_ml/modeling/softmax.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,36 +28,6 @@
2828
_HAS_NL_EXPR = False
2929

3030

31-
def _addGenConstrIndicatorMvarV10(self, binvar, binval, lhs, sense, rhs, name):
32-
"""This function is to work around the lack of MVar compatibility in
33-
Gurobi v10 indicator constraints. Note, it is not as flexible as Model.addGenConstrIndicator
34-
in V11+. If support for v10 is dropped this function can be removed.
35-
36-
Parameters
37-
----------
38-
binvar : MVar
39-
binval : {0,1}
40-
lhs : MVar or MLinExpr
41-
sense : (char)
42-
Options are gp.GRB.LESS_EQUAL, gp.GRB.EQUAL, or gp.GRB.GREATER_EQUAL
43-
rhs : scalar
44-
name : string
45-
"""
46-
assert binvar.shape == lhs.shape
47-
total_constraints = np.prod(binvar.shape)
48-
binvar = binvar.reshape(total_constraints).tolist()
49-
lhs = lhs.reshape(total_constraints).tolist()
50-
for index in range(total_constraints):
51-
self.gp_model.addGenConstrIndicator(
52-
binvar[index],
53-
binval,
54-
lhs[index],
55-
sense,
56-
rhs,
57-
name=self._indexed_name(index, name),
58-
)
59-
60-
6131
def max2(
6232
predictor_model: AbstractPredictorConstr, linear_predictor: gp.MVar, epsilon: float
6333
):
@@ -66,11 +36,40 @@ def max2(
6636
(predictor_model.output.shape[0], 1), vtype=gp.GRB.BINARY, name="bin_output"
6737
)
6838

39+
def _addGenConstrIndicatorMvarV10(binvar, binval, lhs, sense, rhs, name):
40+
"""This function is to work around the lack of MVar compatibility in
41+
Gurobi v10 indicator constraints. Note, it is not as flexible as Model.addGenConstrIndicator
42+
in V11+. If support for v10 is dropped this function can be removed.
43+
44+
Parameters
45+
----------
46+
binvar : MVar
47+
binval : {0,1}
48+
lhs : MVar or MLinExpr
49+
sense : (char)
50+
Options are gp.GRB.LESS_EQUAL, gp.GRB.EQUAL, or gp.GRB.GREATER_EQUAL
51+
rhs : scalar
52+
name : string
53+
"""
54+
assert binvar.shape == lhs.shape
55+
total_constraints = np.prod(binvar.shape)
56+
binvar = binvar.reshape(total_constraints).tolist()
57+
lhs = lhs.reshape(total_constraints).tolist()
58+
for index in range(total_constraints):
59+
predictor_model.gp_model.addGenConstrIndicator(
60+
binvar[index],
61+
binval,
62+
lhs[index],
63+
sense,
64+
rhs,
65+
name=predictor_model._indexed_name(index, name),
66+
)
67+
6968
# Workaround for MVars in indicator constraints for v10.
7069
addGenConstrIndicator = (
7170
predictor_model.gp_model.addGenConstrIndicator
7271
if gp.gurobi.version()[0] >= 11
73-
else predictor_model._addGenConstrIndicatorMvarV10
72+
else _addGenConstrIndicatorMvarV10
7473
)
7574

7675
# The original epsilon is with respect to the range of the logistic function.

tests/test_sklearn/sklearn_cases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,10 @@ def __init__(self):
334334
remainder="drop",
335335
),
336336
]
337+
excluded = ["LogisticRegression", "MLPClassifier"]
337338
super().__init__(
338339
"wages",
339-
excluded=["LogisticRegression"],
340+
excluded=excluded,
340341
transformers=preprocessors,
341342
need_pipeline=True,
342343
saved_training=100,

0 commit comments

Comments
 (0)