Skip to content

[WIP] Nystrom sinkhorn #742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ debug

# pytest cahche
.pytest_cache
/docs
/ot/mytest
/test/mytest
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,5 @@ Artificial Intelligence.
[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR.

[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.

[76] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019.
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)
- Add Nystrom kernel approximation for Sinkhorn (PR #742)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand All @@ -36,7 +37,7 @@ This new release contains several new features, starting with
a novel [Gaussian Mixture Model Optimal Transport (GMM-OT)](https://pythonot.github.io/master/gen_modules/ot.gmm.html#examples-using-ot-gmm-gmm-ot-apply-map) solver to compare GMM while enforcing the transport plan to remain a GMM, that benefits from a closed-form solution making it practical for high-dimensional matching problems. We also extended our general unbalanced OT solvers to support any non-negative reference measure in the regularization terms, before adding the novel [translation invariant UOT](https://pythonot.github.io/master/auto_examples/unbalanced-partial/plot_conv_sinkhorn_ti.html) solver showcasing a higher convergence speed. We also implemented several new solvers and enhanced existing ones to perform OT across spaces. These include a [semi-relaxed FGW barycenter](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.html) solver, coupled with new initialization heuristics for the inner divergence computation, to perform graph partitioning or dictionary learning. Followed by novel [unbalanced FGW and Co-optimal transport](https://pythonot.github.io/master/auto_examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.html) solvers to promote robustness to outliers in such matching problems. And we finally updated the implementation of partial GW now supporting asymmetric structures and the KL divergence, while leveraging a new generic conditional gradient solver for partial transport problems enabling significant speed improvements. These latest updates required some modifications to the line search functions of our generic conditional gradient solver, paving the way for future improvements to other GW-based solvers. Last but not least, we implemented a pre-commit scheme to automatically correct common programming mistakes likely to be made by our future contributors.

This release also contains few bug fixes, concerning the support of any metric in `ot.emd_1d` / `ot.emd2_1d`, and the support of any weights in `ot.gaussian`.

#### Breaking change
- Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663).

Expand Down
15 changes: 15 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@ def trace(self, a):
def inv(self, a):
return scipy.linalg.inv(a)

def pinv(self, a, hermitian=False):
return np.linalg.pinv(a, hermitian=hermitian)

def sqrtm(self, a):
L, V = np.linalg.eigh(a)
L = np.sqrt(L)
Expand Down Expand Up @@ -1781,6 +1784,9 @@ def trace(self, a):
def inv(self, a):
return jnp.linalg.inv(a)

def pinv(self, a, hermitian=False):
return jnp.linalg.pinv(a, hermitian=hermitian)

def sqrtm(self, a):
L, V = jnp.linalg.eigh(a)
L = jnp.sqrt(L)
Expand Down Expand Up @@ -2314,6 +2320,9 @@ def trace(self, a):
def inv(self, a):
return torch.linalg.inv(a)

def pinv(self, a, hermitian=False):
return torch.linalg.pinv(a, hermitian=hermitian)

def sqrtm(self, a):
L, V = torch.linalg.eigh(a)
L = torch.sqrt(L)
Expand Down Expand Up @@ -2728,6 +2737,9 @@ def trace(self, a):
def inv(self, a):
return cp.linalg.inv(a)

def pinv(self, a, hermitian=False):
return cp.linalg.pinv(a)

def sqrtm(self, a):
L, V = cp.linalg.eigh(a)
L = cp.sqrt(L)
Expand Down Expand Up @@ -3164,6 +3176,9 @@ def trace(self, a):
def inv(self, a):
return tf.linalg.inv(a)

def pinv(self, a, hermitian=False):
return tf.linalg.pinv(a)

def sqrtm(self, a):
L, V = tf.linalg.eigh(a)
L = tf.sqrt(L)
Expand Down
4 changes: 4 additions & 0 deletions ot/bregman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
empirical_sinkhorn,
empirical_sinkhorn2,
empirical_sinkhorn_divergence,
empirical_sinkhorn_nystroem,
empirical_sinkhorn_nystroem2,
)

from ._screenkhorn import screenkhorn
Expand Down Expand Up @@ -71,6 +73,8 @@
"empirical_sinkhorn2",
"empirical_sinkhorn2_geomloss",
"empirical_sinkhorn_divergence",
"empirical_sinkhorn_nystroem",
"empirical_sinkhorn_nystroem2",
"geomloss",
"screenkhorn",
"unmix",
Expand Down
243 changes: 239 additions & 4 deletions ot/bregman/_empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@
# Author: Remi Flamary <[email protected]>
# Kilian Fatras <[email protected]>
# Quang Huy Tran <[email protected]>
# Titouan Vayer <[email protected]>
#
# License: MIT License

import warnings
import math

from ..utils import dist, list_to_array, unif, LazyTensor
from ..backend import get_backend

from ._sinkhorn import sinkhorn, sinkhorn2
from ..lowrank import (
kernel_nystroem,
sinkhorn_low_rank_kernel,
compute_lr_sqeuclidean_matrix,
)


def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric="sqeuclidean", reg=1e-1, nx=None):
Expand Down Expand Up @@ -267,7 +274,7 @@ def empirical_sinkhorn(
else:
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(
G, log = sinkhorn(
a,
b,
M,
Expand All @@ -279,9 +286,9 @@ def empirical_sinkhorn(
warmstart=warmstart,
**kwargs,
)
return pi, log
return G, log
else:
pi = sinkhorn(
G = sinkhorn(
a,
b,
M,
Expand All @@ -293,7 +300,7 @@ def empirical_sinkhorn(
warmstart=warmstart,
**kwargs,
)
return pi
return G


def empirical_sinkhorn2(
Expand Down Expand Up @@ -755,3 +762,231 @@ def empirical_sinkhorn_divergence(

sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
return nx.maximum(0, sinkhorn_div)


def empirical_sinkhorn_nystroem(
X_s,
X_t,
reg=1.0,
anchors=50,
a=None,
b=None,
numItermax=1000,
stopThr=1e-9,
verbose=False,
log=False,
warn=True,
warmstart=None,
random_state=None,
):
r"""
Solves the entropic regularization optimal transport problem with Sinkhorn and Nystroem factorization [76] and returns the
OT matrix from empirical data.
Corresponds to an approximation of entropic OT (for a squared Euclidean cost) that runs in linear time.
The number of anchors controls the level of approximation (the higher, the better the approximation, but the slower the computation becomes).

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
anchors : int, optional
The total number of anchors sampled for the Nystroem approximation (anchors/2 in each distribution), default 50.
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
Stop threshold on error on the dual variables (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
random_state : int, optional
The random state for sampling the components in each distribution.


Returns
-------
gamma : LazyTensor
OT plan as lazy tensor.
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import numpy as np
>>> n_samples_a = 2
>>> n_samples_b = 4
>>> reg = 0.1
>>> anchors = 3
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn_nystroem(X_s, X_t, reg, anchors, random_state=42)[:] # doctest: +ELLIPSIS
array([[2.50000000e-01, 1.46537753e-01, 7.29587925e-10, 1.03462246e-01],
[3.63816797e-10, 1.03462247e-01, 2.49999999e-01, 1.46537754e-01]])

References
----------

.. [76] Massively scalable Sinkhorn distances via the Nyström method,
Jason Altschuler, Francis Bach, Alessandro Rudi, Jonathan Niles-Weed, NeurIPS 2019.

"""

left_factor, right_factor = kernel_nystroem(
X_s, X_t, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state
)
_, _, dict_log = sinkhorn_low_rank_kernel(
K1=left_factor,
K2=right_factor,
a=a,
b=b,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=True,
warn=warn,
warmstart=warmstart,
)
if log:
return dict_log["lazy_plan"], dict_log
else:
return dict_log["lazy_plan"]


def empirical_sinkhorn_nystroem2(
X_s,
X_t,
reg=1.0,
anchors=50,
a=None,
b=None,
numItermax=1000,
stopThr=1e-9,
verbose=False,
log=False,
warn=True,
warmstart=None,
random_state=None,
):
r"""
Solves the entropic regularization optimal transport problem with Sinkhorn and Nystroem factorization [76] and returns the
OT loss from empirical data.
Corresponds to an approximation of entropic OT (for a squared Euclidean cost) that runs in linear time.
The number of anchors controls the level of approximation (the higher, the better the approximation, but the slower the computation becomes).

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
anchors : int, optional
The total number of anchors sampled for the Nystroem approximation (anchors/2 in each distribution), default 50.
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
Stop threshold on error on the dual variables (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
random_state : int, optional
The random state for sampling the components in each distribution.

Returns
-------
W : float
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters


Examples
--------

>>> import numpy as np
>>> n_samples_a = 2
>>> n_samples_b = 4
>>> reg = 0.1
>>> anchors = 3
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn_nystroem2(X_s, X_t, reg, anchors, random_state=42) # doctest: +ELLIPSIS
1.9138489870270898


References
----------

.. [76] Massively scalable Sinkhorn distances via the Nyström method,
Jason Altschuler, Francis Bach, Alessandro Rudi, Jonathan Niles-Weed, NeurIPS 2019.


"""

nx = get_backend(X_s, X_t)
M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, False, nx=nx)
left_factor, right_factor = kernel_nystroem(
X_s, X_t, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state
)
if log:
u, v, dict_log = sinkhorn_low_rank_kernel(
K1=left_factor,
K2=right_factor,
a=a,
b=b,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=True,
warn=warn,
warmstart=warmstart,
)
else:
u, v = sinkhorn_low_rank_kernel(
K1=left_factor,
K2=right_factor,
a=a,
b=b,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=False,
warn=warn,
warmstart=warmstart,
)
Q = u.reshape((-1, 1)) * left_factor
R = v.reshape((-1, 1)) * right_factor
# Compute the cost (using trace formula)
A = Q.T @ M1
B = R.T @ M2
loss = nx.sum(A * B)

if log:
return loss, dict_log
else:
return loss
Loading
Loading