Skip to content

Commit 6437e0c

Browse files
committed
switch to ternary search for zeta
switch on n format missed a space
1 parent 352ddaf commit 6437e0c

File tree

2 files changed

+148
-6
lines changed

2 files changed

+148
-6
lines changed

estimator/lwe_primal.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sage.all import oo, ceil, sqrt, log, RR, ZZ, binomial, cached_function
1111
from .reduction import delta as deltaf
1212
from .reduction import cost as costf
13-
from .util import local_minimum
13+
from .util import local_minimum, ternary_search
1414
from .cost import Cost
1515
from .lwe_parameters import LWEParameters
1616
from .simulator import normalize as simulator_normalize
@@ -723,11 +723,19 @@ def __call__(
723723
):
724724
zeta_max += 1
725725

726-
with local_minimum(0, min(zeta_max, params.n), log_level=log_level) as it:
727-
for zeta in it:
728-
it.update(f(zeta=zeta, optimize_d=False, **kwds))
729-
# TODO: this should not be required
730-
cost = min(it.y, f(0, optimize_d=False, **kwds))
726+
if params.n >= 2 ** 15:
727+
with ternary_search(0, min(zeta_max, params.n), log_level=log_level) as it:
728+
for zeta in it:
729+
it.update(f(zeta=zeta, optimize_d=False, **kwds))
730+
# TODO: this should not be required
731+
cost = min(it.y, f(0, optimize_d=False, **kwds))
732+
else:
733+
with local_minimum(0, min(zeta_max, params.n), log_level=log_level) as it:
734+
for zeta in it:
735+
it.update(f(zeta=zeta, optimize_d=False, **kwds))
736+
# TODO: this should not be required
737+
cost = min(it.y, f(0, optimize_d=False, **kwds))
738+
731739
else:
732740
cost = f(zeta=zeta)
733741

estimator/util.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,140 @@ def neighborhood(self):
275275
return range(start, stop)
276276

277277

278+
class ternary_search:
279+
"""
280+
an iterator context for finding a local minimum using ternary search.
281+
282+
For an interval [a, b] we evaluate f(x) and the points
283+
x1 = a + (b - a) / 3
284+
x2 = a + 2 * (b - a) / 3
285+
if f(x1) < f(x2), we keep [a, x2]
286+
if f(x1) > f(x2), we keep [x1, b]
287+
"""
288+
289+
def __init__(
290+
self,
291+
start,
292+
stop,
293+
smallerf=lambda x, best: x <= best,
294+
suppress_bounds_warning=False,
295+
log_level=5,
296+
):
297+
"""
298+
Create a fresh local minimum ternary search context.
299+
300+
:param start: starting point
301+
:param stop: end point (exclusive)
302+
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``
303+
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal
304+
305+
"""
306+
307+
if stop < start:
308+
raise ValueError(f"Incorrect bounds {start} > {stop}.")
309+
310+
self._suppress_bounds_warning = suppress_bounds_warning
311+
self._log_level = log_level
312+
self._start = start
313+
self._stop = stop - 1
314+
self._x1 = start + (stop - start) // 3
315+
self._x2 = start + (2 * (stop - start)) // 3
316+
self._fx1 = None
317+
self._fx2 = None
318+
self._initial_bounds = Bounds(start, stop - 1)
319+
self._smallerf = smallerf
320+
self._last_x = None
321+
self._best = Bounds(None, None)
322+
self._vals = {}
323+
324+
def __enter__(self):
325+
""" """
326+
return self
327+
328+
def __exit__(self, type, value, traceback):
329+
""" """
330+
pass
331+
332+
def __iter__(self):
333+
""" """
334+
return self
335+
336+
def __next__(self):
337+
if self._x1 is not None and self._fx1 is None:
338+
self._last_x = self._x1
339+
return self._last_x
340+
if self._x2 is not None and self._fx2 is None:
341+
self._last_x = self._x2
342+
return self._last_x
343+
if self._best.low in self._initial_bounds and not self._suppress_bounds_warning:
344+
# We warn the user if the optimal solution is at the edge and thus possibly not optimal.
345+
msg = (
346+
f'warning: "optimal" solution {self._best.low} matches a bound ∈ {self._initial_bounds}.',
347+
)
348+
Logging.log("bins", self._log_level, msg)
349+
raise StopIteration
350+
351+
@property
352+
def x(self):
353+
return self._best.low
354+
355+
@property
356+
def y(self):
357+
return self._best.high
358+
359+
def update(self, res):
360+
Logging.log("bins", self._log_level, f"({self._last_x}, {repr(res)})")
361+
362+
self._vals[self._last_x] = res
363+
364+
# We got nothing yet
365+
if self._best.low is None:
366+
self._best = Bounds(self._last_x, res)
367+
368+
# We found something better
369+
if res is not False and self._smallerf(res, self._best.high):
370+
# store it
371+
self._best = Bounds(self._last_x, res)
372+
373+
if self._last_x == self._x1:
374+
self._fx1 = res
375+
376+
if self._last_x == self._x2:
377+
self._fx2 = res
378+
379+
# we need to exit this loop either with something to do, or having calculated f for every point in [start, stop]
380+
# if stop - start > 2, we are guaranteed to shrink
381+
# to avoid getting stuck, we handle the cases stop - start <= 2 separately.
382+
383+
while self._fx1 is not None and self._fx2 is not None and (self._stop - self._start) > 2:
384+
# drop the right third
385+
if self._smallerf(self._fx1, self._fx2):
386+
self._start = self._start
387+
self._stop = self._x2
388+
# drop the left third
389+
else:
390+
self._start = self._x1
391+
self._stop = self._stop
392+
self._x1 = self._start + (self._stop - self._start) // 3
393+
self._x2 = self._start + (2 * (self._stop - self._start)) // 3
394+
395+
# if already seen, load the value: otherwise, mark None
396+
self._fx1 = self._vals.get(self._x1, None)
397+
self._fx2 = self._vals.get(self._x2, None)
398+
399+
# at most three integers remain: exhaustively search over them
400+
if self._stop - self._start <= 2:
401+
# print(self._start, self._stop)
402+
next = [x for x in range(self._start, self._stop + 1) if x not in self._vals]
403+
if next:
404+
# we assign remaining points arbitrarily to x1 and x2
405+
self._x1 = next[0]
406+
self._fx1 = None
407+
if len(next) > 1:
408+
self._x2 = next[1]
409+
self._fx2 = None
410+
411+
278412
class early_abort_range:
279413
"""
280414
An iterator context for finding a local minimum using linear search.

0 commit comments

Comments
 (0)