diff --git a/.gitignore b/.gitignore index dd5860a94..4fa600f52 100644 --- a/.gitignore +++ b/.gitignore @@ -122,3 +122,4 @@ debug # pytest cahche .pytest_cache + diff --git a/README.md b/README.md index 8b4cca7f7..d1c77689d 100644 --- a/README.md +++ b/README.md @@ -389,3 +389,5 @@ Artificial Intelligence. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. + +[76] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019. diff --git a/RELEASES.md b/RELEASES.md index ec7e5774c..4e2cc94f5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ - Backend implementation of `ot.dist` for (PR #701) - Updated documentation Quickstart guide and User guide with new API (PR #726) - Fix jax version for auto-grad (PR #732) +- Add Nystrom kernel approximation for Sinkhorn (PR #742) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) @@ -36,7 +37,7 @@ This new release contains several new features, starting with a novel [Gaussian Mixture Model Optimal Transport (GMM-OT)](https://pythonot.github.io/master/gen_modules/ot.gmm.html#examples-using-ot-gmm-gmm-ot-apply-map) solver to compare GMM while enforcing the transport plan to remain a GMM, that benefits from a closed-form solution making it practical for high-dimensional matching problems. We also extended our general unbalanced OT solvers to support any non-negative reference measure in the regularization terms, before adding the novel [translation invariant UOT](https://pythonot.github.io/master/auto_examples/unbalanced-partial/plot_conv_sinkhorn_ti.html) solver showcasing a higher convergence speed. We also implemented several new solvers and enhanced existing ones to perform OT across spaces. These include a [semi-relaxed FGW barycenter](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.html) solver, coupled with new initialization heuristics for the inner divergence computation, to perform graph partitioning or dictionary learning. Followed by novel [unbalanced FGW and Co-optimal transport](https://pythonot.github.io/master/auto_examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.html) solvers to promote robustness to outliers in such matching problems. And we finally updated the implementation of partial GW now supporting asymmetric structures and the KL divergence, while leveraging a new generic conditional gradient solver for partial transport problems enabling significant speed improvements. These latest updates required some modifications to the line search functions of our generic conditional gradient solver, paving the way for future improvements to other GW-based solvers. Last but not least, we implemented a pre-commit scheme to automatically correct common programming mistakes likely to be made by our future contributors. This release also contains few bug fixes, concerning the support of any metric in `ot.emd_1d` / `ot.emd2_1d`, and the support of any weights in `ot.gaussian`. - + #### Breaking change - Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663). diff --git a/examples/others/plot_nystroem_approximation.py b/examples/others/plot_nystroem_approximation.py new file mode 100644 index 000000000..db76b7b7d --- /dev/null +++ b/examples/others/plot_nystroem_approximation.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +""" +============================ +Nyström approximation for OT +============================ + +Shows how to use Nyström kernel approximation for approximating the Sinkhorn algorithm in linear time. + + +""" + +# Author: Titouan Vayer +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +from ot.lowrank import kernel_nystroem, sinkhorn_low_rank_kernel +import math +import ot +import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm + +############################################################################## +# Generate data +# ------------- + +# %% +offset = 1 +n_samples_per_blob = 500 # We use 2D ''blobs'' data +random_state = 42 +std = 0.2 # standard deviation +np.random.seed(random_state) + +centers = np.array( + [ + [-offset, -offset], # Class 0 - blob 1 + [-offset, offset], # Class 0 - blob 2 + [offset, -offset], # Class 1 - blob 1 + [offset, offset], # Class 1 - blob 2 + ] +) + +X_list = [] +y_list = [] + +for i, center in enumerate(centers): + blob_points = np.random.randn(n_samples_per_blob, 2) * std + center + label = 0 if i < 2 else 1 + X_list.append(blob_points) + y_list.append(np.full(n_samples_per_blob, label)) + +X = np.vstack(X_list) +y = np.concatenate(y_list) +Xs = X[y == 0] # source data +Xt = X[y == 1] # target data +# %% + +############################################################################## +# Plot data +# --------- +# %% +plt.scatter(Xs[:, 0], Xs[:, 1], label="Source") +plt.scatter(Xt[:, 0], Xt[:, 1], label="Target") +plt.legend() + +############################################################################## +# Compute the Nyström approximation of the Gaussian kernel +# -------------------------------------------------------- + +# %% +reg = 5.0 # proportional to the std of the Gaussian kernel +anchors = 5 # number of anchor points for the Nyström approximation +ot.tic() +left_factor, right_factor = kernel_nystroem( + Xs, Xt, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state +) +ot.toc() + +############################################################################## +# Use this approximation in a Sinkhorn algorithm with low rank kernel. +# Each matrix/vector product in the Sinkhorn is accelerated +# since :math:`Kv = K_1 (K_2^\top v)` can be computed in :math:`O(nr)` time +# instead of :math:`O(n^2)` +# ------------------------------------------------------------------------- +# %% +numItermax = 1000 +stopThr = 1e-7 +verbose = True +a, b = None, None +warn = True +warmstart = None +ot.tic() +u, v, dict_log = sinkhorn_low_rank_kernel( + K1=left_factor, + K2=right_factor, + a=a, + b=b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + warn=warn, + warmstart=warmstart, +) +ot.toc() +############################################################################## +# Compare with Sinkhorn +# --------------------- +# %% +M = ot.dist(Xs, Xt) +ot.tic() +G, log_ = ot.sinkhorn( + a=[], + b=[], + M=M, + reg=reg, + numItermax=numItermax, + verbose=verbose, + log=True, + warn=warn, + warmstart=warmstart, +) +ot.toc() + +############################################################################## +# Use directly ot.bregman.empirical_sinkhorn_nystroem +# -------------------------------------------------- +# %% We can compare directly with empirical +ot.tic() +G_nys = ot.bregman.empirical_sinkhorn_nystroem( + Xs, + Xt, + anchors=anchors, + reg=reg, + numItermax=numItermax, + verbose=True, + random_state=random_state, +)[:] +ot.toc() +# %% +ot.tic() +G_sinkh = ot.bregman.empirical_sinkhorn( + Xs, Xt, reg=reg, numIterMax=numItermax, verbose=True +) +ot.toc() + +############################################################################## +# Compare OT plans +# ---------------- +# %% +fig, ax = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True) +vmin = min(G_sinkh.min(), G_nys.min()) +vmax = max(G_sinkh.max(), G_nys.max()) +norm = LogNorm(vmin=vmin, vmax=vmax) +im0 = ax[0].imshow(G_sinkh, norm=norm, cmap="coolwarm") +im1 = ax[1].imshow(G_nys, norm=norm, cmap="coolwarm") +cbar = fig.colorbar(im1, ax=ax, orientation="vertical", fraction=0.046, pad=0.04) +ax[0].set_title("OT plan Sinkhorn") +ax[1].set_title("OT plan Nyström Sinkhorn") +for a in ax: + a.set_xticks([]) + a.set_yticks([]) +plt.show() diff --git a/ot/backend.py b/ot/backend.py index 3d59639fa..14b75c9c4 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1368,6 +1368,9 @@ def trace(self, a): def inv(self, a): return scipy.linalg.inv(a) + def pinv(self, a, hermitian=False): + return np.linalg.pinv(a, hermitian=hermitian) + def sqrtm(self, a): L, V = np.linalg.eigh(a) L = np.sqrt(L) @@ -1781,6 +1784,9 @@ def trace(self, a): def inv(self, a): return jnp.linalg.inv(a) + def pinv(self, a, hermitian=False): + return jnp.linalg.pinv(a, hermitian=hermitian) + def sqrtm(self, a): L, V = jnp.linalg.eigh(a) L = jnp.sqrt(L) @@ -2314,6 +2320,9 @@ def trace(self, a): def inv(self, a): return torch.linalg.inv(a) + def pinv(self, a, hermitian=False): + return torch.linalg.pinv(a, hermitian=hermitian) + def sqrtm(self, a): L, V = torch.linalg.eigh(a) L = torch.sqrt(L) @@ -2728,6 +2737,9 @@ def trace(self, a): def inv(self, a): return cp.linalg.inv(a) + def pinv(self, a, hermitian=False): + return cp.linalg.pinv(a) + def sqrtm(self, a): L, V = cp.linalg.eigh(a) L = cp.sqrt(L) @@ -3164,6 +3176,9 @@ def trace(self, a): def inv(self, a): return tf.linalg.inv(a) + def pinv(self, a, hermitian=False): + return tf.linalg.pinv(a) + def sqrtm(self, a): L, V = tf.linalg.eigh(a) L = tf.sqrt(L) diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py index 54d7eca27..9f2b89e94 100644 --- a/ot/bregman/__init__.py +++ b/ot/bregman/__init__.py @@ -38,6 +38,8 @@ empirical_sinkhorn, empirical_sinkhorn2, empirical_sinkhorn_divergence, + empirical_sinkhorn_nystroem, + empirical_sinkhorn_nystroem2, ) from ._screenkhorn import screenkhorn @@ -71,6 +73,8 @@ "empirical_sinkhorn2", "empirical_sinkhorn2_geomloss", "empirical_sinkhorn_divergence", + "empirical_sinkhorn_nystroem", + "empirical_sinkhorn_nystroem2", "geomloss", "screenkhorn", "unmix", diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py index e010aa0c7..c51fc0a81 100644 --- a/ot/bregman/_empirical.py +++ b/ot/bregman/_empirical.py @@ -6,15 +6,22 @@ # Author: Remi Flamary # Kilian Fatras # Quang Huy Tran +# Titouan Vayer # # License: MIT License import warnings +import math from ..utils import dist, list_to_array, unif, LazyTensor from ..backend import get_backend from ._sinkhorn import sinkhorn, sinkhorn2 +from ..lowrank import ( + kernel_nystroem, + sinkhorn_low_rank_kernel, + compute_lr_sqeuclidean_matrix, +) def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric="sqeuclidean", reg=1e-1, nx=None): @@ -267,7 +274,7 @@ def empirical_sinkhorn( else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn( + G, log = sinkhorn( a, b, M, @@ -279,9 +286,9 @@ def empirical_sinkhorn( warmstart=warmstart, **kwargs, ) - return pi, log + return G, log else: - pi = sinkhorn( + G = sinkhorn( a, b, M, @@ -293,7 +300,7 @@ def empirical_sinkhorn( warmstart=warmstart, **kwargs, ) - return pi + return G def empirical_sinkhorn2( @@ -755,3 +762,235 @@ def empirical_sinkhorn_divergence( sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return nx.maximum(0, sinkhorn_div) + + +def empirical_sinkhorn_nystroem( + X_s, + X_t, + reg=1.0, + anchors=50, + a=None, + b=None, + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, + random_state=None, +): + r""" + Solves the entropic regularization optimal transport problem with Sinkhorn and Nystroem factorization [76] and returns the + OT matrix from empirical data. + Corresponds to an approximation of entropic OT (for a squared Euclidean cost) that runs in linear time. + The number of anchors controls the level of approximation (the higher, the better the approximation, but the slower the computation becomes). + Warning: with low level of regularization, the OT plan can have non-positive values. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + anchors : int, optional + The total number of anchors sampled for the Nystroem approximation (anchors/2 in each distribution), default 50. + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations of Sinkhorn + stopThr : float, optional + Stop threshold on error on the dual variables (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + random_state : int, optional + The random state for sampling the components in each distribution. + + Returns + ------- + gamma : LazyTensor + OT plan as lazy tensor. + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import numpy as np + >>> n_samples_a = 2 + >>> n_samples_b = 4 + >>> reg = 0.1 + >>> anchors = 3 + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> empirical_sinkhorn_nystroem(X_s, X_t, reg, anchors, random_state=42)[:] # doctest: +ELLIPSIS + array([[2.50000000e-01, 1.46537753e-01, 7.29587925e-10, 1.03462246e-01], + [3.63816797e-10, 1.03462247e-01, 2.49999999e-01, 1.46537754e-01]]) + + References + ---------- + + .. [76] Massively scalable Sinkhorn distances via the Nyström method, + Jason Altschuler, Francis Bach, Alessandro Rudi, Jonathan Niles-Weed, NeurIPS 2019. + + """ + left_factor, right_factor = kernel_nystroem( + X_s, X_t, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state + ) + _, _, dict_log = sinkhorn_low_rank_kernel( + K1=left_factor, + K2=right_factor, + a=a, + b=b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + warn=warn, + warmstart=warmstart, + ) + if log: + dict_log["left_factor"] = left_factor + dict_log["right_factor"] = right_factor + return dict_log["lazy_plan"], dict_log + else: + return dict_log["lazy_plan"] + + +def empirical_sinkhorn_nystroem2( + X_s, + X_t, + reg=1.0, + anchors=50, + a=None, + b=None, + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, + random_state=None, +): + r""" + Solves the entropic regularization optimal transport problem with Sinkhorn and Nystroem factorization [76] and returns the + OT loss from empirical data. + Corresponds to an approximation of entropic OT (for a squared Euclidean cost) that runs in linear time. + The number of anchors controls the level of approximation (the higher, the better the approximation, but the slower the computation becomes). + Warning: with low level of regularization, the OT plan can have non-positive values. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + anchors : int, optional + The total number of anchors sampled for the Nystroem approximation (anchors/2 in each distribution), default 50. + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations of Sinkhorn + stopThr : float, optional + Stop threshold on error on the dual variables (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + random_state : int, optional + The random state for sampling the components in each distribution. + + Returns + ------- + W : float + Optimal transportation loss for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + Examples + -------- + + >>> import numpy as np + >>> n_samples_a = 2 + >>> n_samples_b = 4 + >>> reg = 0.1 + >>> anchors = 3 + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> empirical_sinkhorn_nystroem2(X_s, X_t, reg, anchors, random_state=42) # doctest: +ELLIPSIS + 1.9138489870270898 + + + References + ---------- + + .. [76] Massively scalable Sinkhorn distances via the Nyström method, + Jason Altschuler, Francis Bach, Alessandro Rudi, Jonathan Niles-Weed, NeurIPS 2019. + + + """ + + nx = get_backend(X_s, X_t) + M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, False, nx=nx) + left_factor, right_factor = kernel_nystroem( + X_s, X_t, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state + ) + if log: + u, v, dict_log = sinkhorn_low_rank_kernel( + K1=left_factor, + K2=right_factor, + a=a, + b=b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + warn=warn, + warmstart=warmstart, + ) + dict_log["left_factor"] = left_factor + dict_log["right_factor"] = right_factor + else: + u, v = sinkhorn_low_rank_kernel( + K1=left_factor, + K2=right_factor, + a=a, + b=b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=False, + warn=warn, + warmstart=warmstart, + ) + Q = u.reshape((-1, 1)) * left_factor + R = v.reshape((-1, 1)) * right_factor + # Compute the cost (using trace formula) + A = Q.T @ M1 + B = R.T @ M2 + loss = nx.sum(A * B) + + if log: + return loss, dict_log + else: + return loss diff --git a/ot/bregman/_sinkhorn.py b/ot/bregman/_sinkhorn.py index 342ba02e8..cf5efadfc 100644 --- a/ot/bregman/_sinkhorn.py +++ b/ot/bregman/_sinkhorn.py @@ -14,7 +14,6 @@ import warnings import numpy as np - from ..utils import list_to_array from ..backend import get_backend diff --git a/ot/lowrank.py b/ot/lowrank.py index 14a10f163..312251839 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -3,6 +3,7 @@ """ # Author: Laurène David +# Titouan Vayer # # License: MIT License @@ -10,6 +11,8 @@ from .utils import unif, dist, get_lowrank_lazytensor from .backend import get_backend from .bregman import sinkhorn +import random +import math # test if sklearn is installed for linux-minimal-deps try: @@ -505,9 +508,9 @@ def lowrank_sinkhorn( lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) # Compute value_linear (using trace formula) - v1 = nx.dot(Q.T, M1) - v2 = nx.dot(R, (v1.T * diag_g).T) - value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) + A = nx.dot(M1.T, Q) + B = nx.dot(M2.T, R * diag_g) + value_linear = nx.sum(A * B) # Compute value with entropy reg (see "Section 3.2" in the paper) reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q @@ -524,3 +527,212 @@ def lowrank_sinkhorn( return Q, R, g, dict_log return Q, R, g + + +def kernel_nystroem(X_s, X_t, anchors=50, sigma=1.0, random_state=None): + r""" + Compute left and right factors corresponding to the Nystroem method on the Gaussian kernel :math:`K(x^s_i, x^t_j) = \exp(-\|x^s_i-x^t_j\|^2/2\sigma^2)`. + The Nystroem approximation is computed by sampling :math:`\min(n, \lceil(c / 2))\rceil' components in each distribution, + where :math:`n` is the number of samples in the distribution and :math:`c` the total number of anchor points. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + anchors : int, optional + The total number of anchors sampled for the Nystroem approximation (anchors/2 in each distribution), default 50. + sigma : float, optional + The standard deviation parameter for the Gaussian kernel. + random_state : int, optional + The random state for sampling the components in each distribution. + + Returns + ------- + left_factor : array-like, shape (n_samples_a, dim_r) + Left factor of Nystroem + right_factor : array-like, shape (n_samples_b, dim_r) + Right factor of Nystroem + """ + nx = get_backend(X_s, X_t) + + random.seed(random_state) + n, m = X_s.shape[0], X_t.shape[0] + n_components_source = min(n, math.ceil(anchors / 2)) + n_components_target = min(m, math.ceil(anchors / 2)) + # draw n_components/2 points in each distribution + inds_source = nx.arange(n) # sample n_components_source uniformly + random.shuffle(inds_source) + basis_source = X_s[inds_source[:n_components_source]] + + inds_target = nx.arange(m) + random.shuffle(inds_target) + basis_target = X_t[inds_target[:n_components_target]] + + basis = nx.concatenate((basis_source, basis_target)) + + Mzz = dist(basis, metric="sqeuclidean") # compute \|z_i - z_j\|_2^2 + basis_kernel = nx.exp(-Mzz / (2.0 * sigma**2)) + + normalization = nx.pinv(basis_kernel, hermitian=True) + + Mxz = dist(X_s, basis, metric="sqeuclidean") + Myz = dist(X_t, basis, metric="sqeuclidean") + + left_factor = nx.exp(-Mxz / (2.0 * sigma**2)) @ normalization + right_factor = nx.exp(-Myz / (2.0 * sigma**2)) + + return left_factor, right_factor # left_factor @ right_factor.T approx K + + +def sinkhorn_low_rank_kernel( + K1, # left factor + K2, # right factor + a=None, + b=None, + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, +): + r""" + Compute the Sinkhorn algorithm for a kernel :math:`K` that can be written as a low rank factorization :math:`K = K_1 K_2^\top`. + + Precisely : + + - :math:`\mathbf{K}_1`, `\mathbf{K}_2` are the (`dim_a`, `dim_r`), (`dim_b`, `dim_r`) kernel matrices + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] ` + + Parameters + ---------- + K_1 : array-like, shape (n_samples_a, dim_r) + Left factor + K_2 : array-like, shape (n_samples_b, dim_r) + Right factor + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) or array-like, shape (n_samples_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + if :math:`\mathbf{b}` is a matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (n_samples_a, n_samples_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + --------- + u : array-like, shape (n_samples_a, ) or array-like, shape (n_samples_a, n_hists) + Left dual variable + v: array-like, shape (n_samples_b, ) or array-like, shape (n_samples_b, n_hists) + Right dual variable + log : dict (lazy_plan) + log dictionary return only if log==True in parameters + + """ + + nx = get_backend(K1, K2, a, b) + + if a is None: + a = nx.full((K1.shape[0],), 1.0 / K1.shape[0], type_as=K1) + if b is None: + b = nx.full((K2.shape[0],), 1.0 / K2.shape[0], type_as=K2) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + dict_log = {"err": []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=K1) / dim_a + v = nx.ones((dim_b, n_hists), type_as=K2) / dim_b + else: + u = nx.ones(dim_a, type_as=K1) / dim_a + v = nx.ones(dim_b, type_as=K2) / dim_b + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + + err = 1 + for ii in range(numItermax): + uprev = u + vprev = v + KtransposeU = K2 @ (nx.transpose(K1) @ u) + v = b / KtransposeU + KV = K1 @ (nx.transpose(K2) @ v) + u = a / KV + + if ( + nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn("Warning: numerical errors at iteration %d" % ii) + u = uprev + v = vprev + break + if ii % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + tmp2 = ( + v * (K2 @ (nx.transpose(K1) @ u)) + ) # same as nx.einsum("ik, ir, jr, jk->jk", u, K1, K2, v) and nx.einsum("i, ir, jr, j->j", u, K1, K2, v) in the scalar case + err = nx.norm(tmp2 - b) # violation of marginal + if log: + dict_log["err"].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) + else: + if warn: + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) + + if log: + dict_log["niter"] = ii + dict_log["u"] = u + dict_log["v"] = v + dict_log["lazy_plan"] = get_lowrank_lazytensor( + u.reshape((-1, 1)) * K1, v.reshape((-1, 1)) * K2 + ) + return u, v, dict_log + + else: + return u, v diff --git a/ot/solvers.py b/ot/solvers.py index a5bbf0e94..410ef6d19 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1388,9 +1388,9 @@ def solve_sample( Parameters ---------- - X_s : array-like, shape (n_samples_a, dim) + X_a : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : array-like, shape (n_samples_b, dim) + X_b : array-like, shape (n_samples_b, dim) samples in the target domain a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 4c755d3e9..53525ced8 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -1,6 +1,7 @@ """Test for low rank sinkhorn solvers""" # Author: Laurène DAVID +# Titouan Vayer # # License: MIT License @@ -10,6 +11,75 @@ from ot.lowrank import sklearn_import # check sklearn installation +def test_nystroem_kernel_approx(): + # test nystroem kernel approx in easy regime (nb anchors = nb points) + n = 30 + d = 3 + Xs = np.random.randn(n, d) + Xt = np.random.randn(n, d) + 2 + sigma = 2.0 + K = np.exp(-ot.dist(Xs, Xt) / (2 * sigma**2)) + U, V = ot.lowrank.kernel_nystroem(Xs, Xt, anchors=60, sigma=sigma, random_state=42) + + np.testing.assert_allclose(K, U @ V.T, atol=1e-7) + + +def test_nystroem_sinkhorn(): + # test Nystrom approximation for Sinkhorn + offset = 2 + n_samples_per_blob = 50 + random_state = 42 + std = 0.1 + np.random.seed(random_state) + + centers = np.array( + [ + [-offset, -offset], # Class 0 - blob 1 + [-offset, offset], # Class 0 - blob 2 + [offset, -offset], # Class 1 - blob 1 + [offset, offset], # Class 1 - blob 2 + ] + ) + + X_list = [] + y_list = [] + + for i, center in enumerate(centers): + blob_points = np.random.randn(n_samples_per_blob, 2) * std + center + label = 0 if i < 2 else 1 + X_list.append(blob_points) + y_list.append(np.full(n_samples_per_blob, label)) + + X = np.vstack(X_list) + y = np.concatenate(y_list) + Xs = X[y == 0] + Xt = X[y == 1] + + reg = 5.0 + anchors = 5 + + G_nys = ot.bregman.empirical_sinkhorn_nystroem( + Xs, + Xt, + anchors=anchors, + reg=reg, + numItermax=3000, + verbose=True, + random_state=random_state, + )[:] + + G_sinkh = ot.bregman.empirical_sinkhorn( + Xs, Xt, reg=reg, numIterMax=3000, verbose=True + ) + + a = ot.unif(Xs.shape[0]) + b = ot.unif(Xt.shape[0]) + + np.testing.assert_allclose(G_sinkh, G_nys, atol=1e-04) + np.testing.assert_allclose(a, G_nys.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G_nys.sum(0), atol=1e-05) + + def test_compute_lr_sqeuclidean_matrix(): # test computation of low rank cost matrices M1 and M2 n = 100