Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 148 additions & 18 deletions cpmpy/tools/tune_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
The parameter tuner iteratively finds better hyperparameters close to the current best configuration during the search.
Searching and time-out start at the default configuration for a solver (if available in the solver class)
"""
import math
import time
from random import shuffle

import numpy as np

from ..solvers.utils import SolverLookup, param_combinations
from ..solvers.solver_interface import ExitStatus
from ..solvers.solver_interface import ExitStatus, SolverInterface, SolverStatus


class ParameterTuner:
"""
Expand All @@ -42,7 +42,7 @@ def __init__(self, solvername, model, all_params=None, defaults=None):
self.best_params = SolverLookup.lookup(solvername).default_params()

self._param_order = list(self.all_params.keys())
self._best_config = self._params_to_np([self.best_params])
self._best_config = self._params_to_np([self.best_params])[0]

def tune(self, time_limit=None, max_tries=None, fix_params={}):
"""
Expand All @@ -54,12 +54,19 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}):
start_time = time.time()

# Init solver
solver = SolverLookup.get(self.solvername, self.model)
solver.solve(**self.best_params)
if not isinstance(self.model, list):
solver = SolverLookup.get(self.solvername, self.model)
else:
solver = MultiSolver(self.solvername, self.model)
solver.solve(**self.best_params, time_limit=time_limit)
if not _has_finished(solver):
raise TimeoutError("Time's up before solving init solver call")

self.base_runtime = solver.status().runtime
self.best_runtime = self.base_runtime



# Get all possible hyperparameter configurations
combos = list(param_combinations(self.all_params))
combos_np = self._params_to_np(combos)
Expand All @@ -72,7 +79,10 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}):
max_tries = len(combos_np)
while len(combos_np) and i < max_tries:
# Make new solver
solver = SolverLookup.get(self.solvername, self.model)
if not isinstance(self.model, list):
solver = SolverLookup.get(self.solvername, self.model)
else:
solver = MultiSolver(self.solvername, self.model)
# Apply scoring to all combos
scores = self._get_score(combos_np)
max_idx = np.where(scores == scores.min())[0][0]
Expand All @@ -87,22 +97,23 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}):
timeout = self.best_runtime
# set timeout depending on time budget
if time_limit is not None:
if (time.time() - start_time) >= time_limit:
break
timeout = min(timeout, time_limit - (time.time() - start_time))
# run solver
solver.solve(**params_dict, time_limit=timeout)
if solver.status().exitstatus == ExitStatus.OPTIMAL and solver.status().runtime < self.best_runtime:
if _has_finished(solver):
self.best_runtime = solver.status().runtime
# update surrogate
self._best_config = params_np

if time_limit is not None and (time.time() - start_time) >= time_limit:
break
i += 1

self.best_params = self._np_to_params(self._best_config)
self.best_params.update(fix_params)
return self.best_params


def _get_score(self, combos):
"""
Return the hamming distance for each remaining configuration to the current best config.
Expand Down Expand Up @@ -134,13 +145,17 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}):
start_time = time.time()

# Init solver
solver = SolverLookup.get(self.solvername, self.model)
solver.solve(**self.best_params)
if not isinstance(self.model, list):
solver = SolverLookup.get(self.solvername, self.model)
else:
solver = MultiSolver(self.solvername, self.model)
solver.solve(**self.best_params,time_limit=time_limit)
if not _has_finished(solver):
raise TimeoutError("Time's up before solving init solver call")


self.base_runtime = solver.status().runtime
self.best_runtime = self.base_runtime

# Get all possible hyperparameter configurations
combos = list(param_combinations(self.all_params))
shuffle(combos) # test in random order
Expand All @@ -150,24 +165,139 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}):

