diff --git a/README.md b/README.md index 8b4cca7f7..b8e976f55 100644 --- a/README.md +++ b/README.md @@ -389,3 +389,7 @@ Artificial Intelligence. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. + +[76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations. + +[77] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations. diff --git a/RELEASES.md b/RELEASES.md index ec7e5774c..20f7323a9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,8 @@ - Backend implementation of `ot.dist` for (PR #701) - Updated documentation Quickstart guide and User guide with new API (PR #726) - Fix jax version for auto-grad (PR #732) +- Added `ot.solver_1d.linear_circular_ot` (PR #736) +- Added `ot.sliced.linear_sliced_wasserstein_sphere` (PR #736) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py index 5420fea97..27d9fe117 100644 --- a/examples/backends/plot_ssw_unif_torch.py +++ b/examples/backends/plot_ssw_unif_torch.py @@ -10,7 +10,7 @@ .. math:: \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) -where :math:`\nu=\mathrm{Unif}(S^1)`. +where :math:`\nu=\mathrm{Unif}(S^{d-1})`. """ @@ -46,15 +46,18 @@ def plot_sphere(ax): - xlist = np.linspace(-1.0, 1.0, 50) - ylist = np.linspace(-1.0, 1.0, 50) - r = np.linspace(1.0, 1.0, 50) - X, Y = np.meshgrid(xlist, ylist) + # Create a sphere using spherical coordinates + phi = np.linspace(0, 2 * np.pi, 100) + theta = np.linspace(0, np.pi, 100) + phi, theta = np.meshgrid(phi, theta) - Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0)) + # Compute the spherical coordinates + X = np.sin(theta) * np.cos(phi) + Y = np.sin(theta) * np.sin(phi) + Z = np.cos(theta) + # Plot the wireframe ax.plot_wireframe(X, Y, Z, color="gray", alpha=0.3) - ax.plot_wireframe(X, Y, -Z, color="gray", alpha=0.3) # Now plot the bottom half # plot the distributions diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py index 0335fcbe7..e431f7f93 100644 --- a/examples/plot_compute_wasserstein_circle.py +++ b/examples/plot_compute_wasserstein_circle.py @@ -102,17 +102,23 @@ def pdf_von_Mises(theta, mu, kappa): L_w2_circle = np.zeros((n_try, 200)) L_w2 = np.zeros((n_try, 200)) +L_lcot = np.zeros((n_try, 200)) for i in range(n_try): w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) + w_lcot = ot.linear_circular_ot(xs2.T, xts2[i].T) L_w2_circle[i] = w2_circle L_w2[i] = w2 + L_lcot[i] = w_lcot m_w2_circle = np.mean(L_w2_circle, axis=0) std_w2_circle = np.std(L_w2_circle, axis=0) +m_w2_lcot = np.mean(L_lcot, axis=0) +std_w2_lcot = np.std(L_lcot, axis=0) + m_w2 = np.mean(L_w2, axis=0) std_w2 = np.std(L_w2, axis=0) @@ -128,6 +134,13 @@ def pdf_von_Mises(theta, mu, kappa): pl.fill_between( mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5 ) +pl.plot(mu_targets / (2 * np.pi), m_w2_lcot, label="Linear COT") +pl.fill_between( + mu_targets / (2 * np.pi), + m_w2_lcot - 2 * std_w2_lcot, + m_w2_lcot + 2 * std_w2_lcot, + alpha=0.5, +) pl.vlines( x=[mu1 / (2 * np.pi)], ymin=0, @@ -159,15 +172,23 @@ def pdf_von_Mises(theta, mu, kappa): xts[i, k] = xt / (2 * np.pi) L_w2 = np.zeros((n_try, 100)) +L_lcot = np.zeros((n_try, 100)) for i in range(n_try): L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) + L_lcot[i] = ot.linear_circular_ot(xts[i].T) m_w2 = np.mean(L_w2, axis=0) std_w2 = np.std(L_w2, axis=0) +m_lcot = np.mean(L_lcot, axis=0) +std_lcot = np.mean(L_lcot, axis=0) + pl.figure(1) -pl.plot(kappas, m_w2) +pl.plot(kappas, m_w2, label="Wasserstein") pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) +pl.plot(kappas, m_lcot, label="LCOT") +pl.fill_between(kappas, m_lcot - std_lcot, m_lcot + std_lcot, alpha=0.5) +pl.legend() pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") pl.xlabel(r"$\kappa$") pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 5e21d6a76..1aad7828a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -48,6 +48,7 @@ binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle, + linear_circular_ot, ) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 @@ -57,6 +58,7 @@ max_sliced_wasserstein_distance, sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif, + linear_sliced_wasserstein_sphere, ) from .gromov import ( gromov_wasserstein, @@ -105,6 +107,7 @@ "sinkhorn_unbalanced2", "sliced_wasserstein_distance", "sliced_wasserstein_sphere", + "linear_sliced_wasserstein_sphere", "gromov_wasserstein", "gromov_wasserstein2", "gromov_barycenters", @@ -129,6 +132,7 @@ "binary_search_circle", "wasserstein_circle", "semidiscrete_wasserstein2_unif_circle", + "linear_circular_ot", "sliced_wasserstein_sphere_unif", "lowrank_sinkhorn", "lowrank_gromov_wasserstein_samples", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 932b261df..de4ca07a1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,7 @@ binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle, + linear_circular_ot, ) __all__ = [ @@ -42,6 +43,7 @@ "binary_search_circle", "wasserstein_circle", "semidiscrete_wasserstein2_unif_circle", + "linear_circular_ot", "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", "check_number_threads", diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index c308549f8..515c40b03 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -791,10 +791,12 @@ def binary_search_circle( -1, 1 ) - mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) - tc[mask_end > 0] = ( - (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) - )[mask_end > 0] + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ( + (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) + )[mask_end > 0] done[nx.prod(mask, axis=-1) > 0] = 1 elif nx.any(1 - done): tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] @@ -933,8 +935,8 @@ def wasserstein_circle( eps=1e-6, require_sort=True, ): - r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or - the binary search algorithm proposed in [44] otherwise. + r"""Computes the Wasserstein distance on the circle using either :ref:`[45] ` for p=1 or + the binary search algorithm proposed in :ref:`[44] ` otherwise. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates @@ -996,6 +998,8 @@ def wasserstein_circle( >>> wasserstein_circle(u.T, v.T) array([0.1]) + + .. _references-wasserstein-circle: References ---------- .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. @@ -1003,10 +1007,10 @@ def wasserstein_circle( """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - if p == 1: - return wasserstein1_circle( - u_values, v_values, u_weights, v_weights, require_sort - ) + # if p == 1: + # return wasserstein1_circle( + # u_values, v_values, u_weights, v_weights, require_sort + # ) return binary_search_circle( u_values, @@ -1042,7 +1046,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, - using e.g. ot.utils.get_coordinate_circle(x) + using e.g. ot.utils.get_coordinate_circle(x). Parameters ---------- @@ -1095,3 +1099,150 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): cpt2 = nx.sum(u_values * u_weights * ns, axis=0) return cpt1 - u_mean**2 + cpt2 + 1 / 12 + + +def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True): + r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference + :math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`. + + For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[76] `) + + .. math`` + \hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x. + + Parameters + ---------- + x : ndary, shape (m,) + Points in [0,1[ where to evaluate the embedding + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + embedding: ndarray of shape (m, ...) + Embedding evaluated at :math:`x` + + .. _references-lcot: + References + ---------- + .. [76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + """ + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + q_s = ( + x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5 + ) # shape (m, ...) + + u_quantiles = quantile_function(q_s % 1, u_cdf, u_values) + + return (u_quantiles - x[:, None]) % 1 + + +def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None): + r"""Computes the Linear Circular Optimal Transport distance from :ref:`[76] ` using :math:`\eta=\mathrm{Unif}(S^1)` + as reference measure. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + \mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t + + where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`, + and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...), optional + samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + + Returns + ------- + loss: float + Cost associated to the linear optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> linear_circular_ot(u.T, v.T) + array([0.0127]) + + + .. _references-lcot: + References + ---------- + .. [76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + """ + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1] + + emb_u = linear_circular_embedding(unif_s1, u_values, u_weights) + + if v_values is None: + dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u)) + return nx.mean(dist_u**2, axis=0) + else: + m = v_values.shape[0] + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + emb_v = linear_circular_embedding(unif_s1, v_values, v_weights) + + dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v)) + return nx.mean(dist_uv**2, axis=0) diff --git a/ot/sliced.py b/ot/sliced.py index cd095ed6d..376f3967d 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,11 @@ import numpy as np from .backend import get_backend, NumpyBackend from .utils import list_to_array, get_coordinate_circle -from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle +from .lp import ( + wasserstein_circle, + semidiscrete_wasserstein2_unif_circle, + linear_circular_ot, +) def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): @@ -96,7 +100,7 @@ def sliced_wasserstein_distance( samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation - p: float, optional = + p: float, optional Power p used for computing the sliced Wasserstein projections: shape (dim, n_projections), optional Projection matrix (n_projections and seed are not used in this case) @@ -284,6 +288,109 @@ def max_sliced_wasserstein_distance( return res +def get_projections_sphere(d, n_projections, seed=None, backend=None, type_as=None): + r""" + Generates n_projections samples from the uniform distribution on the Stiefel manifold of dimension :math:`d\times 2`: :math:`\mathbb{V}_{d,2}=\{X \in \mathbb{R}^{d\times 2}, X^TX=I_2\}` + + Parameters + ---------- + d : int + dimension of the space + n_projections : int + number of samples requested + seed: int or RandomState, optional + Seed used for numpy random number generator + backend: + Backend to use for random generation + type_as: optional + Type to use for random generation + + Returns + ------- + out: ndarray, shape (n_projections, d, 2) + + Examples + -------- + >>> n_projections = 100 + >>> d = 5 + >>> projs = get_projections_sphere(d, n_projections) + >>> np.allclose(np.einsum("nij, nik -> njk", projs, projs), np.eye(2)) # doctest: +NORMALIZE_WHITESPACE + True + """ + if backend is None: + nx = NumpyBackend() + else: + nx = backend + + if isinstance(seed, np.random.RandomState) and str(nx) == "numpy": + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=type_as) + + projections, _ = nx.qr(Z) + return projections + + +def projection_sphere_to_circle( + x, n_projections=50, projections=None, seed=None, backend=None +): + r""" + Projection of :math:`x\in S^{d-1}` on circles using coordinates on [0,1[. + + To get the projection on the circle, we use the following formula: + .. math:: + P^U(x) = \frac{U^Tx}{\|U^Tx\|_2} + + where :math:`U` is a random matrix sampled from the uniform distribution on the Stiefel manifold of dimension :math:`d\times 2`: :math:`\mathbb{V}_{d,2}=\{X \in \mathbb{R}^{d\times 2}, X^TX=I_2\}` + and :math:`x` is a point on the sphere. Then, we apply the function get_coordinate_circle to get the coordinates on :math:`[0,1[`. + + Parameters + ---------- + x : ndarray, shape (n_samples, dim) + samples on the sphere + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + projections: shape (n_projections, dim, 2), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + backend: + Backend to use for random generation + + Returns + ------- + Xp_coords: ndarray, shape (n_projections, n_samples) + Coordinates of the projections on the circle + """ + if backend is None: + nx = get_backend(x) + else: + nx = backend + + n, d = x.shape + + if projections is None: + projections = get_projections_sphere( + d, n_projections, seed=seed, backend=nx, type_as=x + ) + + # Projection on S^1 + # Projection on plane + Xp = nx.einsum("ikj, lk -> ilj", projections, x) + + # Projection on sphere + Xp = Xp / nx.sqrt(nx.sum(Xp**2, -1, keepdims=True)) + + # Get coordinates on [0,1[ + Xp_coords = nx.reshape( + get_coordinate_circle(nx.reshape(Xp, (-1, 2))), (n_projections, n) + ) + + return Xp_coords, projections + + def sliced_wasserstein_sphere( X_s, X_t, @@ -347,14 +454,13 @@ def sliced_wasserstein_sphere( ---------- .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ + d = X_s.shape[-1] + if a is not None and b is not None: nx = get_backend(X_s, X_t, a, b) else: nx = get_backend(X_s, X_t) - n, d = X_s.shape - m, _ = X_t.shape - if X_s.shape[1] != X_t.shape[1]: raise ValueError( "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( @@ -367,33 +473,16 @@ def sliced_wasserstein_sphere( raise ValueError("X_t is not on the sphere.") if projections is None: - # Uniforms and independent samples on the Stiefel manifold V_{d,2} - if isinstance(seed, np.random.RandomState) and str(nx) == "numpy": - Z = seed.randn(n_projections, d, 2) - else: - if seed is not None: - nx.seed(seed) - Z = nx.randn(n_projections, d, 2, type_as=X_s) - - projections, _ = nx.qr(Z) - else: - n_projections = projections.shape[0] - - # Projection on S^1 - # Projection on plane - Xps = nx.einsum("ikj, lk -> ilj", projections, X_s) - Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t) - - # Projection on sphere - Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) - Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) + projections = get_projections_sphere( + d, n_projections, seed=seed, backend=nx, type_as=X_s + ) - # Get coordinates on [0,1[ - Xps_coords = nx.reshape( - get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n) + Xps_coords, _ = projection_sphere_to_circle( + X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx ) - Xpt_coords = nx.reshape( - get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m) + + Xpt_coords, _ = projection_sphere_to_circle( + X_t, n_projections=n_projections, projections=projections, seed=seed, backend=nx ) projected_emd = wasserstein_circle( @@ -406,7 +495,9 @@ def sliced_wasserstein_sphere( return res -def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False): +def sliced_wasserstein_sphere_unif( + X_s, a=None, n_projections=50, projections=None, seed=None, log=False +): r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution. .. math:: @@ -415,7 +506,7 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log where - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}` - - :math:`\nu=\mathrm{Unif}(S^1)` + - :math:`\nu=\mathrm{Unif}(S^{d-1})` Parameters ---------- @@ -425,6 +516,8 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log samples weights in the source domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation + projections: shape (n_projections, dim, 2), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional Seed used for random number generator log: bool, optional @@ -450,36 +543,23 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log ----------- .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ + d = X_s.shape[-1] + if a is not None: nx = get_backend(X_s, a) else: nx = get_backend(X_s) - n, d = X_s.shape - if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)): raise ValueError("X_s is not on the sphere.") - # Uniforms and independent samples on the Stiefel manifold V_{d,2} - if isinstance(seed, np.random.RandomState) and str(nx) == "numpy": - Z = seed.randn(n_projections, d, 2) - else: - if seed is not None: - nx.seed(seed) - Z = nx.randn(n_projections, d, 2, type_as=X_s) - - projections, _ = nx.qr(Z) - - # Projection on S^1 - # Projection on plane - Xps = nx.einsum("ikj, lk -> ilj", projections, X_s) - - # Projection on sphere - Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + if projections is None: + projections = get_projections_sphere( + d, n_projections, seed=seed, backend=nx, type_as=X_s + ) - # Get coordinates on [0,1[ - Xps_coords = nx.reshape( - get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n) + Xps_coords, _ = projection_sphere_to_circle( + X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx ) projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) @@ -488,3 +568,109 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def linear_sliced_wasserstein_sphere( + X_s, + X_t=None, + a=None, + b=None, + n_projections=50, + projections=None, + seed=None, + log=False, +): + r"""Computes the linear spherical sliced wasserstein distance from :ref:`[77] `. + + General loss returned: + + .. math:: + \mathrm{LSSOT}_2(\mu, \nu) = \left(\int_{\mathbb{V}_{d,2}} \mathrm{LCOT}_2^2(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac12}, + + where :math:`\mu,\nu\in\mathcal{P}(S^{d-1})` are two probability measures on the sphere, :math:`\mathrm{LCOT}_2` is the linear circular optimal transport distance, + and :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`. + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + X_t: ndarray, shape (n_samples_b, dim), optional + Samples in the target domain. If None, computes the distance against the uniform distribution on the sphere. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + projections: shape (n_projections, dim, 2), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, linear_sliced_wasserstein_sphere returns the projections used and their associated LCOT. + + Returns + ------- + cost: float + Linear Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + --------- + >>> n_samples_a = 20 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True)) + >>> linear_sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + + .. _references-lssot: + References + ---------- + .. [77] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data. International Conference on Learning Representations. + """ + d = X_s.shape[-1] + + if a is not None and b is not None: + nx = get_backend(X_s, X_t, a, b) + else: + nx = get_backend(X_s, X_t) + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)): + raise ValueError("X_s is not on the sphere.") + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10 ** (-4)): + raise ValueError("X_t is not on the sphere.") + + if projections is None: + projections = get_projections_sphere( + d, n_projections, seed=seed, backend=nx, type_as=X_s + ) + + Xps_coords, _ = projection_sphere_to_circle( + X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx + ) + + if X_t is not None: + Xpt_coords, _ = projection_sphere_to_circle( + X_t, + n_projections=n_projections, + projections=projections, + seed=seed, + backend=nx, + ) + + projected_lcot = linear_circular_ot( + Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b + ) + res = nx.mean(projected_lcot) ** (1 / 2) + + if log: + return res, {"projections": projections, "projected_emds": projected_lcot} + return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 7ab1009af..c2f377469 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -355,3 +355,88 @@ def test_wasserstein_circle_bad_shape(): with pytest.raises(ValueError): _ = ot.wasserstein_circle(u, v, p=1) + + +@pytest.skip_backend("tf") +def test_linear_circular_ot_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + lcot = ot.linear_circular_ot(xb, xb, rho_ub, rho_vb) + + nx.assert_same_dtype_device(xb, lcot) + + +def test_linear_circular_ot_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.linear_circular_ot(u, v) + + +def test_linear_circular_ot_same_dist(): + n = 20 + rng = np.random.RandomState(0) + u = rng.rand(n) + + lcot = ot.linear_circular_ot(u, u) + np.testing.assert_almost_equal(lcot, 0.0) + + +def test_linear_circular_ot_different_dist(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n) + v = rng.rand(m) + + lcot = ot.linear_circular_ot(u, v) + assert lcot > 0.0 + + +def test_linear_circular_embedding_shape(): + n = 20 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + + ts = np.linspace(0, 1, 101)[:-1] + + emb = ot.lp.solver_1d.linear_circular_embedding(ts, u) + assert emb.shape == (100, 2) + + emb = ot.lp.solver_1d.linear_circular_embedding(ts, u[:, 0]) + assert emb.shape == (100, 1) + + +def test_linear_circular_ot_unif_circle(): + n = 20 + m = 1000 + + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + lcot = ot.linear_circular_ot(u, v) + lcot_unif = ot.linear_circular_ot(u) + + # check loss is similar + np.testing.assert_allclose(lcot, lcot_unif, atol=1e-2) diff --git a/test/test_sliced.py b/test/test_sliced.py index 566a7fdf6..05de13755 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -289,6 +289,30 @@ def test_projections_stiefel(): np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]) ) + rng = np.random.RandomState(0) + + projections = ot.sliced.get_projections_sphere(3, n_projs, seed=rng) + projections_T = np.transpose(projections, [0, 2, 1]) + + np.testing.assert_almost_equal( + np.matmul(projections_T, projections), + np.array([np.eye(2) for k in range(n_projs)]), + ) + + # np.testing.assert_almost_equal(projections, P) + + +def test_projections_sphere_to_circle(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = rng.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + x_projs, _ = ot.sliced.projection_sphere_to_circle(x, n_projs) + assert x_projs.shape == (n_projs, 100) + assert np.all(x_projs >= 0) and np.all(x_projs < 1) + def test_sliced_sphere_same_dist(): n = 100 @@ -346,16 +370,24 @@ def test_sliced_sphere_values_on_the_sphere(): n = 100 rng = np.random.RandomState(0) + u = ot.utils.unif(n) + x = rng.randn(n, 3) x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + # dimension problem y = rng.randn(n, 4) + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) - u = ot.utils.unif(n) - + # not on the sphere + y = rng.randn(n, 3) with pytest.raises(ValueError): _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(y, x, u, u, 10, seed=rng) + def test_sliced_sphere_log(): n = 100 @@ -506,3 +538,162 @@ def test_sliced_sphere_unif_backend_type_devices(nx): valb = ot.sliced_wasserstein_sphere_unif(xb) nx.assert_same_dtype_device(xb, valb) + + +def test_linear_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.linear_sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.0) + + +def test_linear_sliced_sphere_same_proj(): + n_projections = 10 + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + seed = 42 + + cost1, log1 = ot.linear_sliced_wasserstein_sphere( + x, y, seed=seed, n_projections=n_projections, log=True + ) + cost2, log2 = ot.linear_sliced_wasserstein_sphere( + x, y, seed=seed, n_projections=n_projections, log=True + ) + + assert np.allclose(log1["projections"], log2["projections"]) + assert np.isclose(cost1, cost2) + + +def test_linear_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_linear_sliced_sphere_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + u = ot.utils.unif(n) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + # shape problem + y = rng.randn(n, 4) + + with pytest.raises(ValueError): + _ = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + # not on sphere + y = rng.randn(n, 3) + + with pytest.raises(ValueError): + _ = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + with pytest.raises(ValueError): + _ = ot.linear_sliced_wasserstein_sphere(y, x, u, u, 10, seed=rng) + + +def test_linear_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_linear_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0.0 + + +def test_1d_linear_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.linear_sliced_wasserstein_sphere(x, y, a, u, 100, seed=42) + expected = ot.linear_circular_ot(x_coords.T, y_coords.T, a, u) + + np.testing.assert_almost_equal(res**2, expected, decimal=5) + + +@pytest.skip_backend("tf") +def test_linear_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + sw_np, log = ot.linear_sliced_wasserstein_sphere(x, y, log=True) + P = log["projections"] + + for tp in nx.__type_list__: + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.linear_sliced_wasserstein_sphere( + xb, yb, projections=nx.from_numpy(P, type_as=tp) + ) + + nx.assert_same_dtype_device(xb, valb) + np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb))