Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow delegating the archive #1348

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions nevergrad/optimization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import pickle
import logging
import warnings
from pathlib import Path
from numbers import Real
Expand All @@ -26,7 +27,8 @@
X = tp.TypeVar("X", bound="Optimizer")
Y = tp.TypeVar("Y")
IntOrParameter = tp.Union[int, p.Parameter]
_PruningCallable = tp.Callable[[utils.Archive[utils.MultiValue]], utils.Archive[utils.MultiValue]]
_PruningCallable = tp.Callable[[utils.Archive[utils.MultiValue]], None]
logger = logging.getLogger(__name__)


def _loss(param: p.Parameter) -> float:
Expand Down Expand Up @@ -404,18 +406,23 @@ def _update_archive_and_bests(self, candidate: p.Parameter, loss: tp.FloatLoss)
)
if np.isnan(loss) or loss == np.inf:
self._warn(f"Updating fitness with {loss} value", errors.BadLossWarning)
mvalue: tp.Optional[utils.MultiValue] = None
mvalue = utils.MultiValue(candidate, loss, reference=self.parametrization)
if not self.archive.is_delegated:
# print(f"Updating archive for {candidate.uid[:8]} in {self.__class__.__name__}")
if x not in self.archive:
self.archive[x] = mvalue
else: # reevaluation: needs updating
mvalue = self.archive[x]
mvalue.add_evaluation(loss)
# both parameters should be non-None
if mvalue.parameter.loss > candidate.loss: # type: ignore
mvalue.parameter = candidate # keep best candidate
# the following should not happen since delegating archives are used afterwards
if x not in self.archive:
self.archive[x] = utils.MultiValue(candidate, loss, reference=self.parametrization)
else:
mvalue = self.archive[x]
mvalue.add_evaluation(loss)
# both parameters should be non-None
if mvalue.parameter.loss > candidate.loss: # type: ignore
mvalue.parameter = candidate # keep best candidate
logger.warning("Archive is not correctly filled, please open an issue.")
mvalue = self.archive.get(x, mvalue) # type: ignore # update with master archive
# update current best records
# this may have to be improved if we want to keep more kinds of best losss

for name in self.current_bests:
if mvalue is self.current_bests[name]: # reboot
best = min(self.archive.values(), key=lambda mv, n=name: mv.get_estimation(n)) # type: ignore
Expand All @@ -432,8 +439,9 @@ def _update_archive_and_bests(self, candidate: p.Parameter, loss: tp.FloatLoss)
# max(v.get_estimation(name) for v in self.archive.values()))
# raise RuntimeError(f"Best value should exist in the archive at num_tell={self.num_tell})\n"
# f"Best value is {bval} and archive is within range {avals} for {name}")
if self.pruning is not None:
self.archive = self.pruning(self.archive)
if self.pruning is not None and not self.archive.is_delegated:
self.pruning(self.archive)
self.archive[x] = mvalue # we must make sure that the current point is available for suboptim

