diff --git a/.circleci/config.yml b/.circleci/config.yml index 469e6d525..2ba08e537 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -36,6 +36,7 @@ commands: command: | python3 -m venv venv . venv/bin/activate + sudo apt-get install swig pip install --progress-bar off -U pip setuptools pip install --progress-bar off --use-deprecated=legacy-resolver -e .[all] pip install --progress-bar off -U numpy>=1.20.0 diff --git a/mypy.ini b/mypy.ini index 84e1db262..3160f7638 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,6 @@ [mypy] -[mypy-scipy.*,requests,pandas,compiler_gym,compiler_gym.*,gym_anm,matplotlib.*,pytest,cma,bayes_opt.*,torchvision.models,torch.*,mpl_toolkits.*,fcmaes.*,tqdm,pillow,PIL,PIL.Image,sklearn.*,pyomo.*,pyproj,IOHexperimenter.*,tensorflow,koncept.models,cv2,imquality,imquality.brisque,lpips,mixsimulator.*,networkx.*,cdt.*,pymoo,pymoo.*,bayes_optim.*,olympus.*,pymoo,pymoo.*,pybullet,pybullet_envs,pybulletgym,pyvirtualdisplay,nlopt,aquacrop.*] +[mypy-scipy.*,requests,pandas,compiler_gym,compiler_gym.*,gym_anm,,gym,gym.*,matplotlib.*,pytest,cma,bayes_opt.*,torchvision.models,torch.*,mpl_toolkits.*,fcmaes.*,tqdm,pillow,PIL,PIL.Image,sklearn.*,pyomo.*,pyproj,IOHexperimenter.*,tensorflow,koncept.models,cv2,imquality,imquality.brisque,lpips,mixsimulator.*,networkx.*,cdt.*,pymoo,pymoo.*,bayes_optim.*,olympus.*,auto-sklearn.*,openml,pymoo,pymoo.*,pybullet,pybullet_envs,pybulletgym,pyvirtualdisplay,nlopt,aquacrop.*] ignore_missing_imports = True [mypy-nevergrad.functions.rl.*,torchvision,torchvision.*,nevergrad.functions.games.*,nevergrad.functions.multiobjective.pyhv,nevergrad.optimization.test_doc,nevergrad.functions.gym.multigym] diff --git a/nevergrad/benchmark/experiments.py b/nevergrad/benchmark/experiments.py index c4aa34479..c93785846 100644 --- a/nevergrad/benchmark/experiments.py +++ b/nevergrad/benchmark/experiments.py @@ -146,6 +146,79 @@ def naivemltuning(seed: tp.Optional[int] = None) -> tp.Iterator[Experiment]: return mltuning(seed, overfitter=True) +@registry.register +def autosklearntuning(seed: tp.Optional[int] = None): + # pylint: disable=import-outside-toplevel + from nevergrad.functions.automl import AutoSKlearnBenchmark + + seedg = create_seed_generator(seed) + + # Only considered small subset of OpenML-CC18 + list_tasks = [ + 3, + 11, + 15, + 18, + 23, + 29, + 31, + 37, + 45, + 49, + 53, + 2079, + 3022, + 3549, + 3560, + 3902, + 3903, + 3913, + 3917, + 3918, + 9946, + 9957, + 9964, + 9971, + 9978, + 9981, + 10093, + 10101, + 14954, + 125920, + 146800, + 146817, + 146819, + 146821, + 146822, + ] + optims = [ + "HyperOpt", + "RandomSearch", + "CMA", + "DE", + "BO", + ] + optims += get_optimizers("splitters", seed=next(seedg)) # type: ignore + + for budget in [10, 50, 100]: + for task_id in list_tasks: + for algo in optims: + for seed in range(10): + func = AutoSKlearnBenchmark( + openml_task_id=task_id, + cv=3, + overfitter=False, + time_budget_per_run=300, + memory_limit=1024 * 10, + scoring_func="balanced_accuracy", + random_state=next(seedg), + ) + xp = Experiment(func, algo, budget, num_workers=1, seed=next(seedg)) # type: ignore + skip_ci(reason="Too slow") + if not xp.is_incoherent: + yield xp + + # We register only the sequential counterparts for the moment. @registry.register def seq_keras_tuning(seed: tp.Optional[int] = None) -> tp.Iterator[Experiment]: diff --git a/nevergrad/functions/automl/__init__.py b/nevergrad/functions/automl/__init__.py new file mode 100644 index 000000000..117c44888 --- /dev/null +++ b/nevergrad/functions/automl/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .core import AutoSKlearnBenchmark as AutoSKlearnBenchmark diff --git a/nevergrad/functions/automl/core.py b/nevergrad/functions/automl/core.py new file mode 100644 index 000000000..36e5854c6 --- /dev/null +++ b/nevergrad/functions/automl/core.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# Trained policies were extracted from https://github.com/modestyachts/ARS +# under their own license. See ARS_LICENSE file in this file's directory +import os + +import nevergrad.common.typing as tp +import numpy as np +from nevergrad.functions.base import UnsupportedExperiment as UnsupportedExperiment + +from .. import base + + +class AutoSKlearnBenchmark(base.ExperimentFunction): + def __init__( + self, + openml_task_id: int, + cv: int, + time_budget_per_run: int, + memory_limit: int, + scoring_func: str = "balanced_accuracy", + error_penalty: float = 1.0, + overfitter: bool = False, + random_state: tp.Optional[int] = None, + ) -> None: + if os.name == "nt": + raise UnsupportedExperiment("Auto-Sklearn is not working under Windows") + + from .ngautosklearn import get_parametrization, _eval_function, get_config_space + import openml + import submitit + + self.openml_task_id = openml_task_id + self.random_state = random_state + self.cv = cv + self.scoring_func = scoring_func + self.memory_limit = memory_limit + self.time_budget_per_run = time_budget_per_run + self.error_penalty = error_penalty + self.overfitter = overfitter + self.evaluate_on_test = False + self.eval_func = _eval_function + openml_task = openml.tasks.get_task(openml_task_id) + self.dataset_name = openml_task.get_dataset().name + X, y = openml_task.get_X_and_y() + split = openml_task.get_train_test_split_indices() + self.X_train, self.y_train = X[split[0]], y[split[0]] + self.X_test, self.y_test = X[split[1]], y[split[1]] + + self.config_space = get_config_space(X=self.X_train, y=self.y_train) + parametrization = get_parametrization(self.config_space) + parametrization = parametrization.set_name(f"time={time_budget_per_run}") + + log_folder = "/tmp" + self.executor = submitit.AutoExecutor(folder=log_folder, cluster="local") + self.executor.update_parameters(timeout_min=time_budget_per_run) + + self.add_descriptors( + openml_task_id=openml_task_id, + cv=cv, + scoring_func=scoring_func, + memory_limit=memory_limit, + time_budget_per_run=time_budget_per_run, + error_penalty=error_penalty, + overfitter=overfitter, + dataset_name=self.dataset_name, + ) + self._descriptors.pop("random_state", None) # remove it from automatically added descriptors + self.best_loss = np.inf + self.best_config = None + parametrization.function.proxy = not overfitter + parametrization.function.deterministic = False + super().__init__(self._simulate, parametrization) + + def _simulate(self, **x) -> float: + from .ngautosklearn import get_configuration + + config = get_configuration(x, self.config_space) + if not self.evaluate_on_test: + job = self.executor.submit( + self.eval_func, + config, + self.X_train, + self.y_train, + self.scoring_func, + self.cv, + self.random_state, + None, + ) + else: + job = self.executor.submit( + self.eval_func, + config, + self.X_train, + self.y_train, + self.scoring_func, + self.cv, + self.random_state, + (self.X_test, self.y_test), + ) + try: + loss = job.result() + except: + loss = 1 + + return loss if isinstance(loss, float) else self.error_penalty + + def evaluation_function(self, *args) -> float: + self.evaluate_on_test = not self.overfitter + return super().evaluation_function(*args) diff --git a/nevergrad/functions/automl/ngautosklearn.py b/nevergrad/functions/automl/ngautosklearn.py new file mode 100644 index 000000000..ead624217 --- /dev/null +++ b/nevergrad/functions/automl/ngautosklearn.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import ConfigSpace as cs # type: ignore +import nevergrad as ng +import numpy as np +import scipy +from sklearn.metrics import get_scorer +from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import cross_val_score + +try: + from autosklearn.constants import BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION # type: ignore + from autosklearn.util.pipeline import get_configuration_space # type: ignore + from autosklearn.pipeline.classification import SimpleClassificationPipeline # type: ignore +except ImportError: + raise ImportError("Auto-Sklearn not installed. Run: python -m pip install auto-sklearn==0.11.0") + + +def _eval_function( + config: cs.Configuration, X, y, scoring_func: str, cv: int, random_state: int, test_data: tuple +): + try: + classifier = SimpleClassificationPipeline(config=config, random_state=random_state) + scorer = get_scorer(scoring_func) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if test_data is None: + scores = cross_val_score( + estimator=classifier, + X=X, + y=y, + cv=StratifiedKFold(n_splits=cv, random_state=random_state, shuffle=True), + scoring=scorer, + n_jobs=1, + ) + return 1 - np.mean(scores) + else: + classifier.fit(X, y) + return 1 - scorer(classifier, test_data[0], test_data[1]) + except Exception: + return 1 + + +def check_configuration(config_space, values): + val_dict = to_dict(values[1]) + try: + cs.Configuration(configuration_space=config_space, values=val_dict, allow_inactive_with_values=False) + except Exception: + return False + return True + + +def get_config_space(X, y): + dataset_properties = { + "task": BINARY_CLASSIFICATION if len(np.unique(y)) == 2 else MULTICLASS_CLASSIFICATION, + "is_sparse": scipy.sparse.issparse(X), + } + return get_configuration_space(dataset_properties) + + +def get_instrumention(param): + if isinstance(param, cs.hyperparameters.CategoricalHyperparameter): + return ng.p.Choice(param.choices) + elif isinstance(param, cs.hyperparameters.UniformIntegerHyperparameter): + if param.log == False: + return ng.p.Scalar( + lower=param.lower, upper=param.upper, init=param.default_value + ).set_integer_casting() + else: + return ng.p.Log( + lower=param.lower, upper=param.upper, init=param.default_value + ).set_integer_casting() + elif isinstance(param, cs.hyperparameters.UniformFloatHyperparameter): + if param.log == False: + return ng.p.Scalar(lower=param.lower, upper=param.upper, init=param.default_value) + else: + return ng.p.Log(lower=param.lower, upper=param.upper, init=param.default_value) + elif isinstance(param, cs.hyperparameters.Constant): + return param.value + raise Exception(r"{param} type not known") + + +def get_parametrization(config_space: cs.ConfigurationSpace): + base_pipeline = [ + "balancing:strategy", + "classifier:__choice__", + "data_preprocessor:feature_type:categorical_transformer:categorical_encoding:__choice__", + "data_preprocessor:feature_type:categorical_transformer:category_coalescence:__choice__", + "data_preprocessor:feature_type:numerical_transformer:imputation:strategy", + "data_preprocessor:feature_type:numerical_transformer:rescaling:__choice__", + "feature_preprocessor:__choice__", + "data_preprocessor:__choice__", + ] + + params = {} + + for param in config_space.get_hyperparameters(): + if param.name in base_pipeline: + if param.name in [ + "classifier:__choice__", + "feature_preprocessor:__choice__", + "data_preprocessor:feature_type:numerical_transformer:rescaling:__choice__", + "data_preprocessor:feature_type:categorical_transformer:category_coalescence:__choice__", + ]: + params[param.name] = ng.p.Choice( + [ + ng.p.Tuple( + param_choice, + ng.p.Dict( + **{ + hp.name: get_instrumention(hp) + for hp in config_space.get_hyperparameters() + if param_choice in hp.name + } + ), + ) + for param_choice in param.choices + ] + ) + else: + params[param.name] = get_instrumention(param) + + inst = ng.p.Instrumentation(**params) + from functools import partial + + constraint_check_func = partial(check_configuration, config_space) + inst.register_cheap_constraint(constraint_check_func) + return inst + + +def get_configuration(values, config_space): + val_dict = to_dict(values) + return cs.Configuration( + configuration_space=config_space, values=val_dict, allow_inactive_with_values=True + ) + + +def to_dict(values): + clf = values["classifier:__choice__"] + features = values["feature_preprocessor:__choice__"] + data_preprocessor = values["data_preprocessor:__choice__"] + trans_cat = values[ + "data_preprocessor:feature_type:categorical_transformer:category_coalescence:__choice__" + ] + trans_num = values["data_preprocessor:feature_type:numerical_transformer:rescaling:__choice__"] + del values["classifier:__choice__"] + del values["feature_preprocessor:__choice__"] + del values["data_preprocessor:feature_type:categorical_transformer:category_coalescence:__choice__"] + del values["data_preprocessor:feature_type:numerical_transformer:rescaling:__choice__"] + values["classifier:__choice__"] = clf[0] + values.update(clf[1]) + values["feature_preprocessor:__choice__"] = features[0] + values.update(features[1]) + values["data_preprocessor:__choice__"] = data_preprocessor + values[ + "data_preprocessor:feature_type:categorical_transformer:category_coalescence:__choice__" + ] = trans_cat[0] + if len(trans_cat[1]) > 0: + values.update(trans_cat[1]) + values["data_preprocessor:feature_type:numerical_transformer:rescaling:__choice__"] = trans_num[0] + if len(trans_num[1]) > 0: + values.update(trans_num[1]) + return values diff --git a/nevergrad/functions/automl/test_automl.py b/nevergrad/functions/automl/test_automl.py new file mode 100644 index 000000000..ffb62fad2 --- /dev/null +++ b/nevergrad/functions/automl/test_automl.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import nevergrad as ng + +from .core import AutoSKlearnBenchmark + + +def test_parametrization(): + func = AutoSKlearnBenchmark( + openml_task_id=3, + cv=3, + time_budget_per_run=60, + memory_limit=2000, + scoring_func="balanced_accuracy", + random_state=42, + ) + optimizer = ng.optimizers.RandomSearch(parametrization=func.parametrization, budget=3) + optimizer.minimize(func, verbosity=2) + + +def test_function(): + func = AutoSKlearnBenchmark( + openml_task_id=3, + cv=3, + time_budget_per_run=360, + memory_limit=7000, + scoring_func="balanced_accuracy", + random_state=42, + ) + for _ in range(2): + is_valid = False + while not is_valid: + cand = func.parametrization.sample() + is_valid = cand.satisfies_constraints() + val = func(**cand.kwargs) + assert (val >= 0) and (val <= 1) diff --git a/requirements/bench.txt b/requirements/bench.txt index 1a15f445a..18b338bcf 100644 --- a/requirements/bench.txt +++ b/requirements/bench.txt @@ -35,5 +35,8 @@ silence_tensorflow # for olymp tensorflow_probability # for olymp bayes-optim==0.2.5.5 nlopt +auto-sklearn>=0.14.0 +openml>=0.12.2 +submitit pybullet>=3.2.2 box2d-py>=2.3.5 \ No newline at end of file