Skip to content

Commit

Permalink
Merge pull request #1175 from dedupeio/iterative_bb
Browse files Browse the repository at this point in the history
iterative branch and bound
  • Loading branch information
fgregg authored Dec 19, 2023
2 parents d986738 + 3ae13e2 commit a315886
Showing 1 changed file with 18 additions and 41 deletions.
59 changes: 18 additions & 41 deletions dedupe/branch_and_bound.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import functools
import warnings
from typing import Any, Iterable, Mapping, Sequence, Tuple

from ._typing import Cover
Expand All @@ -27,13 +26,11 @@ def _remove_dominated(coverage: Cover, dominator: Predicate) -> Cover:
def _uncovered_by(
coverage: Mapping[Any, frozenset[int]], covered: frozenset[int]
) -> dict[Any, frozenset[int]]:
remaining = {}
for predicate, uncovered in coverage.items():
still_uncovered = uncovered - covered
if still_uncovered:
remaining[predicate] = still_uncovered

return remaining
return {
pred: still_uncovered
for pred, uncovered in coverage.items()
if (still_uncovered := uncovered - covered)
}


def _order_by(
Expand All @@ -46,41 +43,22 @@ def _score(partial: Iterable[Predicate]) -> float:
return sum(p.cover_count for p in partial)


def _suppress_recursion_error(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RecursionError:
warnings.warn("Recursion limit eached while searching for predicates")

return wrapper


def search(candidates, target: int, max_calls: int) -> Partial:
calls = max_calls

cheapest_score = float("inf")
cheapest: Partial = ()

original_cover = candidates.copy()

def search(original_cover: Cover, target: int, calls: int) -> Partial:
def _covered(partial: Partial) -> int:
return (
len(frozenset.union(*(original_cover[p] for p in partial)))
if partial
else 0
)

@_suppress_recursion_error
def walk(candidates: Cover, partial: Partial = ()) -> None:
nonlocal calls
nonlocal cheapest
nonlocal cheapest_score
cheapest_score = float("inf")
cheapest: Partial = ()

if calls <= 0:
return
start: tuple[Cover, Partial] = (original_cover, ())
to_explore = [start]

calls -= 1
while to_explore and calls:
candidates, partial = to_explore.pop()

covered = _covered(partial)
score = _score(partial)
Expand All @@ -97,17 +75,16 @@ def walk(candidates: Cover, partial: Partial = ()) -> None:
order_by = functools.partial(_order_by, candidates)
best = max(candidates, key=order_by)

remaining = _uncovered_by(candidates, candidates[best])
walk(remaining, partial + (best,))
del remaining

reduced = _remove_dominated(candidates, best)
walk(reduced, partial)
del reduced
to_explore.append((reduced, partial))

remaining = _uncovered_by(candidates, candidates[best])
to_explore.append((remaining, partial + (best,)))

elif score < cheapest_score:
cheapest = partial
cheapest_score = score

walk(candidates)
calls -= 1

return cheapest

0 comments on commit a315886

Please sign in to comment.