diff --git a/cpmpy/tools/tune_solver.py b/cpmpy/tools/tune_solver.py index 0e3def99a..9a629f577 100644 --- a/cpmpy/tools/tune_solver.py +++ b/cpmpy/tools/tune_solver.py @@ -14,13 +14,13 @@ The parameter tuner iteratively finds better hyperparameters close to the current best configuration during the search. Searching and time-out start at the default configuration for a solver (if available in the solver class) """ +import math import time from random import shuffle - import numpy as np - from ..solvers.utils import SolverLookup, param_combinations -from ..solvers.solver_interface import ExitStatus +from ..solvers.solver_interface import ExitStatus, SolverInterface, SolverStatus + class ParameterTuner: """ @@ -42,7 +42,7 @@ def __init__(self, solvername, model, all_params=None, defaults=None): self.best_params = SolverLookup.lookup(solvername).default_params() self._param_order = list(self.all_params.keys()) - self._best_config = self._params_to_np([self.best_params]) + self._best_config = self._params_to_np([self.best_params])[0] def tune(self, time_limit=None, max_tries=None, fix_params={}): """ @@ -54,12 +54,19 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}): start_time = time.time() # Init solver - solver = SolverLookup.get(self.solvername, self.model) - solver.solve(**self.best_params) + if not isinstance(self.model, list): + solver = SolverLookup.get(self.solvername, self.model) + else: + solver = MultiSolver(self.solvername, self.model) + solver.solve(**self.best_params, time_limit=time_limit) + if not _has_finished(solver): + raise TimeoutError("Time's up before solving init solver call") self.base_runtime = solver.status().runtime self.best_runtime = self.base_runtime + + # Get all possible hyperparameter configurations combos = list(param_combinations(self.all_params)) combos_np = self._params_to_np(combos) @@ -72,7 +79,10 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}): max_tries = len(combos_np) while len(combos_np) and i < max_tries: # Make new solver - solver = SolverLookup.get(self.solvername, self.model) + if not isinstance(self.model, list): + solver = SolverLookup.get(self.solvername, self.model) + else: + solver = MultiSolver(self.solvername, self.model) # Apply scoring to all combos scores = self._get_score(combos_np) max_idx = np.where(scores == scores.min())[0][0] @@ -87,22 +97,23 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}): timeout = self.best_runtime # set timeout depending on time budget if time_limit is not None: + if (time.time() - start_time) >= time_limit: + break timeout = min(timeout, time_limit - (time.time() - start_time)) # run solver solver.solve(**params_dict, time_limit=timeout) - if solver.status().exitstatus == ExitStatus.OPTIMAL and solver.status().runtime < self.best_runtime: + if _has_finished(solver): self.best_runtime = solver.status().runtime # update surrogate self._best_config = params_np - if time_limit is not None and (time.time() - start_time) >= time_limit: - break i += 1 self.best_params = self._np_to_params(self._best_config) self.best_params.update(fix_params) return self.best_params + def _get_score(self, combos): """ Return the hamming distance for each remaining configuration to the current best config. @@ -134,13 +145,17 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}): start_time = time.time() # Init solver - solver = SolverLookup.get(self.solvername, self.model) - solver.solve(**self.best_params) + if not isinstance(self.model, list): + solver = SolverLookup.get(self.solvername, self.model) + else: + solver = MultiSolver(self.solvername, self.model) + solver.solve(**self.best_params,time_limit=time_limit) + if not _has_finished(solver): + raise TimeoutError("Time's up before solving init solver call") self.base_runtime = solver.status().runtime self.best_runtime = self.base_runtime - # Get all possible hyperparameter configurations combos = list(param_combinations(self.all_params)) shuffle(combos) # test in random order @@ -150,24 +165,139 @@ def tune(self, time_limit=None, max_tries=None, fix_params={}): for params_dict in combos: # Make new solver - solver = SolverLookup.get(self.solvername, self.model) + if not isinstance(self.model, list): + solver = SolverLookup.get(self.solvername, self.model) + else: + solver = MultiSolver(self.solvername, self.model) # set fixed params params_dict.update(fix_params) timeout = self.best_runtime # set timeout depending on time budget if time_limit is not None: + if (time.time() - start_time) >= time_limit: + break timeout = min(timeout, time_limit - (time.time() - start_time)) # run solver solver.solve(**params_dict, time_limit=timeout) - if solver.status().exitstatus == ExitStatus.OPTIMAL and solver.status().runtime < self.best_runtime: + if _has_finished(solver): self.best_runtime = solver.status().runtime # update surrogate self.best_params = params_dict + return self.best_params - if time_limit is not None and (time.time() - start_time) >= time_limit: - break +def _has_finished(solver): + """ + Check whether a given solver has found the target solution. + Parameters + ---------- + solver : SolverInterface + + Returns + ------- + bool + True if the solver has has found the target solution. This means: + - For a `MultiSolver`: its own `has_finished()` method determines completion. + - For a problem with an objective: status is OPTIMAL. + - For a problem without an objective: status is FEASIBLE. + - For an unsat problem: status is UNSATISFIABLE. + False otherwise. + """ + if isinstance(solver,MultiSolver): + return solver.has_finished() + elif (((solver.has_objective() and solver.status().exitstatus == ExitStatus.OPTIMAL) or + (not solver.has_objective() and solver.status().exitstatus == ExitStatus.FEASIBLE)) or + (solver.status().exitstatus == ExitStatus.UNSATISFIABLE)): + return True + return False - return self.best_params +class MultiSolver(SolverInterface): + """ + Class that manages multiple solver instances. + Attributes + ---------- + name : str + Name of the solver used for all instances. + solvers : list of SolverInterface + The solver instances corresponding to each model. + cpm_status : SolverStatus + Aggregated solver status. Tracks runtime and per-solver exit statuses. + """ + + def __init__(self,solvername,models): + """ + Initialize a MultiSolver with the given list of solvers. + Parameters + ---------- + solvername : str + Name of the solver backend (e.g., "ortools", "gurobi"). + models : list of Model + The models to create solver instances for. + """ + + self.name = solvername + self.solvers = [] + for mdl in models: + self.solvers.append(SolverLookup.get(solvername,mdl)) + + def solve(self, time_limit=None, **kwargs): + """ + Solve the models sequentially using the solvers. + + Parameters + ---------- + time_limit : + Global time limit in seconds for all solvers combined. + **kwargs : dict + Additional arguments passed to each solve method. + + Returns + ------- + bool + True if all solvers returned a solution, False otherwise. + """ + self.cpm_status = SolverStatus(self.name) + self.cpm_status.exitstatus = [ExitStatus.NOT_RUN] * len(self.solvers) + all_has_sol = True + # initialize exitstatus list + init_start = time.time() + for i, s in enumerate(self.solvers): + # call solver + start = time.time() + has_sol = s.solve(time_limit=time_limit, **kwargs) + # update only the current solver's exitstatus + self.cpm_status.exitstatus[i] = s.status().exitstatus + if time_limit is not None: + time_limit = time_limit - (time.time() - start) + if time_limit <= 0: + break + all_has_sol = all_has_sol and has_sol + end = time.time() + # update runtime + self.cpm_status.runtime = end - init_start + return all_has_sol + + def has_finished(self): + """ + Check whether all solvers in the MultiSolver have finished. + + A solver is considered finished if: + - It has an objective and reached OPTIMAL, or + - It has no objective and reached FEASIBLE, or + - It reached UNSATISFIABLE. + + Returns + ------- + bool + True if all solvers have finished, False otherwise. + """ + all_have_finished = True + for s in self.solvers: + finished = ((s.has_objective() and s.status().exitstatus == ExitStatus.OPTIMAL) or + (not s.has_objective() and s.status().exitstatus == ExitStatus.FEASIBLE) or + (s.status().exitstatus == ExitStatus.UNSATISFIABLE)) + all_have_finished = all_have_finished and finished + return all_have_finished +