def ask(self) -> p.Parameter:
"""Provides a point to explore.
Expand Down
4 changes: 4 additions & 0 deletions nevergrad/optimization/optimizerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ def __init__(
) -> None:
super().__init__(parametrization, budget=budget, num_workers=num_workers)
self._optimizer = base_optimizer(self.parametrization, budget=budget, num_workers=num_workers)
# cannot delegate the archive since the parametrization is different :(
self._subcandidates: tp.Dict[str, p.Parameter] = {}
if scale is None:
assert self.budget is not None, "Either scale or budget must be known in _Rescaled."
Expand Down Expand Up @@ -1374,6 +1375,8 @@ def __init__(
num_workers=sub_workers,
)
)
for optim in self.optims:
optim.archive.delegate_to(self.archive)
# current optimizer choice
self._selected_ind: tp.Optional[int] = None
self._current = -1
Expand Down Expand Up @@ -2407,6 +2410,7 @@ def optim(self) -> base.Optimizer:
if self._optim is None:
self._optim = self._select_optimizer_cls()(self.parametrization, self.budget, self.num_workers)
self._optim = self._optim if not isinstance(self._optim, NGOptBase) else self._optim.optim
self._optim.archive.delegate_to(self.archive)
logger.debug("%s selected %s optimizer.", *(x.name for x in (self, self._optim)))
return self._optim

Expand Down
6 changes: 3 additions & 3 deletions nevergrad/optimization/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def test_pruning() -> None:
# pruning
pruning = utils.Pruning(min_len=1, max_len=3)
# 0 is best optimistic and average, and 3 is best pessimistic (variance=0)
archive = pruning(archive)
pruning(archive)
testing.assert_set_equal([x[0] for x in archive.keys_as_arrays()], [0, 3], err_msg=f"Repetition #{k+1}")
pickle.dumps(archive) # should be picklable
# should not change anything this time
archive2 = pruning(archive)
testing.assert_set_equal([x[0] for x in archive2.keys_as_arrays()], [0, 3], err_msg=f"Repetition #{k+1}")
pruning(archive)
testing.assert_set_equal([x[0] for x in archive.keys_as_arrays()], [0, 3], err_msg=f"Repetition #{k+1}")


@pytest.mark.parametrize( # type: ignore
Expand Down
35 changes: 26 additions & 9 deletions nevergrad/optimization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,27 @@ class Archive(tp.Generic[Y]):
"""

def __init__(self) -> None:
self.bytesdict: tp.Dict[bytes, Y] = {}
self._data: tp.Union["Archive[Y]", tp.Dict[bytes, Y]] = {}

@property
def bytesdict(self) -> tp.Dict[bytes, Y]:
if not isinstance(self._data, Archive):
return self._data
else:
while self._data.is_delegated: # unroll
self._data = self._data._data
return self._data._data # type: ignore

def delegate_to(self, archive: "Archive[Y]") -> None:
self._data = archive

@property
def is_delegated(self) -> bool:
return isinstance(self._data, Archive)

def __setitem__(self, x: tp.ArrayLike, value: Y) -> None:
if self.is_delegated:
raise RuntimeError("Cannot set from a delegated instance")
self.bytesdict[_tobytes(x)] = value

def __getitem__(self, x: tp.ArrayLike) -> Y:
Expand Down Expand Up @@ -255,13 +273,13 @@ def __init__(self, min_len: int, max_len: int):
self.max_len = max_len
self._num_prunings = 0 # for testing it is not called too often

def __call__(self, archive: Archive[MultiValue]) -> Archive[MultiValue]:
if len(archive) < self.max_len:
return archive
return self._prune(archive)
def __call__(self, archive: Archive[MultiValue]) -> None:
if len(archive) >= self.max_len:
self._prune(archive)

def _prune(self, archive: Archive[MultiValue]) -> Archive[MultiValue]:
def _prune(self, archive: Archive[MultiValue]) -> None:
self._num_prunings += 1
assert not isinstance(archive._data, Archive), "Cannot prune on delegated instance"
# separate function to ease profiling
quantiles: tp.Dict[str, float] = {}
threshold = float(self.min_len + 1) / len(archive)
Expand All @@ -270,14 +288,13 @@ def _prune(self, archive: Archive[MultiValue]) -> Archive[MultiValue]:
quantiles[name] = np.quantile(
[v.get_estimation(name) for v in archive.values()], threshold, interpolation="lower"
)
new_archive: Archive[MultiValue] = Archive()
new_archive.bytesdict = {
bytesdict = {
b: v
for b, v in archive.bytesdict.items()
if any(v.get_estimation(n) < quantiles[n] for n in names)
} # strict comparison to make sure we prune even for values repeated maaany times
# this may remove all points though, but nevermind for now
return new_archive
archive._data = bytesdict

@classmethod
def sensible_default(cls, num_workers: int, dimension: int) -> "Pruning":
Expand Down