Skip to content

Commit 0c981b3

Browse files
committed
Improve regular logistic regression
1 parent 935481a commit 0c981b3

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

docs/examples/example2_student_admission.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123

124124
# Run our regression
125125
scaler = StandardScaler()
126-
regression = LogisticRegression(random_state=1)
126+
regression = LogisticRegression(random_state=10)
127127
pipe = make_pipeline(scaler, regression)
128128
pipe.fit(X=historical_data.loc[:, features], y=historical_data.loc[:, target])
129129

@@ -144,7 +144,7 @@
144144
nstudents = 20
145145

146146
# Select randomly nstudents in the data
147-
studentsdata = studentsdata.sample(nstudents, random_state=1)
147+
studentsdata = studentsdata.sample(nstudents, random_state=10)
148148

149149

150150
######################################################################
@@ -229,7 +229,8 @@
229229
# documentation <https://gurobi-machinelearning.readthedocs.io/en/v1.3.0/mlm-examples/student_admission.html>`__
230230
# for dealing with those approximations.
231231
#
232-
232+
m.Params.NodeLimit = 10000
233+
m.write("students.lp")
233234
m.optimize()
234235

235236

@@ -251,7 +252,7 @@
251252
# regression in a solution as a pandas dataframe using input_values.
252253
#
253254

254-
pred_constr.input_values
255+
pred_constr.input_values.round(3)
255256

256257

257258
######################################################################

src/gurobi_ml/modeling/softmax.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,22 @@ def _addGenConstrIndicatorMvarV10(binvar, binval, lhs, sense, rhs, name):
100100
def logistic(predictor_model: AbstractPredictorConstr, linear_predictor: gp.MVar):
101101
log_result = predictor_model.output[:, 1]
102102

103-
for index in np.ndindex(log_result.shape):
104-
predictor_model.gp_model.addGenConstrLogistic(
105-
linear_predictor[index],
106-
log_result[index],
107-
name=predictor_model._indexed_name(index, "logistic"),
103+
if _HAS_NL_EXPR:
104+
predictor_model.gp_model.addConstr(
105+
log_result == nlfunc.logistic(linear_predictor[:, 0])
108106
)
109-
num_gc = predictor_model.gp_model.NumGenConstrs
110-
predictor_model.gp_model.update()
111-
for gen_constr in predictor_model.gp_model.getGenConstrs()[num_gc:]:
112-
for attr, val in predictor_model.attributes.items():
113-
gen_constr.setAttr(attr, val)
107+
else:
108+
for index in np.ndindex(log_result.shape):
109+
predictor_model.gp_model.addGenConstrLogistic(
110+
linear_predictor[index],
111+
log_result[index],
112+
name=predictor_model._indexed_name(index, "logistic"),
113+
)
114+
num_gc = predictor_model.gp_model.NumGenConstrs
115+
predictor_model.gp_model.update()
116+
for gen_constr in predictor_model.gp_model.getGenConstrs()[num_gc:]:
117+
for attr, val in predictor_model.attributes.items():
118+
gen_constr.setAttr(attr, val)
114119

115120

116121
def hardmax(

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ python =
1414
3.9: py39-all_deps-gurobi11
1515
3.10: py310-all_deps-gurobi11
1616
3.12: py312-all_deps-gurobi11
17-
3.11: pre-commit,docs,py311-{lightgbm,keras,pytorch,sklearn,xgboost,no_deps,all_deps}-{gurobi11-gurobi12}
17+
3.11: pre-commit,docs,py311-{lightgbm,keras,pytorch,sklearn,xgboost,no_deps,all_deps}-{gurobi11,gurobi12}
1818

1919
[testenv:docs]
2020
deps=

0 commit comments

Comments
 (0)