Skip to content

Commit

Permalink
Fixed BestSoFar optimisers being broken.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 5, 2023
1 parent 21dab0d commit bdde300
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
28 changes: 28 additions & 0 deletions optimistix/_root_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,38 @@ def postprocess(self, fn, y, aux, args, options, state, tags, result):
class _MinimToRoot(AbstractMinimiser, _ToRoot):
solver: AbstractMinimiser

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property # pyright: ignore
def rtol(self):
return self.solver.rtol

@property # pyright: ignore
def atol(self):
return self.solver.atol

@property # pyright: ignore
def norm(self):
return self.solver.norm


class _LstsqToRoot(AbstractLeastSquaresSolver, _ToRoot):
solver: AbstractLeastSquaresSolver

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property # pyright: ignore
def rtol(self):
return self.solver.rtol

@property # pyright: ignore
def atol(self):
return self.solver.atol

@property # pyright: ignore
def norm(self):
return self.solver.norm


@eqx.filter_jit
def root_find(
Expand Down
56 changes: 56 additions & 0 deletions optimistix/_solver/best_so_far.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ def __init__(self, solver: AbstractMinimiser[Y, tuple[Scalar, Aux], Any]):
def _to_loss(self, y: Y, f: Scalar) -> Scalar:
return f

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property # pyright: ignore
def rtol(self):
return self.solver.rtol

@property # pyright: ignore
def atol(self):
return self.solver.atol

@property # pyright: ignore
def norm(self):
return self.solver.norm


BestSoFarMinimiser.__init__.__doc__ = """**Arguments:**
Expand All @@ -158,6 +172,20 @@ def __init__(
def _to_loss(self, y: Y, f: Out) -> Scalar:
return sum_squares(f)

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property # pyright: ignore
def rtol(self):
return self.solver.rtol

@property # pyright: ignore
def atol(self):
return self.solver.atol

@property # pyright: ignore
def norm(self):
return self.solver.norm


BestSoFarLeastSquares.__init__.__doc__ = """**Arguments:**
Expand All @@ -181,6 +209,20 @@ def __init__(self, solver: AbstractRootFinder[Y, Out, tuple[Out, Aux], Any]):
def _to_loss(self, y: Y, f: Out) -> Scalar:
return sum_squares(f)

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property # pyright: ignore
def rtol(self):
return self.solver.rtol

@property # pyright: ignore
def atol(self):
return self.solver.atol

@property # pyright: ignore
def norm(self):
return self.solver.norm


BestSoFarRootFinder.__init__.__doc__ = """**Arguments:**
Expand All @@ -204,6 +246,20 @@ def __init__(self, solver: AbstractFixedPointSolver[Y, tuple[Y, Aux], Any]):
def _to_loss(self, y: Y, f: Y) -> Scalar:
return sum_squares((y**ω - f**ω).ω)

# Redeclare these three to work around the Equinox bug fixed here:
# https://github.com/patrick-kidger/equinox/pull/544
@property # pyright: ignore
def rtol(self):
return self.solver.rtol

@property # pyright: ignore
def atol(self):
return self.solver.atol

@property # pyright: ignore
def norm(self):
return self.solver.norm


BestSoFarFixedPoint.__init__.__doc__ = """**Arguments:**
Expand Down
1 change: 1 addition & 0 deletions tests/test_best_so_far.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ def fn(y, _):
return 0.5 * (y - jnp.tanh(y + 1)) ** 2

solver = optx.BFGS(rtol=1e-6, atol=1e-6)
solver = optx.BestSoFarMinimiser(solver)
sol = optx.minimise(fn, solver, jnp.array(0.0))
assert jnp.allclose(sol.value, 0.96118069, rtol=1e-5, atol=1e-5)

0 comments on commit bdde300

Please sign in to comment.