for params_dict in combos:
# Make new solver
solver = SolverLookup.get(self.solvername, self.model)
if not isinstance(self.model, list):
solver = SolverLookup.get(self.solvername, self.model)
else:
solver = MultiSolver(self.solvername, self.model)
# set fixed params
params_dict.update(fix_params)
timeout = self.best_runtime
# set timeout depending on time budget
if time_limit is not None:
if (time.time() - start_time) >= time_limit:
break
timeout = min(timeout, time_limit - (time.time() - start_time))
# run solver
solver.solve(**params_dict, time_limit=timeout)
if solver.status().exitstatus == ExitStatus.OPTIMAL and solver.status().runtime < self.best_runtime:
if _has_finished(solver):
self.best_runtime = solver.status().runtime
# update surrogate
self.best_params = params_dict
return self.best_params

if time_limit is not None and (time.time() - start_time) >= time_limit:
break
def _has_finished(solver):
"""
Check whether a given solver has found the target solution.
Parameters
----------
solver : SolverInterface

Returns
-------
bool
True if the solver has has found the target solution. This means:
- For a `MultiSolver`: its own `has_finished()` method determines completion.
- For a problem with an objective: status is OPTIMAL.
- For a problem without an objective: status is FEASIBLE.
- For an unsat problem: status is UNSATISFIABLE.
False otherwise.
"""
if isinstance(solver,MultiSolver):
return solver.has_finished()
elif (((solver.has_objective() and solver.status().exitstatus == ExitStatus.OPTIMAL) or
(not solver.has_objective() and solver.status().exitstatus == ExitStatus.FEASIBLE)) or
(solver.status().exitstatus == ExitStatus.UNSATISFIABLE)):
return True
return False

return self.best_params


class MultiSolver(SolverInterface):
"""
Class that manages multiple solver instances.
Attributes
----------
name : str
Name of the solver used for all instances.
solvers : list of SolverInterface
The solver instances corresponding to each model.
cpm_status : SolverStatus
Aggregated solver status. Tracks runtime and per-solver exit statuses.
"""

def __init__(self,solvername,models):
"""
Initialize a MultiSolver with the given list of solvers.
Parameters
----------
solvername : str
Name of the solver backend (e.g., "ortools", "gurobi").
models : list of Model
The models to create solver instances for.
"""

self.name = solvername
self.solvers = []
for mdl in models:
self.solvers.append(SolverLookup.get(solvername,mdl))

def solve(self, time_limit=None, **kwargs):
"""
Solve the models sequentially using the solvers.

Parameters
----------
time_limit :
Global time limit in seconds for all solvers combined.
**kwargs : dict
Additional arguments passed to each solve method.

Returns
-------
bool
True if all solvers returned a solution, False otherwise.
"""
self.cpm_status = SolverStatus(self.name)
self.cpm_status.exitstatus = [ExitStatus.NOT_RUN] * len(self.solvers)
all_has_sol = True
# initialize exitstatus list
init_start = time.time()
for i, s in enumerate(self.solvers):
# call solver
start = time.time()
has_sol = s.solve(time_limit=time_limit, **kwargs)
# update only the current solver's exitstatus
self.cpm_status.exitstatus[i] = s.status().exitstatus
if time_limit is not None:
time_limit = time_limit - (time.time() - start)
if time_limit <= 0:
break
all_has_sol = all_has_sol and has_sol
end = time.time()
# update runtime
self.cpm_status.runtime = end - init_start
return all_has_sol

def has_finished(self):
"""
Check whether all solvers in the MultiSolver have finished.

A solver is considered finished if:
- It has an objective and reached OPTIMAL, or
- It has no objective and reached FEASIBLE, or
- It reached UNSATISFIABLE.

Returns
-------
bool
True if all solvers have finished, False otherwise.
"""
all_have_finished = True
for s in self.solvers:
finished = ((s.has_objective() and s.status().exitstatus == ExitStatus.OPTIMAL) or
(not s.has_objective() and s.status().exitstatus == ExitStatus.FEASIBLE) or
(s.status().exitstatus == ExitStatus.UNSATISFIABLE))
all_have_finished = all_have_finished and finished
return all_have_finished


Loading