Skip to content

Commit

Permalink
refactor Hutch++ algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
FMatti committed Jun 11, 2024
1 parent d7b2cb9 commit 1bf050c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 32 deletions.
64 changes: 38 additions & 26 deletions roughly/approximate/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def compute(self, A : Union[np.ndarray, callable], k : int = 10, n : Union[int,
Returns
-------
self.est : float
self.trace : float
Trace estimate.
"""
self._preprocess(A, k, n, dtype=dtype)
V = self.rng(self.n, self.k)
self.est = np.sum(V.conj() * self.matvec(V)) / self.k
return self.est
self.trace = np.sum(V.conj() * self.matvec(V)) / self.k
return self.trace

def refine(self, k : int = 1):
"""
Expand All @@ -122,13 +122,13 @@ def refine(self, k : int = 1):
Returns
-------
self.est : float
self.trace : float
Trace estimate.
"""
V = self.rng(self.n, k)
self.est = (self.est * self.k + np.sum(V.conj() * self.matvec(V))) / (k + self.k)
self.trace = (self.trace * self.k + np.sum(V.conj() * self.matvec(V))) / (k + self.k)
self.k += k
return self.est
return self.trace

class SubspaceProjectionEstimator(TraceEstimator):
"""
Expand Down Expand Up @@ -191,13 +191,13 @@ def compute(self, A : Union[np.ndarray, callable], k : int = 10, n : Union[int,
Returns
-------
self.est : float
self.trace : float
Trace estimate.
"""
self._preprocess(A, k, n, dtype=dtype)
self.Q = self.sketch.compute(self.matvec, k=k, n=self.n, dtype=self.dtype)
self.est = np.sum(self.Q.conj() * self.matvec(self.Q))
return self.est
self.trace = np.sum(self.Q.conj() * self.matvec(self.Q))
return self.trace

def refine(self, k : int = 1):
"""
Expand All @@ -210,13 +210,13 @@ def refine(self, k : int = 1):
Returns
-------
self.est : float
self.trace : float
Trace estimate.
"""
self.Q = self.sketch.refine(k=k)
self.est = np.sum(self.Q.conj() * self.matvec(self.Q))
self.trace = np.sum(self.Q.conj() * self.matvec(self.Q))
self.k += k
return self.est
return self.trace

class DeflatedTraceEstimator(TraceEstimator):
"""
Expand Down Expand Up @@ -259,7 +259,7 @@ def __init__(self, rng : Union[str, callable] = "gaussian"):
self.sketch = StandardSketch(rng)
super().__init__(rng)

def compute(self, A : Union[np.ndarray, callable], k : int = 10, n : Union[int, None] = None, dtype : Union[type, None] = None):
def compute(self, A : Union[np.ndarray, callable], k : int = 10, sketch_ratio : float = 2/3, n : Union[int, None] = None, dtype : Union[type, None] = None):
"""
Compute the trace estimate.
Expand All @@ -279,17 +279,23 @@ def compute(self, A : Union[np.ndarray, callable], k : int = 10, n : Union[int,
Returns
-------
self.est : float
self.trace : float
Trace estimate.
"""
self._preprocess(A, k, n, dtype=dtype)
self.Q = self.sketch.compute(self.matvec, k=k // 3, n=self.n, dtype=self.dtype)
G = self.rng(self.n, k // 3)
G = G - self.Q @ (self.Q.T @ G)
self.est = np.sum(self.Q.conj() * self.matvec(self.Q)) + 1 / G.shape[-1] * np.trace(G.T @ self.matvec(G))
return self.est
k_sketch = round(k * sketch_ratio / 2)
k_correction = k - k_sketch

def refine(self, k : int = 1):
# Subspace projection
self.Q = self.sketch.compute(self.matvec, k=k_sketch, n=self.n, dtype=self.dtype)
self.S = self.rng(self.n, k_correction)

# Trace correction
G = self.S - self.Q @ (self.Q.T @ self.S)
self.trace = np.sum(self.Q.conj() * self.matvec(self.Q)) + np.sum(G.conj() * self.matvec(G)) / k_correction
return self.trace

def refine(self, k : int = 1, sketch_ratio : float = 2/3):
"""
Refine the trace estimate.
Expand All @@ -300,15 +306,21 @@ def refine(self, k : int = 1):
Returns
-------
self.est : float
self.trace : float
Trace estimate.
"""
self.Q = self.sketch.refine(k=k)
G = self.rng(self.n, k)
G = G - self.Q @ (self.Q.T @ G)
self.est = np.sum(self.Q.conj() * self.matvec(self.Q)) + 1/(self.k // 3 + G.shape[-1]) * np.trace(G.T @ self.matvec(G))
k_sketch = round(k * sketch_ratio / 2)
k_correction = k - k_sketch

# Subspace projection
self.Q = self.sketch.refine(k=k_sketch)
self.S = np.hstack((self.S, self.rng(self.n, k_correction)))

# Trace correction
G = self.S - self.Q @ (self.Q.T @ self.S)
self.trace = np.sum(self.Q.conj() * self.matvec(self.Q)) + np.sum(G.conj() * self.matvec(G)) / G.shape[-1]
self.k += k
return self.est
return self.trace

# TODO: class XTrace(TraceEstimator)

Expand Down
11 changes: 5 additions & 6 deletions tests/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_HutchinsonTraceEstimator():
tr_A = np.trace(A)

for k in [1, 2, 3, 10, 20]:
estimator = HutchinsonTraceEstimator(rng="rademacher")
estimator = HutchinsonTraceEstimator()
estimator.compute(A, k)
k_total = k
for k_refine in range(10):
Expand All @@ -29,7 +29,7 @@ def test_SubspaceProjectionEstimator():
tr_A = np.trace(A)

for k in [1, 2, 3, 10, 20]:
estimator = SubspaceProjectionEstimator(rng="rademacher")
estimator = SubspaceProjectionEstimator()
estimator.compute(A, k)
k_total = k
for k_refine in range(10):
Expand All @@ -44,13 +44,12 @@ def test_DeflatedTraceEstimator():
A = A + A.conj().T
tr_A = np.trace(A)

for k in [3, 10, 20]:
estimator = DeflatedTraceEstimator(rng="rademacher")
for k in [1, 2, 3, 10, 20]:
estimator = DeflatedTraceEstimator()
estimator.compute(A, k)
k_total = k
for k_refine in range(10):
k_total += k_refine
t = estimator.refine(k=k_refine)
if k_refine > 5:
assert(abs(tr_A - t) / abs(tr_A) < 10 * n / k_total)

assert(abs(tr_A - t) / abs(tr_A) < 5 * n / k_total)

0 comments on commit 1bf050c

Please sign in to comment.