From 14e2e0d4fe07bd0620efa6bc933e872765932798 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 12 Sep 2025 16:32:31 +0200 Subject: [PATCH 01/19] first implementation of sliced ot plans with example --- README.md | 7 + RELEASES.md | 1 + .../sliced-wasserstein/plot_sliced_plans.py | 123 ++++++++++ ot/__init__.py | 4 + ot/sliced.py | 230 +++++++++++++++++- 5 files changed, 363 insertions(+), 2 deletions(-) create mode 100644 examples/sliced-wasserstein/plot_sliced_plans.py diff --git a/README.md b/README.md index 3a14474d2..28089b487 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ POT provides the following generic OT solvers: * Fused unbalanced Gromov-Wasserstein [70]. * [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77] * [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77] +* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [81, 82, 83] POT provides the following Machine Learning related solvers: @@ -446,3 +447,9 @@ Artificial Intelligence. [79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations. [80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019. + +[81] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385. + +[82] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661. + +[83] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 2b0f47eae..95d17cedf 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -27,6 +27,7 @@ - Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743) - Removed release information from quickstart guide (PR #744) - Update REAMDE with new API and reorganize examples (PR #754) +- Added Sliced OT plans (PR #757) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py new file mode 100644 index 000000000..dced356d8 --- /dev/null +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +""" +=============== +Sliced OT Plans +=============== + +Compares different Sliced OT plans between two 2D point clouds. The min-Pivot +Sliced plan was introduced in [81], and the Expected Sliced plan in [83], both +were further studied theoretically in [82]. + +.. [81] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + +.. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + +.. [83] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. +""" + +# Author: Eloi Tanguy +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +############################################################################## +# Setup data and imports +# ---------------------- +# %% +import numpy as np +import ot +from time import time +import matplotlib.pyplot as plt +from ot.sliced import get_random_projections + +seed = 0 +np.random.seed(seed) +n = 10 +d = 2 +X = np.random.randn(n, 2) +Y = np.random.randn(n, 2) + np.array([5.0, 0.0])[None, :] +n_proj = 20 +thetas = get_random_projections(d, n_proj).T +alpha = 0.3 + +############################################################################## +# Compute min-Pivot Sliced permutation +# ------------------------------------ +# %% +t = time() +min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True) +min_time = time() - t + +############################################################################## +# Compute Expected Sliced Plan +# ------------------------------------ +# %% +t = time() +expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True) +expected_time = time() - t + +############################################################################## +# Compute 2-Wasserstein Plan +# ------------------------------------ +# %% +a = np.ones(n, device=X.device) / n +t = time() +dists = ot.dist(X, Y) +W2 = ot.emd2(a, a, dists) +W2_plan = ot.emd(a, a, dists) +W2_time = time() - t + +############################################################################## +# Plot resulting assignments +# ------------------------------------ +# %% +fig, axs = plt.subplots(1, 3, figsize=(12, 4)) +fig.suptitle("Sliced plans comparison", y=0.85, fontsize=16) + +# draw min sliced permutation +axs[0].set_title(f"Min Pivot Sliced, cost={min_cost:.2f}" f", time={min_time:.2e}s") +for i in range(n): + axs[0].plot( + [X[i, 0], Y[min_perm[i], 0]], + [X[i, 1], Y[min_perm[i], 1]], + color="black", + alpha=alpha, + label="min-Sliced perm" if i == 0 else None, + ) + +# draw expected sliced plan +axs[1].set_title( + f"Expected Sliced, cost={expected_cost:.2f}" f", time={expected_time:.2e}s" +) +for i in range(n): + for j in range(n): + w = alpha * expected_plan[i, j].item() * n + axs[1].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="Expected Sliced plan" if i == 0 and j == 0 else None, + ) + +# draw W2 plan +axs[2].set_title(f"W2, cost={W2:.2f}" f", time={W2_time:.2e}s") +for i in range(n): + for j in range(n): + w = alpha * W2_plan[i, j].item() * n + axs[2].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="W2 plan" if i == 0 and j == 0 else None, + ) + +for ax in axs: + ax.scatter(X[:, 0], X[:, 1], label="X") + ax.scatter(Y[:, 0], Y[:, 1], label="Y") + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() diff --git a/ot/__init__.py b/ot/__init__.py index 1aad7828a..e2c6cf31e 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -59,6 +59,8 @@ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif, linear_sliced_wasserstein_sphere, + min_pivot_sliced, + expected_sliced, ) from .gromov import ( gromov_wasserstein, @@ -108,6 +110,8 @@ "sliced_wasserstein_distance", "sliced_wasserstein_sphere", "linear_sliced_wasserstein_sphere", + "min_pivot_sliced", + "expected_sliced", "gromov_wasserstein", "gromov_wasserstein2", "gromov_barycenters", diff --git a/ot/sliced.py b/ot/sliced.py index 3cf2002e7..b222130f9 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,17 +1,17 @@ """ Sliced OT Distances - """ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary +# Eloi Tanguy # # License: MIT License import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array, get_coordinate_circle +from .utils import list_to_array, get_coordinate_circle, dist from .lp import ( wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -674,3 +674,229 @@ def linear_sliced_wasserstein_sphere( if log: return res, {"projections": projections, "projected_emds": projected_lcot} return res + + +def sliced_permutations(X, Y, thetas=None, n_proj=None, log=False, backend=None): + r""" + Computes all the permutations that sort the projections of two `(n, d)` + datasets `X` and `Y` on the directions `thetas`. + Each permutation `perm[:, k]` is such that each `X[i, :]` is matched + to `Y[perm[i, k], :]` when projected on `thetas[k, :]`. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (n, d) + The second set of vectors. + thetas : array-like, shape (n_proj, d), optional + The projection directions. If None, random directions will be generated. + Default is None. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + log : bool, optional + If True, returns additional logging information. Default is False. + backend : ot.backend, optional + Backend to use for computations. If None, the backend is inferred from the input arrays. Default is None. + + Returns + ------- + perm : array-like, shape (n, n_proj) + All sliced permutations. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + """ + nx = get_backend(X, Y) if backend is None else backend + d = X.shape[1] + do_draw_thetas = thetas is None + if do_draw_thetas: # create thetas (n_proj, d) + thetas = get_random_projections(d, n_proj, backend=nx).T + + # project on each theta: (n, d) -> (n, n_proj) + X_theta = X @ thetas.T # shape (n, n_proj) + Y_theta = Y @ thetas.T # shape (n, n_proj) + + # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] + sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) + tau = nx.argsort(Y_theta, axis=0) # (n, n_proj) + + # perm[:, i_proj] is tau[:, i_proj] o sigma[:, i_proj]^{-1} + perm = nx.take_along_axis(tau, nx.argsort(sigma, axis=0), axis=0) # (n, n_proj) + + if log: + log_dict = { + "X_theta": X_theta, + "Y_theta": Y_theta, + "sigma": sigma, + "tau": tau, + "perm": perm, + } + if do_draw_thetas: + log_dict["thetas"] = thetas + return perm, log_dict + else: + return perm + + +def min_pivot_sliced( + X, Y, thetas=None, order=2, n_proj=None, log=False, warm_perm=None +): + r""" + Computes the cost and permutation associated to the min-Pivot Sliced + Discrepancy (introduced as SWGG in [81] and studied further in [82]). Given + the supports `X` and `Y` of two discrete uniform measures with `n` atoms in + dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj` + different projections of the measures on random directions, and retains the + permutation that yields the lowest cost between `X` and `Y` (compared + in :math:`\mathbb{R}^d`). + + .. math:: + \mathrm{min\text{-}PS}_p^p(X, Y) \approx + \min_{k \in [1, n_{\mathrm{proj}}]} \left( + \frac{1}{n} \sum_{i=1}^n \|X_i - Y_{\sigma_k(i)}\|_2^p \right), + + where :math:`\sigma_k` is a permutation such that ordering the projections + on the axis `thetas[k, :]` matches `X[i, :]` to `Y[\sigma_k(i), :]`. + + .. note:: + The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (n, d) + The second set of vectors. + thetas : array-like, shape (n_proj, d), optional + The projection directions. If None, random directions will be generated. Default is None. + order : int, optional + Power to elevate the norm. Default is 2. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + log : bool, optional + If True, returns additional logging information. Default is False. + warm_perm : array-like, shape (n,), optional + A permutation to add to the permutation list. Default is None. + + Returns + ------- + perm : array-like, shape (n,) + The permutation that minimizes the cost. + min_cost : float + The minimum cost corresponding to the optimal permutation. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + + References + ---------- + .. [81] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + + .. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + """ + n = X.shape[0] + nx = get_backend(X, Y) + log_dict = {} + + if log: + perm, log_dict = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=True, backend=nx + ) + else: + perm = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=False, backend=nx + ) + + # add the 'warm perm' to permutations to test + if warm_perm is not None: + perm = nx.concatenate([perm, warm_perm[:, None]], dim=1) + if log: + log_dict["perm"] = perm + + min_cost = None + idx_min_cost = None + costs = [] + + for k in range(perm.shape[-1]): + cost = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n + if min_cost is None or cost < min_cost: + min_cost = cost + idx_min_cost = k + if log: + costs.append(cost) + + min_perm = perm[:, idx_min_cost] + + if log: + log_dict["costs"] = costs + log_dict["idx_min_cost"] = idx_min_cost + return min_perm, min_cost, log_dict + else: + return min_perm, min_cost + + +def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False): + r""" + Computes the Expected Sliced cost and plan between two `(n, d)` + datasets `X` and `Y`. Given a set of `n_proj` projection directions, + the expected sliced plan is obtained by averaging the `n_proj` 1d optimal + transport plans between the projections of `X` and `Y` on each direction. + Expected Sliced was introduced in [83] and further studied in [82]. + + .. note:: + The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + + Parameters + ---------- + X : torch.Tensor + A tensor of shape (n, d) representing the first set of vectors. + Y : torch.Tensor + A tensor of shape (n, d) representing the second set of vectors. + thetas : torch.Tensor, optional + A tensor of shape (n_proj, d) representing the projection directions. + If None, random directions will be generated. Default is None. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + order : int, optional + Power to elevate the norm. Default is 2. + log : bool, optional + If True, returns additional logging information. Default is False. + + Returns + ------- + plan : torch.Tensor + A tensor of shape (n_proj, n, n) representing the expected sliced plan. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + + References + ---------- + .. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + + .. [83] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. + """ + nx = get_backend(X, Y) + n = X.shape[0] + log_dict = {} + if log: + perm, log_dict = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx + ) + else: + perm = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx + ) + plan = nx.zeros((n, n), type_as=X) + n_proj = perm.shape[1] + range_array = nx.arange(n, type_as=X) + for k in range(n_proj): + plan[range_array, perm[:, k]] += 1 / (n_proj * n) + + cost = (dist(X, Y, p=order) * plan).sum() + + if log: + return plan, cost, log_dict + else: + return plan, cost From 2911166f63ccfa5ab43f34534ec1643d086cb26e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 16 Sep 2025 18:08:28 +0200 Subject: [PATCH 02/19] tests + temperature option in expected-sliced --- .../sliced-wasserstein/plot_sliced_plans.py | 86 +++++++++--- ot/sliced.py | 31 ++++- test/test_sliced.py | 131 ++++++++++++++++++ 3 files changed, 226 insertions(+), 22 deletions(-) diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py index dced356d8..0ea65fb2c 100644 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -26,7 +26,6 @@ # %% import numpy as np import ot -from time import time import matplotlib.pyplot as plt from ot.sliced import get_random_projections @@ -44,78 +43,129 @@ # Compute min-Pivot Sliced permutation # ------------------------------------ # %% -t = time() min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True) -min_time = time() - t +min_plan = np.zeros((n, n)) +min_plan[np.arange(n), min_perm] = 1 / n ############################################################################## # Compute Expected Sliced Plan # ------------------------------------ # %% -t = time() expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True) -expected_time = time() - t ############################################################################## # Compute 2-Wasserstein Plan # ------------------------------------ # %% a = np.ones(n, device=X.device) / n -t = time() dists = ot.dist(X, Y) W2 = ot.emd2(a, a, dists) W2_plan = ot.emd(a, a, dists) -W2_time = time() - t ############################################################################## # Plot resulting assignments # ------------------------------------ # %% -fig, axs = plt.subplots(1, 3, figsize=(12, 4)) -fig.suptitle("Sliced plans comparison", y=0.85, fontsize=16) +fig, axs = plt.subplots(2, 3, figsize=(12, 4)) +fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16) # draw min sliced permutation -axs[0].set_title(f"Min Pivot Sliced, cost={min_cost:.2f}" f", time={min_time:.2e}s") +axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") for i in range(n): - axs[0].plot( + axs[0, 0].plot( [X[i, 0], Y[min_perm[i], 0]], [X[i, 1], Y[min_perm[i], 1]], color="black", alpha=alpha, label="min-Sliced perm" if i == 0 else None, ) +axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues") # draw expected sliced plan -axs[1].set_title( - f"Expected Sliced, cost={expected_cost:.2f}" f", time={expected_time:.2e}s" -) +axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}") for i in range(n): for j in range(n): w = alpha * expected_plan[i, j].item() * n - axs[1].plot( + axs[0, 1].plot( [X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], color="black", alpha=w, label="Expected Sliced plan" if i == 0 and j == 0 else None, ) +axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues") # draw W2 plan -axs[2].set_title(f"W2, cost={W2:.2f}" f", time={W2_time:.2e}s") +axs[0, 2].set_title(f"W2: cost={W2:.2f}") for i in range(n): for j in range(n): w = alpha * W2_plan[i, j].item() * n - axs[2].plot( + axs[0, 2].plot( [X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], color="black", alpha=w, label="W2 plan" if i == 0 and j == 0 else None, ) +axs[1, 2].imshow(W2_plan, interpolation="nearest", cmap="Blues") -for ax in axs: +for ax in axs[0, :]: ax.scatter(X[:, 0], X[:, 1], label="X") ax.scatter(Y[:, 0], Y[:, 1], label="Y") + +for ax in axs.flatten(): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() + +############################################################################## +# Compare Expected Sliced plans with different inverse-temperatures beta +# ------------------------------------ +# %% As the temperature decreases, ES becomes sparser and approaches minPS +betas = [0.0, 5.0, 50.0] +n_plots = len(betas) + 1 +size = 4 +fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size)) +fig.suptitle( + "Expected Sliced plan varying beta (inverse temperature)", y=0.95, fontsize=16 +) +for beta_idx, beta in enumerate(betas): + expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas, beta=beta) + print(f"beta={beta}: cost={expected_cost:.2f}") + + axs[0, beta_idx].set_title(f"beta={beta}: cost={expected_cost:.2f}") + for i in range(n): + for j in range(n): + w = alpha * expected_plan[i, j].item() * n + axs[0, beta_idx].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="Expected Sliced plan" if i == 0 and j == 0 else None, + ) + + axs[0, beta_idx].scatter(X[:, 0], X[:, 1], label="X") + axs[0, beta_idx].scatter(Y[:, 0], Y[:, 1], label="Y") + axs[1, beta_idx].imshow(expected_plan, interpolation="nearest", cmap="Blues") + +# draw min sliced permutation (limit when beta -> +inf) +axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") +for i in range(n): + axs[0, -1].plot( + [X[i, 0], Y[min_perm[i], 0]], + [X[i, 1], Y[min_perm[i], 1]], + color="black", + alpha=alpha, + label="min-Sliced perm" if i == 0 else None, + ) +axs[0, -1].scatter(X[:, 0], X[:, 1], label="X") +axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y") +axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues") + +for ax in axs.flatten(): ax.set_aspect("equal") ax.set_xticks([]) ax.set_yticks([]) diff --git a/ot/sliced.py b/ot/sliced.py index b222130f9..b8731a1c9 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -707,6 +707,9 @@ def sliced_permutations(X, Y, thetas=None, n_proj=None, log=False, backend=None) A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. """ + assert ( + X.shape == Y.shape + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" nx = get_backend(X, Y) if backend is None else backend d = X.shape[1] do_draw_thetas = thetas is None @@ -795,6 +798,9 @@ def min_pivot_sliced( .. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. """ + assert ( + X.shape == Y.shape + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" n = X.shape[0] nx = get_backend(X, Y) log_dict = {} @@ -810,7 +816,7 @@ def min_pivot_sliced( # add the 'warm perm' to permutations to test if warm_perm is not None: - perm = nx.concatenate([perm, warm_perm[:, None]], dim=1) + perm = nx.concatenate([perm, warm_perm[:, None]], axis=1) if log: log_dict["perm"] = perm @@ -836,7 +842,7 @@ def min_pivot_sliced( return min_perm, min_cost -def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False): +def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0): r""" Computes the Expected Sliced cost and plan between two `(n, d)` datasets `X` and `Y`. Given a set of `n_proj` projection directions, @@ -862,6 +868,8 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False): Power to elevate the norm. Default is 2. log : bool, optional If True, returns additional logging information. Default is False. + beta : float, optional + Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). Returns ------- @@ -877,6 +885,9 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False): .. [83] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. """ + assert ( + X.shape == Y.shape + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" nx = get_backend(X, Y) n = X.shape[0] log_dict = {} @@ -891,8 +902,20 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False): plan = nx.zeros((n, n), type_as=X) n_proj = perm.shape[1] range_array = nx.arange(n, type_as=X) - for k in range(n_proj): - plan[range_array, perm[:, k]] += 1 / (n_proj * n) + + if beta != 0.0: # computing the temperature weighting + log_factors = nx.zeros(n_proj, type_as=X) # for beta weighting + for k in range(n_proj): + cost_k = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n + log_factors[k] = -beta * cost_k + weights = nx.exp(log_factors - nx.logsumexp(log_factors)) + + else: # uniform weights + weights = nx.ones(n_proj, type_as=X) / n_proj + + for k in range(n_proj): # populating the expected plan + # 1 / n is because is a permutation of [1, n] + plan[range_array, perm[:, k]] += (1 / n) * weights[k] cost = (dist(X, Y, p=order) * plan).sum() diff --git a/test/test_sliced.py b/test/test_sliced.py index 05de13755..48009dfca 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -2,6 +2,7 @@ # Author: Adrien Corenflos # Nicolas Courty +# Eloi Tanguy # # License: MIT License @@ -110,6 +111,14 @@ def test_max_sliced_different_dists(): assert res > 0.0 +def test_max_sliced_dim_check(): + n = 3 + x = np.zeros((n, 2)) + y = np.zeros((n + 1, 3)) + with pytest.raises(ValueError): + _ = ot.max_sliced_wasserstein_distance(x, y, n_projections=10) + + def test_sliced_same_proj(): n_projections = 10 seed = 12 @@ -152,6 +161,16 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, 2 * n) + b /= b.sum() + a_b = nx.from_numpy(a) + b_b = nx.from_numpy(b) + val = ot.sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) + val_b = ot.sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) + np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) + def test_sliced_backend_type_devices(nx): n = 100 @@ -227,6 +246,16 @@ def test_max_sliced_backend(nx): assert np.allclose(val0, valb) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, 2 * n) + b /= b.sum() + a_b = nx.from_numpy(a) + b_b = nx.from_numpy(b) + val = ot.max_sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) + val_b = ot.max_sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) + np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) + def test_max_sliced_backend_type_devices(nx): n = 100 @@ -697,3 +726,105 @@ def test_linear_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) + + +def test_sliced_permutations(nx): + n = 4 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + perm = ot.sliced.sliced_permutations(x, y, thetas=thetas) + perm_b, _ = ot.sliced.sliced_permutations( + x_b, y_b, thetas=thetas_b, log=True, backend=nx + ) + + np.testing.assert_almost_equal(perm, nx.to_numpy(perm_b)) + + # test without provided thetas + perm = ot.sliced.sliced_permutations(x, y, n_proj=n_proj) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.sliced_permutations(x[1:, :], y, thetas=thetas) + + +def test_min_pivot_sliced(nx): + n = 10 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + min_perm, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) + min_perm_b, min_cost_b, _ = ot.sliced.min_pivot_sliced( + x_b, y_b, thetas=thetas_b, log=True + ) + + np.testing.assert_almost_equal(min_perm, nx.to_numpy(min_perm_b)) + np.testing.assert_almost_equal(min_cost, nx.to_numpy(min_cost_b)) + + # result should be an upper-bound of W2 and relatively close + w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + assert min_cost >= w2 + assert min_cost <= 1.5 * w2 + + # test without provided thetas and with a warm permutation + ot.sliced.min_pivot_sliced(x, y, n_proj=n_proj, warm_perm=np.arange(n), log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + + +def test_expected_sliced(nx): + n = 10 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas) + expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( + x_b, y_b, thetas=thetas_b, log=True + ) + + np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) + np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) + + # result should be a coarse upper-bound of W2 + w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + assert expected_cost >= w2 + assert expected_cost <= 3 * w2 + + # test without provided thetas + ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + + # with a small temperature (i.e. large beta), + # the cost should be close to min_pivot + _, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) + np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) From 134846cc2e38c8c24ebd96eb4d10a6869b95de59 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 16 Sep 2025 18:12:49 +0200 Subject: [PATCH 03/19] fixed cell rendering in example --- examples/sliced-wasserstein/plot_sliced_plans.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py index 0ea65fb2c..9a2c6cd66 100644 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -23,7 +23,6 @@ ############################################################################## # Setup data and imports # ---------------------- -# %% import numpy as np import ot import matplotlib.pyplot as plt @@ -42,7 +41,6 @@ ############################################################################## # Compute min-Pivot Sliced permutation # ------------------------------------ -# %% min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True) min_plan = np.zeros((n, n)) min_plan[np.arange(n), min_perm] = 1 / n @@ -50,13 +48,11 @@ ############################################################################## # Compute Expected Sliced Plan # ------------------------------------ -# %% expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True) ############################################################################## # Compute 2-Wasserstein Plan # ------------------------------------ -# %% a = np.ones(n, device=X.device) / n dists = ot.dist(X, Y) W2 = ot.emd2(a, a, dists) @@ -65,7 +61,6 @@ ############################################################################## # Plot resulting assignments # ------------------------------------ -# %% fig, axs = plt.subplots(2, 3, figsize=(12, 4)) fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16) @@ -123,7 +118,7 @@ ############################################################################## # Compare Expected Sliced plans with different inverse-temperatures beta # ------------------------------------ -# %% As the temperature decreases, ES becomes sparser and approaches minPS +## As the temperature decreases, ES becomes sparser and approaches minPS betas = [0.0, 5.0, 50.0] n_plots = len(betas) + 1 size = 4 From 414a2958e358a74a9adb92ef8e4e2daf8e56d137 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Sep 2025 10:45:39 +0200 Subject: [PATCH 04/19] ref number update --- README.md | 2 +- examples/sliced-wasserstein/plot_sliced_plans.py | 10 +++++----- ot/sliced.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4a1688a7f..ed6f2a89c 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ POT provides the following generic OT solvers: * Fused unbalanced Gromov-Wasserstein [70]. * [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77] * [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77] -* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [81, 82, 83] +* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [82, 83, 84] POT provides the following Machine Learning related solvers: diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py index 9a2c6cd66..ca7b35a3f 100644 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -5,14 +5,14 @@ =============== Compares different Sliced OT plans between two 2D point clouds. The min-Pivot -Sliced plan was introduced in [81], and the Expected Sliced plan in [83], both -were further studied theoretically in [82]. +Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both +were further studied theoretically in [83]. -.. [81] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. +.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. -.. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. +.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. -.. [83] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. +.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. """ # Author: Eloi Tanguy diff --git a/ot/sliced.py b/ot/sliced.py index b8731a1c9..8e155cb9f 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -747,7 +747,7 @@ def min_pivot_sliced( ): r""" Computes the cost and permutation associated to the min-Pivot Sliced - Discrepancy (introduced as SWGG in [81] and studied further in [82]). Given + Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given the supports `X` and `Y` of two discrete uniform measures with `n` atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj` different projections of the measures on random directions, and retains the @@ -794,9 +794,9 @@ def min_pivot_sliced( References ---------- - .. [81] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + .. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. - .. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. """ assert ( X.shape == Y.shape @@ -848,7 +848,7 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 datasets `X` and `Y`. Given a set of `n_proj` projection directions, the expected sliced plan is obtained by averaging the `n_proj` 1d optimal transport plans between the projections of `X` and `Y` on each direction. - Expected Sliced was introduced in [83] and further studied in [82]. + Expected Sliced was introduced in [84] and further studied in [83]. .. note:: The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. @@ -881,9 +881,9 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 References ---------- - .. [82] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. - .. [83] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. + .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. """ assert ( X.shape == Y.shape From 778bc73b75d80b42992f66ec6585254ca182a219 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Sep 2025 11:19:14 +0200 Subject: [PATCH 05/19] skip jax and tf in expected sliced testing due to array assignment --- test/test_sliced.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_sliced.py b/test/test_sliced.py index 48009dfca..6e220238a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -790,6 +790,8 @@ def test_min_pivot_sliced(nx): ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") def test_expected_sliced(nx): n = 10 n_proj = 10 From 08c5348d0a9d54ce7ad70ea764283213304da78c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 22 Sep 2025 19:21:51 +0200 Subject: [PATCH 06/19] raise NotImplementedError when expected_sliced is used with tf or jax --- ot/sliced.py | 10 ++++++++++ test/test_sliced.py | 48 +++++++++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/ot/sliced.py b/ot/sliced.py index 8e155cb9f..cc79fa968 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -853,6 +853,9 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 .. note:: The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + .. warning:: + The function runs on backend but tensorflow and jax are not supported due to array assignment. + Parameters ---------- X : torch.Tensor @@ -888,8 +891,15 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 assert ( X.shape == Y.shape ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" + nx = get_backend(X, Y) + if str(nx) in ["tf", "jax"]: + raise NotImplementedError( + f"expected_sliced is not implemented for the {str(nx)} backend due" + "to array assignment." + ) n = X.shape[0] + log_dict = {} if log: perm, log_dict = sliced_permutations( diff --git a/test/test_sliced.py b/test/test_sliced.py index 6e220238a..7f12d378a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -12,6 +12,7 @@ import ot from ot.sliced import get_random_projections from ot.backend import tf, torch +from contextlib import nullcontext def test_get_random_projections(): @@ -790,8 +791,6 @@ def test_min_pivot_sliced(nx): ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) -@pytest.skip_backend("tf") # skips because of array assignment -@pytest.skip_backend("jax") def test_expected_sliced(nx): n = 10 n_proj = 10 @@ -805,28 +804,35 @@ def test_expected_sliced(nx): thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T thetas_b = nx.from_numpy(thetas) - expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas) - expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( - x_b, y_b, thetas=thetas_b, log=True + context = ( + nullcontext() + if str(nx) not in ["tf", "jax"] + else pytest.raises(NotImplementedError) ) - np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) - np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) + with context: + expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas) + expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( + x_b, y_b, thetas=thetas_b, log=True + ) - # result should be a coarse upper-bound of W2 - w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) - assert expected_cost >= w2 - assert expected_cost <= 3 * w2 + np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) + np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) - # test without provided thetas - ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + # result should be a coarse upper-bound of W2 + w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + assert expected_cost >= w2 + assert expected_cost <= 3 * w2 - # test with invalid shapes - with pytest.raises(AssertionError): - ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + # test without provided thetas + ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) - # with a small temperature (i.e. large beta), - # the cost should be close to min_pivot - _, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0) - _, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) - np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) + # with a small temperature (i.e. large beta), the cost should be close + # to min_pivot + _, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) + np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) From 40a978ea4a9887ca32ee32eca1f8a19c728359e8 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Tue, 7 Oct 2025 13:38:25 +0200 Subject: [PATCH 07/19] update with n\neq m and a\neq b --- .../sliced-wasserstein/plot_sliced_plans.py | 79 ++-- ot/backend.py | 19 + ot/lp/solver_1d.py | 55 ++- ot/sliced.py | 399 +++++++++++++----- 4 files changed, 399 insertions(+), 153 deletions(-) diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py index ca7b35a3f..43703e57e 100644 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -24,39 +24,48 @@ # Setup data and imports # ---------------------- import numpy as np + import ot import matplotlib.pyplot as plt from ot.sliced import get_random_projections +from ot.lp import wasserstein_1d + seed = 0 np.random.seed(seed) -n = 10 +n = 20 +m = 10 d = 2 X = np.random.randn(n, 2) -Y = np.random.randn(n, 2) + np.array([5.0, 0.0])[None, :] -n_proj = 20 +Y = np.random.randn(m, 2) + np.array([5.0, 0.0])[None, :] +n_proj = 50 thetas = get_random_projections(d, n_proj).T alpha = 0.3 + +proj_X = X @ thetas.T +proj_Y = Y @ thetas.T + + ############################################################################## # Compute min-Pivot Sliced permutation # ------------------------------------ -min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True) -min_plan = np.zeros((n, n)) -min_plan[np.arange(n), min_perm] = 1 / n +min_plan, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas=thetas, log=True) ############################################################################## # Compute Expected Sliced Plan # ------------------------------------ -expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True) - +expected_plan, expected_cost, log_expected = ot.expected_sliced( + X, Y, thetas=thetas, log=True +) ############################################################################## # Compute 2-Wasserstein Plan # ------------------------------------ a = np.ones(n, device=X.device) / n +b = np.ones(m, device=Y.device) / m dists = ot.dist(X, Y) -W2 = ot.emd2(a, a, dists) -W2_plan = ot.emd(a, a, dists) +W2 = ot.emd2(a, b, dists) +W2_plan = ot.emd(a, b, dists) ############################################################################## # Plot resulting assignments @@ -66,20 +75,21 @@ # draw min sliced permutation axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") -for i in range(n): - axs[0, 0].plot( - [X[i, 0], Y[min_perm[i], 0]], - [X[i, 1], Y[min_perm[i], 1]], - color="black", - alpha=alpha, - label="min-Sliced perm" if i == 0 else None, - ) +for i in range(X.shape[0]): + for j in range(Y.shape[0]): + if min_plan[i, j] > 0: + axs[0, 0].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=alpha, + ) axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues") # draw expected sliced plan axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}") for i in range(n): - for j in range(n): + for j in range(m): w = alpha * expected_plan[i, j].item() * n axs[0, 1].plot( [X[i, 0], Y[j, 0]], @@ -91,9 +101,9 @@ axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues") # draw W2 plan -axs[0, 2].set_title(f"W2: cost={W2:.2f}") +axs[0, 2].set_title(f"W$_2$: cost={W2:.2f}") for i in range(n): - for j in range(n): + for j in range(m): w = alpha * W2_plan[i, j].item() * n axs[0, 2].plot( [X[i, 0], Y[j, 0]], @@ -123,16 +133,17 @@ n_plots = len(betas) + 1 size = 4 fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size)) + fig.suptitle( - "Expected Sliced plan varying beta (inverse temperature)", y=0.95, fontsize=16 + "Expected Sliced plan varying $\\beta$ (inverse temperature)", y=0.95, fontsize=16 ) for beta_idx, beta in enumerate(betas): - expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas, beta=beta) + expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas=thetas, beta=beta) print(f"beta={beta}: cost={expected_cost:.2f}") - axs[0, beta_idx].set_title(f"beta={beta}: cost={expected_cost:.2f}") + axs[0, beta_idx].set_title(f"$\\beta$={beta}: cost={expected_cost:.2f}") for i in range(n): - for j in range(n): + for j in range(m): w = alpha * expected_plan[i, j].item() * n axs[0, beta_idx].plot( [X[i, 0], Y[j, 0]], @@ -148,14 +159,16 @@ # draw min sliced permutation (limit when beta -> +inf) axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") -for i in range(n): - axs[0, -1].plot( - [X[i, 0], Y[min_perm[i], 0]], - [X[i, 1], Y[min_perm[i], 1]], - color="black", - alpha=alpha, - label="min-Sliced perm" if i == 0 else None, - ) +for i in range(X.shape[0]): + for j in range(Y.shape[0]): + if min_plan[i, j] > 0: + axs[0, -1].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=alpha, + ) + axs[0, -1].scatter(X[:, 0], X[:, 1], label="X") axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y") axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues") diff --git a/ot/backend.py b/ot/backend.py index 64b5a88cf..549ce43c7 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -729,6 +729,16 @@ def stack(self, arrays, axis=0): """ raise NotImplementedError() + def unstack(self, arrays, axis=0): + r""" + Split an array into a sequence of arrays along the given axis. + + This function follows the api from :any:`numpy.unstack` + + See: https://numpy.org/doc/stable/reference/generated/numpy.unstack.html + """ + raise NotImplementedError() + def outer(self, a, b): r""" Computes the outer product between two vectors. @@ -1300,6 +1310,9 @@ def unique(self, a, return_inverse=False): def stack(self, arrays, axis=0): return np.stack(arrays, axis) + def unstack(self, arrays, axis=0): + return np.unstack(arrays, axis=axis) + def reshape(self, a, shape): return np.reshape(a, shape) @@ -1710,6 +1723,9 @@ def unique(self, a, return_inverse=False): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) + def unstack(self, arrays, axis=0): + return jnp.unstack(arrays, axis=axis) + def reshape(self, a, shape): return jnp.reshape(a, shape) @@ -2213,6 +2229,9 @@ def logsumexp(self, a, axis=None, keepdims=False): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) + def unstack(self, arrays, axis=0): + return torch.unbind(arrays, dim=axis) + def reshape(self, a, shape): return torch.reshape(a, shape) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 49e0c9c41..d09f8a46a 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -16,7 +16,7 @@ from ..utils import list_to_array -def quantile_function(qs, cws, xs): +def quantile_function(qs, cws, xs, idx_xs=None): r"""Computes the quantile function of an empirical distribution Parameters @@ -27,7 +27,8 @@ def quantile_function(qs, cws, xs): cumulative weights of the 1D empirical distribution, if batched, must be similar to xs xs: array-like, shape (n, ...) locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions - + idx_xs: array-like, shape (n, ...) + associated indices. If None, do not return them Returns ------- q: array-like, shape (..., n) @@ -44,11 +45,22 @@ def quantile_function(qs, cws, xs): cws = cws.T qs = qs.T idx = nx.searchsorted(cws, qs).T - return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + if idx_xs is not None: + return nx.take_along_axis( + xs, nx.clip(idx, 0, n - 1), axis=0 + ), nx.take_along_axis(idx_xs, nx.clip(idx, 0, n - 1), axis=0) + else: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) def wasserstein_1d( - u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + require_sort=True, + return_plan=False, ): r""" Computes the 1 dimensional OT loss [15] between two (batched) empirical @@ -79,7 +91,9 @@ def wasserstein_1d( require_sort: bool, optional sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True - + return_plan: bool, optional + if True, returns also the optimal transport plan between the two + (batched) measures as a coo_matrix, default is False Returns ------- cost: float/array-like, shape (...) @@ -124,15 +138,31 @@ def wasserstein_1d( v_cumweights = nx.cumsum(v_weights, 0) qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) - u_quantiles = quantile_function(qs, u_cumweights, u_values) - v_quantiles = quantile_function(qs, v_cumweights, v_values) + u_quantiles, u_quantiles_idx = quantile_function( + qs, u_cumweights, u_values, u_sorter + ) + v_quantiles, v_quantiles_idx = quantile_function( + qs, v_cumweights, v_values, v_sorter + ) qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) delta = qs[1:, ...] - qs[:-1, ...] diff_quantiles = nx.abs(u_quantiles - v_quantiles) - if p == 1: - return nx.sum(delta * diff_quantiles, axis=0) - return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + if return_plan: + plan = [ + nx.coo_matrix( + delta[:, k], + u_quantiles_idx[:, k], + v_quantiles_idx[:, k], + shape=(n, m), + type_as=u_values, + ) + for k in range(delta.shape[1]) + ] + if return_plan: + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0), plan + else: + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) def emd_1d( @@ -201,7 +231,8 @@ def emd_1d( gamma: ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log: dict - If input log is True, a dictionary containing the cost + If input log is True, a dictionary containing the cost and the indices + of the non-zero elements of the transportation matrix Examples @@ -297,6 +328,8 @@ def emd_1d( warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {"cost": nx.from_numpy(cost, type_as=x_a)} + log["perms_x_a"] = perm_a[indices[:, 0]] + log["perms_x_b"] = perm_b[indices[:, 1]] return G, log return G diff --git a/ot/sliced.py b/ot/sliced.py index cc79fa968..687adcae4 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -9,6 +9,8 @@ # # License: MIT License +import warnings + import numpy as np from .backend import get_backend, NumpyBackend from .utils import list_to_array, get_coordinate_circle, dist @@ -16,10 +18,12 @@ wasserstein_circle, semidiscrete_wasserstein2_unif_circle, linear_circular_ot, + wasserstein_1d, ) -def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): +def get_random_projections(d, n_projections, seed=None, backend=None, + type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})` @@ -595,7 +599,8 @@ def linear_sliced_wasserstein_sphere( X_s: ndarray, shape (n_samples_a, dim) Samples in the source domain X_t: ndarray, shape (n_samples_b, dim), optional - Samples in the target domain. If None, computes the distance against the uniform distribution on the sphere. + Samples in the target domain. If None, computes the distance against + the uniform distribution on the sphere. a : ndarray, shape (n_samples_a,), optional samples weights in the source domain b : ndarray, shape (n_samples_b,), optional @@ -607,7 +612,8 @@ def linear_sliced_wasserstein_sphere( seed: int or RandomState or None, optional Seed used for random number generator log: bool, optional - if True, linear_sliced_wasserstein_sphere returns the projections used and their associated LCOT. + if True, linear_sliced_wasserstein_sphere returns the projections used + and their associated LCOT. Returns ------- @@ -628,7 +634,10 @@ def linear_sliced_wasserstein_sphere( .. _references-lssot: References ---------- - .. [79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data. International Conference on Learning Representations. + .. [79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, + B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal + Transport: A Fast Metric for Comparing Spherical Data. International + Conference on Learning Representations. """ d = X_s.shape[-1] @@ -639,9 +648,8 @@ def linear_sliced_wasserstein_sphere( if X_s.shape[1] != X_t.shape[1]: raise ValueError( - "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( - X_s.shape[1], X_t.shape[1] - ) + "X_s and X_t must have the same number of dimensions {} and {} \ + respectively given".format(X_s.shape[1], X_t.shape[1]) ) if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)): raise ValueError("X_s is not on the sphere.") @@ -654,7 +662,8 @@ def linear_sliced_wasserstein_sphere( ) Xps_coords, _ = projection_sphere_to_circle( - X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx + X_s, n_projections=n_projections, projections=projections, seed=seed, + backend=nx ) if X_t is not None: @@ -672,11 +681,24 @@ def linear_sliced_wasserstein_sphere( res = nx.mean(projected_lcot) ** (1 / 2) if log: - return res, {"projections": projections, "projected_emds": projected_lcot} + return res, {"projections": projections, + "projected_emds": projected_lcot} return res -def sliced_permutations(X, Y, thetas=None, n_proj=None, log=False, backend=None): +def sliced_plans( + X, + Y, + a=None, + b=None, + metric="sqeuclidean", + p=2, + thetas=None, + warm_theta=False, + n_proj=None, + log=False, + backend=None, +): r""" Computes all the permutations that sort the projections of two `(n, d)` datasets `X` and `Y` on the directions `thetas`. @@ -687,63 +709,156 @@ def sliced_permutations(X, Y, thetas=None, n_proj=None, log=False, backend=None) ---------- X : array-like, shape (n, d) The first set of vectors. - Y : array-like, shape (n, d) + Y : array-like, shape (m, d) The second set of vectors. + a : ndarray of float64, shape (ns,), optional + Source histogram (default is uniform weight) + b : ndarray of float64, shape (nt,), optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' thetas : array-like, shape (n_proj, d), optional - The projection directions. If None, random directions will be generated. + The projection directions. If None, random directions will be + generated. Default is None. + warm_theta : array-like, shape (d,), optional + A direction to add to the set of directions. Default is None. n_proj : int, optional The number of projection directions. Required if thetas is None. log : bool, optional If True, returns additional logging information. Default is False. backend : ot.backend, optional - Backend to use for computations. If None, the backend is inferred from the input arrays. Default is None. + Backend to use for computations. If None, the backend is inferred from + the input arrays. Default is None. Returns ------- - perm : array-like, shape (n, n_proj) - All sliced permutations. + G, sigma, tau, costs + G: ndarray, shape (ns, nt) or coo_matrix if dense is False + Optimal transportation matrix for the given parameters + sigma : list of elements of array-like + All the indices of X sorted along each projection. + tau : list of elements of array-like + All the indices of Y sorted along each projection. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. """ + assert ( - X.shape == Y.shape - ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" - nx = get_backend(X, Y) if backend is None else backend + X.shape[1] == Y.shape[1] + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" + if metric == "euclidean": + p = 2 + elif metric == "cityblock": + p = 1 + d = X.shape[1] + n = X.shape[0] + m = Y.shape[0] + nx = get_backend(X, Y) if backend is None else backend + do_draw_thetas = thetas is None if do_draw_thetas: # create thetas (n_proj, d) + assert n_proj is not None, "n_proj must be specified if thetas is None" thetas = get_random_projections(d, n_proj, backend=nx).T + if warm_theta is not None: + thetas = nx.concatenate([thetas, warm_theta[:, None]], axis=1) + else: + n_proj = thetas.shape[0] - # project on each theta: (n, d) -> (n, n_proj) + # project on each theta: (n or m, d) -> (n or m, n_proj) X_theta = X @ thetas.T # shape (n, n_proj) - Y_theta = Y @ thetas.T # shape (n, n_proj) + Y_theta = Y @ thetas.T # shape (m, n_proj) + + if n == m and (a is None or b is None or (a == b).all()): + # we compute maps (permutations) + # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] + sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) + tau = nx.argsort(Y_theta, axis=0) # (m, n_proj) + + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ((nx.sum(nx.abs(X[sigma[:, k]] - Y[tau[:, k]]) ** p, + axis=1)) ** (1 / p)) / n + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum((nx.sum((X[sigma[:, k]] - Y[tau[:, k]]) ** 2, + axis=1)) / n) + for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + ) + + G = [ + nx.coo_matrix( + np.ones(n) / n, + sigma[:, k], + tau[:, k], + shape=(n, m), + type_as=X_theta, + ) + for k in range(n_proj) + ] - # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] - sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) - tau = nx.argsort(Y_theta, axis=0) # (n, n_proj) + else: # we compute plans + _, G = wasserstein_1d( + X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True + ) - # perm[:, i_proj] is tau[:, i_proj] o sigma[:, i_proj]^{-1} - perm = nx.take_along_axis(tau, nx.argsort(sigma, axis=0), axis=0) # (n, n_proj) + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ( + (nx.sum(nx.abs(X[G[k].row] - Y[G[k].col]) ** p, + axis=1)) ** (1 / p) + ) * G[k].data + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum((nx.sum((X[G[k].row] - Y[G[k].col]) ** 2, axis=1)) * + G[k].data) for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + ) if log: - log_dict = { - "X_theta": X_theta, - "Y_theta": Y_theta, - "sigma": sigma, - "tau": tau, - "perm": perm, - } - if do_draw_thetas: - log_dict["thetas"] = thetas - return perm, log_dict + log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas} + return costs, G, log_dict else: - return perm + return costs, G def min_pivot_sliced( - X, Y, thetas=None, order=2, n_proj=None, log=False, warm_perm=None + X, + Y, + a=None, + b=None, + thetas=None, + metric="sqeuclidean", + p=2, + n_proj=None, + dense=True, + log=False, + warm_theta=None, + backend=None, ): r""" Computes the cost and permutation associated to the min-Pivot Sliced @@ -763,7 +878,11 @@ def min_pivot_sliced( on the axis `thetas[k, :]` matches `X[i, :]` to `Y[\sigma_k(i), :]`. .. note:: - The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + The computation ignores potential ambiguities in the projections: if + two points from a same measure have the same projection on a direction, + then multiple sorting permutations are possible. To avoid combinatorial + explosion, only one permutation is retained: this strays from theory in + pathological cases. Parameters ---------- @@ -771,16 +890,31 @@ def min_pivot_sliced( The first set of vectors. Y : array-like, shape (n, d) The second set of vectors. + a : ndarray of float64, shape (ns,), optional + Source histogram (default is uniform weight) + b : ndarray of float64, shape (nt,), optional + Target histogram (default is uniform weight) thetas : array-like, shape (n_proj, d), optional - The projection directions. If None, random directions will be generated. Default is None. - order : int, optional - Power to elevate the norm. Default is 2. + The projection directions. If None, random directions will be generated + Default is None. + metric: str, optional (default='sqeuclidean') + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' n_proj : int, optional The number of projection directions. Required if thetas is None. + dense: boolean, optional (default=True) + If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. log : bool, optional If True, returns additional logging information. Default is False. - warm_perm : array-like, shape (n,), optional - A permutation to add to the permutation list. Default is None. + warm_theta : array-like, shape (d,), optional + A theta to add to the list of thetas. Default is None. + backend : ot.backend, optional + Backend to use for computations. If None, the backend is inferred from + the input arrays. Default is None. Returns ------- @@ -794,55 +928,75 @@ def min_pivot_sliced( References ---------- - .. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + .. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). + Fast Optimal Transport through Sliced Generalized Wasserstein + Geodesics. Advances in Neural Information Processing Systems, 36, + 35350–35385. - .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport + Plans. arXiv preprint 2506.03661. """ - assert ( - X.shape == Y.shape - ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" - n = X.shape[0] - nx = get_backend(X, Y) - log_dict = {} - if log: - perm, log_dict = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=True, backend=nx - ) - else: - perm = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=False, backend=nx - ) + assert ( + X.shape[1] == Y.shape[1] + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" - # add the 'warm perm' to permutations to test - if warm_perm is not None: - perm = nx.concatenate([perm, warm_perm[:, None]], axis=1) - if log: - log_dict["perm"] = perm + nx = get_backend(X, Y) if backend is None else backend - min_cost = None - idx_min_cost = None - costs = [] + log_dict = {} + costs, G, log_dict_plans = sliced_plans( + X, + Y, + a, + b, + metric, + p, + thetas, + n_proj=n_proj, + warm_theta=warm_theta, + log=True, + backend=nx, + ) + pos_min = np.argmin(costs) + cost = costs[pos_min] + plan = G[pos_min] - for k in range(perm.shape[-1]): - cost = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n - if min_cost is None or cost < min_cost: - min_cost = cost - idx_min_cost = k - if log: - costs.append(cost) + if log: + log_dict = { + "thetas": log_dict_plans["thetas"], + "costs": costs, + "min_theta": log_dict_plans["thetas"][pos_min], + "X_min_theta": log_dict_plans["X_theta"][:, pos_min], + "Y_min_theta": log_dict_plans["Y_theta"][:, pos_min], + } - min_perm = perm[:, idx_min_cost] + if dense: + plan = nx.todense(plan) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to\ + dense") + plan = nx.todense(plan) if log: - log_dict["costs"] = costs - log_dict["idx_min_cost"] = idx_min_cost - return min_perm, min_cost, log_dict + return plan, cost, log_dict else: - return min_perm, min_cost + return plan, cost -def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0): +def expected_sliced( + X, + Y, + a=None, + b=None, + thetas=None, + metric="sqeuclidean", + p=2, + n_proj=None, + dense=True, + log=False, + backend=None, + beta=0.0, +): r""" Computes the Expected Sliced cost and plan between two `(n, d)` datasets `X` and `Y`. Given a set of `n_proj` projection directions, @@ -851,17 +1005,22 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 Expected Sliced was introduced in [84] and further studied in [83]. .. note:: - The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + The computation ignores potential ambiguities in the projections: if + two points from a same measure have the same projection on a direction, + then multiple sorting permutations are possible. To avoid combinatorial + explosion, only one permutation is retained: this strays from theory in + pathological cases. .. warning:: - The function runs on backend but tensorflow and jax are not supported due to array assignment. + The function runs on backend but tensorflow and jax are not supported + due to array assignment. Parameters ---------- X : torch.Tensor - A tensor of shape (n, d) representing the first set of vectors. + A tensor of shape (ns, d) representing the first set of vectors. Y : torch.Tensor - A tensor of shape (n, d) representing the second set of vectors. + A tensor of shape (nt, d) representing the second set of vectors. thetas : torch.Tensor, optional A tensor of shape (n_proj, d) representing the projection directions. If None, random directions will be generated. Default is None. @@ -869,10 +1028,15 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 The number of projection directions. Required if thetas is None. order : int, optional Power to elevate the norm. Default is 2. + dense: boolean, optional (default=True) + If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. log : bool, optional If True, returns additional logging information. Default is False. beta : float, optional - Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). + Inverse-temperature parameter which weights each projection's + contribution to the expected plan. Default is 0 (uniform weighting). Returns ------- @@ -884,50 +1048,67 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0 References ---------- - .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport + Plans. arXiv preprint 2506.03661. - .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. + .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi + A., Kolouri, S. (2024). Expected Sliced Transport Plans. International + Conference on Learning Representations. """ assert ( - X.shape == Y.shape - ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" + X.shape[1] == Y.shape[1] + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" + + nx = get_backend(X, Y) if backend is None else backend - nx = get_backend(X, Y) if str(nx) in ["tf", "jax"]: raise NotImplementedError( f"expected_sliced is not implemented for the {str(nx)} backend due" "to array assignment." ) - n = X.shape[0] + + ns = X.shape[0] + nt = Y.shape[0] log_dict = {} + costs, G, log_dict_plans = sliced_plans( + X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, backend=nx + ) if log: - perm, log_dict = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx - ) - else: - perm = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx - ) - plan = nx.zeros((n, n), type_as=X) - n_proj = perm.shape[1] - range_array = nx.arange(n, type_as=X) + log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G} if beta != 0.0: # computing the temperature weighting - log_factors = nx.zeros(n_proj, type_as=X) # for beta weighting - for k in range(n_proj): - cost_k = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n - log_factors[k] = -beta * cost_k + log_factors = -beta * list_to_array(costs) weights = nx.exp(log_factors - nx.logsumexp(log_factors)) + cost = nx.sum(list_to_array(costs) * weights) else: # uniform weights + if n_proj is None: + n_proj = thetas.shape[0] weights = nx.ones(n_proj, type_as=X) / n_proj - for k in range(n_proj): # populating the expected plan - # 1 / n is because is a permutation of [1, n] - plan[range_array, perm[:, k]] += (1 / n) * weights[k] + log_dict["weights"] = weights + + weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))]) + X_idx = nx.concatenate([G[i].row for i in range(len(G))]) + Y_idx = nx.concatenate([G[i].col for i in range(len(G))]) + plan = nx.coo_matrix( + weights, + X_idx, + Y_idx, + shape=(ns, nt), + type_as=X, + ) + + if beta == 0.0: # otherwise already computed above + cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum() - cost = (dist(X, Y, p=order) * plan).sum() + if dense: + plan = nx.todense(plan) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to\ + dense") + plan = nx.todense(plan) if log: return plan, cost, log_dict From cd6ce31b002c915f1ec143a62ffd7ad202281f38 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Tue, 7 Oct 2025 14:23:35 +0200 Subject: [PATCH 08/19] update lint --- ot/sliced.py | 54 ++++++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/ot/sliced.py b/ot/sliced.py index 687adcae4..d0d0c88cd 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -22,8 +22,7 @@ ) -def get_random_projections(d, n_projections, seed=None, backend=None, - type_as=None): +def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})` @@ -662,8 +661,7 @@ def linear_sliced_wasserstein_sphere( ) Xps_coords, _ = projection_sphere_to_circle( - X_s, n_projections=n_projections, projections=projections, seed=seed, - backend=nx + X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx ) if X_t is not None: @@ -681,8 +679,7 @@ def linear_sliced_wasserstein_sphere( res = nx.mean(projected_lcot) ** (1 / 2) if log: - return res, {"projections": projections, - "projected_emds": projected_lcot} + return res, {"projections": projections, "projected_emds": projected_lcot} return res @@ -783,22 +780,24 @@ def sliced_plans( if metric in ("minkowski", "euclidean", "cityblock"): costs = [ nx.sum( - ((nx.sum(nx.abs(X[sigma[:, k]] - Y[tau[:, k]]) ** p, - axis=1)) ** (1 / p)) / n + ( + (nx.sum(nx.abs(X[sigma[:, k]] - Y[tau[:, k]]) ** p, axis=1)) + ** (1 / p) + ) + / n ) for k in range(n_proj) ] elif metric == "sqeuclidean": costs = [ - nx.sum((nx.sum((X[sigma[:, k]] - Y[tau[:, k]]) ** 2, - axis=1)) / n) + nx.sum((nx.sum((X[sigma[:, k]] - Y[tau[:, k]]) ** 2, axis=1)) / n) for k in range(n_proj) ] else: raise ValueError( - "Sliced plans work only with metrics " + - "from the following list: " + - "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) G = [ @@ -821,22 +820,23 @@ def sliced_plans( costs = [ nx.sum( ( - (nx.sum(nx.abs(X[G[k].row] - Y[G[k].col]) ** p, - axis=1)) ** (1 / p) - ) * G[k].data + (nx.sum(nx.abs(X[G[k].row] - Y[G[k].col]) ** p, axis=1)) + ** (1 / p) + ) + * G[k].data ) for k in range(n_proj) ] elif metric == "sqeuclidean": costs = [ - nx.sum((nx.sum((X[G[k].row] - Y[G[k].col]) ** 2, axis=1)) * - G[k].data) for k in range(n_proj) + nx.sum((nx.sum((X[G[k].row] - Y[G[k].col]) ** 2, axis=1)) * G[k].data) + for k in range(n_proj) ] else: raise ValueError( - "Sliced plans work only with metrics " + - "from the following list: " + - "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) if log: @@ -973,8 +973,10 @@ def min_pivot_sliced( if dense: plan = nx.todense(plan) elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to\ - dense") + warnings.warn( + "JAX does not support sparse matrices, converting to\ + dense" + ) plan = nx.todense(plan) if log: @@ -1106,8 +1108,10 @@ def expected_sliced( if dense: plan = nx.todense(plan) elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to\ - dense") + warnings.warn( + "JAX does not support sparse matrices, converting to\ + dense" + ) plan = nx.todense(plan) if log: From 01e53c36c381514d20967f0d7ccaeb3bc21ceafe Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Tue, 7 Oct 2025 16:36:23 +0200 Subject: [PATCH 09/19] update contributing.md --- .github/CONTRIBUTING.md | 201 ++++++++++-------- .../sliced-wasserstein/plot_sliced_plans.py | 3 +- 2 files changed, 115 insertions(+), 89 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 94486046f..b6add7ae3 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -1,11 +1,8 @@ -Contributing to POT -=================== +# Contributing to POT +First off, thank you for considering contributing to POT. -First off, thank you for considering contributing to POT. - -How to contribute ------------------ +## How to contribute The preferred workflow for contributing to POT is to fork the [main repository](https://github.com/rflamary/POT) on @@ -23,7 +20,7 @@ GitHub, clone, and develop on a branch. Steps: $ cd POT ``` -2. Install pre-commit hooks to ensure that your code is properly formatted: +3. Install pre-commit hooks to ensure that your code is properly formatted: ```bash $ pip install pre-commit @@ -32,15 +29,48 @@ GitHub, clone, and develop on a branch. Steps: This will install the pre-commit hooks that will run on every commit. If the hooks fail, the commit will be aborted. -3. Create a ``feature`` branch to hold your development changes: +4. Create a `feature` branch to hold your development changes: ```bash $ git checkout -b my-feature ``` - Always use a ``feature`` branch. It's good practice to never work on the ``master`` branch! + Always use a `feature` branch. It's good practice to never work on the `master` branch! + +5. Install a recent version of Python (e.g. 3.10), using conda for instance. You can create a conda environment and activate it: + + ```bash + $ conda create -n dev-pot-env python=3.10 + $ conda activate dev-pot-env + ``` + +6. Install all the necessary packages in your environment: -4. Develop the feature on your feature branch. Add changed files using ``git add`` and then ``git commit`` files: +```bash +$ pip install -r requirements_all.txt +``` + +6. Install a compiler with OpenMP support for your platform (see details on the [scikit-learn contributing guide](https://scikit-learn.org/stable/developers/advanced_installation.html#platform-specific-instructions)). + For instance, with macOS, Apple clang does not support OpenMP. One can install the LLVM OpenMP library from homebrew: + + ```bash + $ brew install libomp + ``` + + and set environment variables: + + ```bash + $ export CC=/usr/local/opt/llvm/bin/clang + $ export CXX=/usr/local/opt/llvm/bin/clang++ + ``` + +7. Build the projet with pip: + + ```bash + pip install -e . + ``` + +8. Develop the feature on your feature branch. Add changed files using `git add` and then `git commit` files: ```bash $ git add modified_files @@ -53,64 +83,62 @@ GitHub, clone, and develop on a branch. Steps: $ git push -u origin my-feature ``` -5. Follow [these instructions](https://help.github.com/articles/creating-a-pull-request-from-a-fork) -to create a pull request from your fork. This will send an email to the committers. +9. Follow [these instructions](https://help.github.com/articles/creating-a-pull-request-from-a-fork) + to create a pull request from your fork. This will send an email to the committers. (If any of the above seems like magic to you, please look up the [Git documentation](https://git-scm.com/documentation) on the web, or ask a friend or another contributor for help.) -Pull Request Checklist ----------------------- +## Pull Request Checklist We recommended that your contribution complies with the following rules before you submit a pull request: -- Follow the PEP8 Guidelines which should be handles automatically by pre-commit. - -- If your pull request addresses an issue, please use the pull request title - to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is - created. - -- All public methods should have informative docstrings with sample - usage presented as doctests when appropriate. - -- Please prefix the title of your pull request with `[MRG]` (Ready for - Merge), if the contribution is complete and ready for a detailed review. - Two core developers will review your code and change the prefix of the pull - request to `[MRG + 1]` and `[MRG + 2]` on approval, making it eligible - for merging. An incomplete contribution -- where you expect to do more work before - receiving a full review -- should be prefixed `[WIP]` (to indicate a work - in progress) and changed to `[MRG]` when it matures. WIPs may be useful - to: indicate you are working on something to avoid duplicated work, - request broad review of functionality or API, or seek collaborators. - WIPs often benefit from the inclusion of a - [task list](https://github.com/blog/1375-task-lists-in-gfm-issues-pulls-comments) - in the PR description. - - -- When adding additional functionality, provide at least one - example script in the ``examples/`` folder. Have a look at other - examples for reference. Examples should demonstrate why the new - functionality is useful in practice and, if possible, compare it - to other methods available in POT. - -- Documentation and high-coverage tests are necessary for enhancements to be - accepted. Bug-fixes or new features should be provided with - [non-regression tests](https://en.wikipedia.org/wiki/Non-regression_testing). - These tests verify the correct behavior of the fix or feature. In this - manner, further modifications on the code base are granted to be consistent - with the desired behavior. - For the Bug-fixes case, at the time of the PR, this tests should fail for - the code base in master and pass for the PR code. - -- At least one paragraph of narrative documentation with links to - references in the literature (with PDF links when possible) and - the example. +* Follow the PEP8 Guidelines which should be handles automatically by pre-commit. + +* If your pull request addresses an issue, please use the pull request title + to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is + created. + +* All public methods should have informative docstrings with sample + usage presented as doctests when appropriate. + +* Please prefix the title of your pull request with `[MRG]` (Ready for + Merge), if the contribution is complete and ready for a detailed review. + Two core developers will review your code and change the prefix of the pull + request to `[MRG + 1]` and `[MRG + 2]` on approval, making it eligible + for merging. An incomplete contribution -- where you expect to do more work before + receiving a full review -- should be prefixed `[WIP]` (to indicate a work + in progress) and changed to `[MRG]` when it matures. WIPs may be useful + to: indicate you are working on something to avoid duplicated work, + request broad review of functionality or API, or seek collaborators. + WIPs often benefit from the inclusion of a + [task list](https://github.com/blog/1375-task-lists-in-gfm-issues-pulls-comments) + in the PR description. + +* When adding additional functionality, provide at least one + example script in the `examples/` folder. Have a look at other + examples for reference. Examples should demonstrate why the new + functionality is useful in practice and, if possible, compare it + to other methods available in POT. + +* Documentation and high-coverage tests are necessary for enhancements to be + accepted. Bug-fixes or new features should be provided with + [non-regression tests](https://en.wikipedia.org/wiki/Non-regression_testing). + These tests verify the correct behavior of the fix or feature. In this + manner, further modifications on the code base are granted to be consistent + with the desired behavior. + For the Bug-fixes case, at the time of the PR, this tests should fail for + the code base in master and pass for the PR code. + +* At least one paragraph of narrative documentation with links to + references in the literature (with PDF links when possible) and + the example. You can also check for common programming errors with the following tools: -- All lint checks pass. You can run the following command to check: +* All lint checks pass. You can run the following command to check: ```bash $ pre-commit run --all-files @@ -118,52 +146,51 @@ tools: This will run the pre-commit checks on all files in the repository. -- All tests pass. You can run the following command to check: +* All tests pass. You can run the following command to check: ```bash $ pytest --durations=20 -v test/ --doctest-modules - ``` + ``` Bonus points for contributions that include a performance analysis with a benchmark script and profiling output (please report on the mailing list or on the GitHub issue). -Filing bugs ------------ +## Filing bugs + We use Github issues to track all bugs and feature requests; feel free to open an issue if you have found a bug or wish to see a feature implemented. It is recommended to check that your issue complies with the following rules before submitting: -- Verify that your issue is not being currently addressed by other - [issues](https://github.com/rflamary/POT/issues?q=) - or [pull requests](https://github.com/rflamary/POT/pulls?q=). +* Verify that your issue is not being currently addressed by other + [issues](https://github.com/rflamary/POT/issues?q=) + or [pull requests](https://github.com/rflamary/POT/pulls?q=). -- Please ensure all code snippets and error messages are formatted in - appropriate code blocks. - See [Creating and highlighting code blocks](https://help.github.com/articles/creating-and-highlighting-code-blocks). +* Please ensure all code snippets and error messages are formatted in + appropriate code blocks. + See [Creating and highlighting code blocks](https://help.github.com/articles/creating-and-highlighting-code-blocks). -- Please include your operating system type and version number, as well - as your Python, POT, numpy, and scipy versions. This information - can be found by running the following code snippet: +* Please include your operating system type and version number, as well + as your Python, POT, numpy, and scipy versions. This information + can be found by running the following code snippet: - ```python - import platform; print(platform.platform()) - import sys; print("Python", sys.version) - import numpy; print("NumPy", numpy.__version__) - import scipy; print("SciPy", scipy.__version__) - import ot; print("POT", ot.__version__) - ``` +```python +import platform; print(platform.platform()) +import sys; print("Python", sys.version) +import numpy; print("NumPy", numpy.__version__) +import scipy; print("SciPy", scipy.__version__) +import ot; print("POT", ot.__version__) +``` -- Please be specific about what estimators and/or functions are involved - and the shape of the data, as appropriate; please include a - [reproducible](http://stackoverflow.com/help/mcve) code snippet - or link to a [gist](https://gist.github.com). If an exception is raised, - please provide the traceback. +* Please be specific about what estimators and/or functions are involved + and the shape of the data, as appropriate; please include a + [reproducible](http://stackoverflow.com/help/mcve) code snippet + or link to a [gist](https://gist.github.com). If an exception is raised, + please provide the traceback. -New contributor tips --------------------- +## New contributor tips A great way to start contributing to POT is to pick an item from the list of [Easy issues](https://github.com/rflamary/POT/issues?labels=Easy) @@ -173,8 +200,7 @@ assistance in this area will be greatly appreciated by the more experienced developers as it helps free up their time to concentrate on other issues. -Documentation -------------- +## Documentation We are glad to accept any sort of documentation: function docstrings, reStructuredText documents (like this one), tutorials, etc. @@ -182,8 +208,8 @@ reStructuredText documents live in the source code repository under the doc/ directory. You can edit the documentation using any text editor and then generate -the HTML output by typing ``make html`` from the ``docs/`` directory. -Alternatively, ``make`` can be used to quickly generate the +the HTML output by typing `make html` from the `docs/` directory. +Alternatively, `make` can be used to quickly generate the documentation without the example gallery with `make html-noplot`. The resulting HTML files will be placed in `docs/build/html/` and are viewable in a web browser. @@ -199,5 +225,4 @@ start with a small paragraph with a hand-waving explanation of what the method does to the data and a figure (coming from an example) illustrating it. - This Contribution guide is strongly inspired by the one of the [scikit-learn](https://github.com/scikit-learn/scikit-learn) team. diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py index 43703e57e..877a02190 100644 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -128,7 +128,8 @@ ############################################################################## # Compare Expected Sliced plans with different inverse-temperatures beta # ------------------------------------ -## As the temperature decreases, ES becomes sparser and approaches minPS +# As the temperature decreases, ES becomes sparser and approaches minPS + betas = [0.0, 5.0, 50.0] n_plots = len(betas) + 1 size = 4 From 86749e0cca36dbb0d8e0fb1a5a472fe1b96a302a Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Tue, 7 Oct 2025 17:17:27 +0200 Subject: [PATCH 10/19] update tests for sliced plans --- test/test_sliced.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/test_sliced.py b/test/test_sliced.py index 7f12d378a..f15a1598f 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -742,19 +742,17 @@ def test_sliced_permutations(nx): thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T thetas_b = nx.from_numpy(thetas) - perm = ot.sliced.sliced_permutations(x, y, thetas=thetas) - perm_b, _ = ot.sliced.sliced_permutations( - x_b, y_b, thetas=thetas_b, log=True, backend=nx - ) + perm = ot.sliced.sliced_plans(x, y, thetas=thetas) + perm_b, _ = ot.sliced.sliced_plans(x_b, y_b, thetas=thetas_b, log=True, backend=nx) np.testing.assert_almost_equal(perm, nx.to_numpy(perm_b)) # test without provided thetas - perm = ot.sliced.sliced_permutations(x, y, n_proj=n_proj) + perm = ot.sliced.sliced_plans(x, y, n_proj=n_proj) # test with invalid shapes with pytest.raises(AssertionError): - ot.sliced.sliced_permutations(x[1:, :], y, thetas=thetas) + ot.sliced.sliced_plans(x[1:, :], y, thetas=thetas) def test_min_pivot_sliced(nx): @@ -783,8 +781,8 @@ def test_min_pivot_sliced(nx): assert min_cost >= w2 assert min_cost <= 1.5 * w2 - # test without provided thetas and with a warm permutation - ot.sliced.min_pivot_sliced(x, y, n_proj=n_proj, warm_perm=np.arange(n), log=True) + # test without provided thetas + ot.sliced.min_pivot_sliced(x, y, n_proj=n_proj, log=True) # test with invalid shapes with pytest.raises(AssertionError): @@ -811,9 +809,11 @@ def test_expected_sliced(nx): ) with context: - expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas) + expected_plan, expected_cost = ot.sliced.expected_sliced( + x, y, dense=True, thetas=thetas + ) expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( - x_b, y_b, thetas=thetas_b, log=True + x_b, y_b, thetas=thetas_b, dense=True, log=True ) np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) From 83b653bc1a92d48e95edac15b8dac2632d874d8c Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Wed, 8 Oct 2025 09:59:52 +0200 Subject: [PATCH 11/19] update tests and doc --- .../sliced-wasserstein/plot_sliced_plans.py | 1 - ot/sliced.py | 166 +++++++++++------- test/test_sliced.py | 101 ++++++++--- 3 files changed, 182 insertions(+), 86 deletions(-) diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py index 877a02190..00f7beed2 100644 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -28,7 +28,6 @@ import ot import matplotlib.pyplot as plt from ot.sliced import get_random_projections -from ot.lp import wasserstein_1d seed = 0 diff --git a/ot/sliced.py b/ot/sliced.py index d0d0c88cd..6a76c7fee 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -691,8 +691,9 @@ def sliced_plans( metric="sqeuclidean", p=2, thetas=None, - warm_theta=False, + warm_theta=None, n_proj=None, + dense=False, log=False, backend=None, ): @@ -723,6 +724,9 @@ def sliced_plans( Default is None. warm_theta : array-like, shape (d,), optional A direction to add to the set of directions. Default is None. + dense: bool, optional + If True, returns dense matrices instead of sparse ones. + Default is False. n_proj : int, optional The number of projection directions. Required if thetas is None. log : bool, optional @@ -733,18 +737,19 @@ def sliced_plans( Returns ------- - G, sigma, tau, costs - G: ndarray, shape (ns, nt) or coo_matrix if dense is False + plan : ndarray, shape (ns, nt) or coo_matrix if dense is False Optimal transportation matrix for the given parameters - sigma : list of elements of array-like - All the indices of X sorted along each projection. - tau : list of elements of array-like - All the indices of Y sorted along each projection. + costs : list of float + The cost associated to each projection. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. """ + X, Y = list_to_array(X, Y) + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" + assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" + assert ( X.shape[1] == Y.shape[1] ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" @@ -758,6 +763,11 @@ def sliced_plans( m = Y.shape[0] nx = get_backend(X, Y) if backend is None else backend + is_perm = False + if n == m: + if a is None or b is None or (a == b).all(): + is_perm = True + do_draw_thetas = thetas is None if do_draw_thetas: # create thetas (n_proj, d) assert n_proj is not None, "n_proj must be specified if thetas is None" @@ -771,12 +781,11 @@ def sliced_plans( X_theta = X @ thetas.T # shape (n, n_proj) Y_theta = Y @ thetas.T # shape (m, n_proj) - if n == m and (a is None or b is None or (a == b).all()): + if is_perm: # we compute maps (permutations) # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) tau = nx.argsort(Y_theta, axis=0) # (m, n_proj) - if metric in ("minkowski", "euclidean", "cityblock"): costs = [ nx.sum( @@ -799,20 +808,13 @@ def sliced_plans( + "from the following list: " + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) - - G = [ - nx.coo_matrix( - np.ones(n) / n, - sigma[:, k], - tau[:, k], - shape=(n, m), - type_as=X_theta, - ) + plan = [ + nx.coo_matrix(np.ones(n) / n, sigma[:, k], tau[:, k], shape=(n, m)) for k in range(n_proj) ] else: # we compute plans - _, G = wasserstein_1d( + _, plan = wasserstein_1d( X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True ) @@ -820,16 +822,19 @@ def sliced_plans( costs = [ nx.sum( ( - (nx.sum(nx.abs(X[G[k].row] - Y[G[k].col]) ** p, axis=1)) + (nx.sum(nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1)) ** (1 / p) ) - * G[k].data + * plan[k].data ) for k in range(n_proj) ] elif metric == "sqeuclidean": costs = [ - nx.sum((nx.sum((X[G[k].row] - Y[G[k].col]) ** 2, axis=1)) * G[k].data) + nx.sum( + (nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1)) + * plan[k].data + ) for k in range(n_proj) ] else: @@ -839,11 +844,17 @@ def sliced_plans( + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) + if dense: + plan = [nx.todense(plan[k]) for k in range(n_proj)] + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + plan = [nx.todense(plan[k]) for k in range(n_proj)] + if log: log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas} - return costs, G, log_dict + return plan, costs, log_dict else: - return costs, G + return plan, costs def min_pivot_sliced( @@ -863,11 +874,11 @@ def min_pivot_sliced( r""" Computes the cost and permutation associated to the min-Pivot Sliced Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given - the supports `X` and `Y` of two discrete uniform measures with `n` atoms in - dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj` - different projections of the measures on random directions, and retains the - permutation that yields the lowest cost between `X` and `Y` (compared - in :math:`\mathbb{R}^d`). + the supports `X` and `Y` of two discrete uniform measures with `n` and `m` + atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through + `n_proj` different projections of the measures on random directions, and + retains the couplings that yields the lowest cost between `X` and `Y` + (compared in :math:`\mathbb{R}^d`). When $n=m$, it gives .. math:: \mathrm{min\text{-}PS}_p^p(X, Y) \approx @@ -888,7 +899,7 @@ def min_pivot_sliced( ---------- X : array-like, shape (n, d) The first set of vectors. - Y : array-like, shape (n, d) + Y : array-like, shape (m, d) The second set of vectors. a : ndarray of float64, shape (ns,), optional Source histogram (default is uniform weight) @@ -918,10 +929,10 @@ def min_pivot_sliced( Returns ------- - perm : array-like, shape (n,) - The permutation that minimizes the cost. - min_cost : float - The minimum cost corresponding to the optimal permutation. + plan : ndarray, shape (n, m) or coo_matrix if dense is False + Optimal transportation matrix for the given parameters. + cost : float + The cost associated to the optimal permutation. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. @@ -935,8 +946,24 @@ def min_pivot_sliced( .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + + Examples + -------- + >>> x=np.array([[3,3], [1,1]]) + >>> y=np.array([[2,2.5], [3,2]]) + >>> thetas=np.array([[1, 0], [0, 1]]) + >>> plan, cost = ot.expected_sliced(x, y, thetas) + >>> plan + [[0 0.5] + [0.5 0]] + >>> cost + 2.125 """ + X, Y = list_to_array(X, Y) + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" + assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" + assert ( X.shape[1] == Y.shape[1] ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" @@ -944,7 +971,7 @@ def min_pivot_sliced( nx = get_backend(X, Y) if backend is None else backend log_dict = {} - costs, G, log_dict_plans = sliced_plans( + G, costs, log_dict_plans = sliced_plans( X, Y, a, @@ -1000,11 +1027,12 @@ def expected_sliced( beta=0.0, ): r""" - Computes the Expected Sliced cost and plan between two `(n, d)` - datasets `X` and `Y`. Given a set of `n_proj` projection directions, - the expected sliced plan is obtained by averaging the `n_proj` 1d optimal - transport plans between the projections of `X` and `Y` on each direction. - Expected Sliced was introduced in [84] and further studied in [83]. + Computes the Expected Sliced cost and plan between two datasets `X` and + `Y` of shapes `(n, d)` and `(m, d)`. Given a set of `n_proj` projection + directions, the expected sliced plan is obtained by averaging the `n_proj` + 1d optimal transport plans between the projections of `X` and `Y` on each + direction. Expected Sliced was introduced in [84] and further studied in + [83]. .. note:: The computation ignores potential ambiguities in the projections: if @@ -1020,9 +1048,9 @@ def expected_sliced( Parameters ---------- X : torch.Tensor - A tensor of shape (ns, d) representing the first set of vectors. + A tensor of shape (n, d) representing the first set of vectors. Y : torch.Tensor - A tensor of shape (nt, d) representing the second set of vectors. + A tensor of shape (m, d) representing the second set of vectors. thetas : torch.Tensor, optional A tensor of shape (n_proj, d) representing the projection directions. If None, random directions will be generated. Default is None. @@ -1031,7 +1059,7 @@ def expected_sliced( order : int, optional Power to elevate the norm. Default is 2. dense: boolean, optional (default=True) - If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). + If True, returns :math:`\gamma` as a dense ndarray of shape (n, m). Otherwise returns a sparse representation using scipy's `coo_matrix` format. log : bool, optional @@ -1042,8 +1070,10 @@ def expected_sliced( Returns ------- - plan : torch.Tensor - A tensor of shape (n_proj, n, n) representing the expected sliced plan. + plan : ndarray, shape (n, m) or coo_matrix if dense is False + Optimal transportation matrix for the given parameters. + cost : float + The cost associated to the optimal permutation. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. @@ -1051,12 +1081,29 @@ def expected_sliced( References ---------- .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport - Plans. arXiv preprint 2506.03661. - + Plans. arXiv preprint 2506.03661. .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi - A., Kolouri, S. (2024). Expected Sliced Transport Plans. International - Conference on Learning Representations. + A., Kolouri, S. (2024). Expected Sliced Transport Plans. + International Conference on Learning Representations. + + Examples + -------- + >>> x=np.array([[3,3], [1,1]]) + >>> y=np.array([[2,2.5], [3,2]]) + >>> thetas=np.array([[1, 0], [0, 1]]) + >>> plan, cost = ot.expected_sliced(x, y, thetas) + >>> plan + [[0.25 0.25] + [0.25 0.25]] + >>> cost + 2.625 """ + + X, Y = list_to_array(X, Y) + + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" + assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" + assert ( X.shape[1] == Y.shape[1] ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" @@ -1069,11 +1116,11 @@ def expected_sliced( "to array assignment." ) - ns = X.shape[0] - nt = Y.shape[0] + n = X.shape[0] + m = Y.shape[0] log_dict = {} - costs, G, log_dict_plans = sliced_plans( + G, costs, log_dict_plans = sliced_plans( X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, backend=nx ) if log: @@ -1087,20 +1134,14 @@ def expected_sliced( else: # uniform weights if n_proj is None: n_proj = thetas.shape[0] - weights = nx.ones(n_proj, type_as=X) / n_proj + weights = nx.ones(n_proj) / n_proj log_dict["weights"] = weights weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))]) X_idx = nx.concatenate([G[i].row for i in range(len(G))]) Y_idx = nx.concatenate([G[i].col for i in range(len(G))]) - plan = nx.coo_matrix( - weights, - X_idx, - Y_idx, - shape=(ns, nt), - type_as=X, - ) + plan = nx.coo_matrix(weights, X_idx, Y_idx, shape=(n, m)) if beta == 0.0: # otherwise already computed above cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum() @@ -1108,10 +1149,7 @@ def expected_sliced( if dense: plan = nx.todense(plan) elif str(nx) == "jax": - warnings.warn( - "JAX does not support sparse matrices, converting to\ - dense" - ) + warnings.warn("JAX does not support sparse matrices, converting to dense") plan = nx.todense(plan) if log: diff --git a/test/test_sliced.py b/test/test_sliced.py index f15a1598f..39a26e9b9 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -742,61 +742,118 @@ def test_sliced_permutations(nx): thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T thetas_b = nx.from_numpy(thetas) - perm = ot.sliced.sliced_plans(x, y, thetas=thetas) - perm_b, _ = ot.sliced.sliced_plans(x_b, y_b, thetas=thetas_b, log=True, backend=nx) - - np.testing.assert_almost_equal(perm, nx.to_numpy(perm_b)) + plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True) + plan_b, _, _ = ot.sliced.sliced_plans( + x_b, y_b, thetas=thetas_b, log=True, dense=True, backend=nx + ) + np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b)) # test without provided thetas - perm = ot.sliced.sliced_plans(x, y, n_proj=n_proj) + _, _ = ot.sliced.sliced_plans(x, y, n_proj=n_proj) # test with invalid shapes with pytest.raises(AssertionError): - ot.sliced.sliced_plans(x[1:, :], y, thetas=thetas) + ot.sliced.sliced_plans(x[:, 1:], y, thetas=thetas) + + +def test_sliced_plans(nx): + x = [1, 2] + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x, x, n_proj=2) + + n = 4 + m = 5 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(m, 2) + + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + # test with a and b uniform + plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True) + plan_b, _, _ = ot.sliced.sliced_plans( + x_b, y_b, thetas=thetas_b, log=True, dense=True, backend=nx + ) + np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b)) + + # test with a and b not uniform + plan, _ = ot.sliced.sliced_plans(x, y, a, b, thetas=thetas, dense=True) + plan_b, _, _ = ot.sliced.sliced_plans( + x_b, y_b, a, b, thetas=thetas_b, log=True, dense=True, backend=nx + ) + np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b)) def test_min_pivot_sliced(nx): + x = [1, 2] + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x, x, n_proj=2) + n = 10 + m = 4 n_proj = 10 d = 2 rng = np.random.RandomState(0) x = rng.randn(n, 2) - y = rng.randn(n, 2) + y = rng.randn(m, 2) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() x_b, y_b = nx.from_numpy(x, y) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T thetas_b = nx.from_numpy(thetas) - min_perm, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) - min_perm_b, min_cost_b, _ = ot.sliced.min_pivot_sliced( - x_b, y_b, thetas=thetas_b, log=True + G, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) + G_b, min_cost_b, _ = ot.sliced.min_pivot_sliced( + x_b, y_b, a, b, thetas=thetas_b, log=True, dense=True ) - np.testing.assert_almost_equal(min_perm, nx.to_numpy(min_perm_b)) + np.testing.assert_almost_equal(G, nx.to_numpy(G_b)) np.testing.assert_almost_equal(min_cost, nx.to_numpy(min_cost_b)) # result should be an upper-bound of W2 and relatively close - w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + w2 = ot.emd2(a, b, ot.dist(x, y)) assert min_cost >= w2 assert min_cost <= 1.5 * w2 # test without provided thetas - ot.sliced.min_pivot_sliced(x, y, n_proj=n_proj, log=True) + ot.sliced.min_pivot_sliced(x, y, a, b, n_proj=n_proj, log=True) # test with invalid shapes with pytest.raises(AssertionError): - ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) def test_expected_sliced(nx): + x = [1, 2] + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x, x, n_proj=2) + n = 10 + m = 24 n_proj = 10 d = 2 rng = np.random.RandomState(0) x = rng.randn(n, 2) - y = rng.randn(n, 2) + y = rng.randn(m, 2) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() x_b, y_b = nx.from_numpy(x, y) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T @@ -810,17 +867,17 @@ def test_expected_sliced(nx): with context: expected_plan, expected_cost = ot.sliced.expected_sliced( - x, y, dense=True, thetas=thetas + x, y, a, b, dense=True, thetas=thetas ) expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( - x_b, y_b, thetas=thetas_b, dense=True, log=True + x_b, y_b, a, b, thetas=thetas_b, dense=True, log=True ) np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) # result should be a coarse upper-bound of W2 - w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + w2 = ot.emd2(a, b, ot.dist(x, y)) assert expected_cost >= w2 assert expected_cost <= 3 * w2 @@ -829,10 +886,12 @@ def test_expected_sliced(nx): # test with invalid shapes with pytest.raises(AssertionError): - ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) # with a small temperature (i.e. large beta), the cost should be close # to min_pivot - _, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0) - _, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) + _, expected_cost = ot.sliced.expected_sliced( + x, y, a, b, thetas=thetas, dense=True, beta=100.0 + ) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) From e5917370f61b34c94dc52a021396d6180a45b982 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Wed, 8 Oct 2025 11:22:51 +0200 Subject: [PATCH 12/19] update tests and doc --- README.md | 287 ++++++++++++++++++++++---------------------- ot/backend.py | 19 --- ot/sliced.py | 22 ++-- test/test_sliced.py | 11 ++ 4 files changed, 167 insertions(+), 172 deletions(-) diff --git a/README.md b/README.md index ed6f2a89c..f443ec365 100644 --- a/README.md +++ b/README.md @@ -12,78 +12,78 @@ This open source Python library provides several solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning. -Website and documentation: [https://PythonOT.github.io/](https://PythonOT.github.io/) +Website and documentation: Source Code (MIT): -[https://github.com/PythonOT/POT](https://github.com/PythonOT/POT) - + POT has the following main features: + * A large set of differentiable solvers for optimal transport problems, including: - * Exact linear OT, entropic and quadratic regularized OT, - * Gromov-Wasserstein (GW) distances, Fused GW distances and variants of - quadratic OT, - * Unbalanced and partial OT for different divergences, -* OT barycenters (Wasserstein and GW) for fixed and free support, -* Fast OT solvers in 1D, on the circle and between Gaussian Mixture Models (GMMs), -* Many ML related solvers, such as domain adaptation, optimal transport mapping - estimation, subspace learning, Graph Neural Networks (GNNs) layers. -* Several backends for easy use with Pytorch, Jax, Tensorflow, Numpy and Cupy arrays. + * Exact linear OT, entropic and quadratic regularized OT, + * Gromov-Wasserstein (GW) distances, Fused GW distances and variants of + quadratic OT, + * Unbalanced and partial OT for different divergences, +* OT barycenters (Wasserstein and GW) for fixed and free support, +* Fast OT solvers in 1D, on the circle and between Gaussian Mixture Models (GMMs), +* Many ML related solvers, such as domain adaptation, optimal transport mapping + estimation, subspace learning, Graph Neural Networks (GNNs) layers. +* Several backends for easy use with Pytorch, Jax, Tensorflow, Numpy and Cupy arrays. ### Implemented Features POT provides the following generic OT solvers: -* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . -* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. +* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance \[1] . +* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) \[6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT \[7]. * Entropic regularization OT solver with [Sinkhorn Knopp - Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , - stabilized version [9] [10] [34], lazy CPU/GPU solver from geomloss [60] [61], greedy Sinkhorn [22] and Screening - Sinkhorn [26]. -* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. -* Sinkhorn divergence [23] and entropic regularization OT from empirical data. -* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] -* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17]. -* Weak OT solver between empirical distributions [39] -* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12,51]), differentiable using gradients from Graph Dictionary Learning [38] - * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact [24] and regularized [12,51]). + Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) \[2] , + stabilized version \[9] \[10] \[34], lazy CPU/GPU solver from geomloss \[60] \[61], greedy Sinkhorn \[22] and Screening + Sinkhorn \[26]. +* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) \[3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) \[21] and unmixing \[4]. +* Sinkhorn divergence \[23] and entropic regularization OT from empirical data. +* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) \[37] +* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations \[17]. +* Weak OT solver between empirical distributions \[39] +* Non regularized [Wasserstein barycenters \[16\] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale). +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact \[13] and regularized \[12,51]), differentiable using gradients from Graph Dictionary Learning \[38] +* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact \[24] and regularized \[12,51]). * [Stochastic solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for - Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) -* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] -* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] -* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). -* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. + Large-scale Optimal Transport (semi-dual problem \[18] and dual problem \[19]) +* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions \[33] +* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) \[20]. +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) \[10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) \[41] +* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact \[29] and entropic \[3] formulations). +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) \[31, 32] and Max-sliced Wasserstein \[35] that can be used for gradient flows \[36]. * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html) - [44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] -* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. -* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized [48]). -* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. -* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. + \[44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) \[46] +* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) \[38]. +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized \[48]). +* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) \[68]. +* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) \[50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. -* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. -* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/gaussian_gmm/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) [69]. -* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and -[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. -* Fused unbalanced Gromov-Wasserstein [70]. -* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77] -* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77] -* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [82, 83, 84] +* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) \[58], with an extension to bounding potentials using \[59]. +* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/gaussian_gmm/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) \[69]. +* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) \[49] and + [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) \[71]. +* Fused unbalanced Gromov-Wasserstein \[70]. +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) \[77] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) \[69, 77] +* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) \[82, 83, 84] POT provides the following Machine Learning related solvers: * [Optimal transport for domain adaptation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_classes.html) - with [group lasso regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_classes.html), [Laplacian regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_laplacian.html) [5] [30] and [semi + with [group lasso regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_classes.html), [Laplacian regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_laplacian.html) \[5] \[30] and [semi supervised setting](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_semi_supervised.html). -* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8]. -* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt). -* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27]. -* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53] +* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) \[14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) \[8]. +* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) \[11] (requires autograd + pymanopt). +* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) \[27]. +* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) \[52] and TW (OT-GNN) \[53] Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html). @@ -93,9 +93,11 @@ If you use this toolbox in your research and find it useful, please cite POT using the following references from the current version and from our [JMLR paper](https://jmlr.org/papers/v22/20-451.html): - Flamary R., Vincent-Cuaz C., Courty N., Gramfort A., Kachaiev O., Quang Tran H., David L., Bonet C., Cassereau N., Gnassounou T., Tanguy E., Delon J., Collas A., Mazelet S., Chapel L., Kerdoncuff T., Yu X., Feickert M., Krzakala P., Liu T., Fernandes Montesuma E. POT Python Optimal Transport (version 0.9.5). URL: https://github.com/PythonOT/POT +``` +Flamary R., Vincent-Cuaz C., Courty N., Gramfort A., Kachaiev O., Quang Tran H., David L., Bonet C., Cassereau N., Gnassounou T., Tanguy E., Delon J., Collas A., Mazelet S., Chapel L., Kerdoncuff T., Yu X., Feickert M., Krzakala P., Liu T., Fernandes Montesuma E. POT Python Optimal Transport (version 0.9.5). URL: https://github.com/PythonOT/POT - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. URL: https://pythonot.github.io/ +Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. URL: https://pythonot.github.io/ +``` In Bibtex format: @@ -123,13 +125,12 @@ In Bibtex format: The library has been tested on Linux, MacOSX and Windows. It requires a C++ compiler for building/installing the EMD solver and relies on the following Python modules: -- Numpy (>=1.16) -- Scipy (>=1.0) -- Cython (>=0.23) (build only, not necessary when installing from pip or conda) +* Numpy (>=1.16) +* Scipy (>=1.0) +* Cython (>=0.23) (build only, not necessary when installing from pip or conda) #### Pip installation - You can install the toolbox through PyPI with: ```console @@ -143,9 +144,11 @@ pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user ``` Optional dependencies may be installed with + ```console pip install POT[all] ``` + Note that this installs `cvxopt`, which is licensed under GPL 3.0. Alternatively, if you cannot use GPL-licensed software, the specific optional dependencies may be installed individually, or per-submodule. The available optional installations are `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`. #### Anaconda installation with conda-forge @@ -157,6 +160,7 @@ conda install -c conda-forge pot ``` #### Post installation check + After a correct installation, you should be able to import the module without errors: ```python @@ -165,7 +169,6 @@ import ot Note that for easier access the module is named `ot` instead of `pot`. - ### Dependencies Some sub-modules require additional dependencies which are discussed below @@ -176,7 +179,6 @@ Some sub-modules require additional dependencies which are discussed below pip install pymanopt autograd ``` - ## Examples ### Short examples @@ -241,8 +243,7 @@ ba = ot.barycenter(A, M, reg) # reg is regularization parameter ### Examples and Notebooks -The examples folder contain several examples and use case for the library. The full documentation with examples and output is available on [https://PythonOT.github.io/](https://PythonOT.github.io/). - +The examples folder contain several examples and use case for the library. The full documentation with examples and output is available on . ## Acknowledgements @@ -279,179 +280,179 @@ You can also post bug reports and feature requests in Github issues. Make sure t ## References -[1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). [Displacement interpolation using Lagrangian mass transport](https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf). In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. +\[1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). [Displacement interpolation using Lagrangian mass transport](https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf). In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. -[2] Cuturi, M. (2013). [Sinkhorn distances: Lightspeed computation of optimal transport](https://arxiv.org/pdf/1306.0895.pdf). In Advances in Neural Information Processing Systems (pp. 2292-2300). +\[2] Cuturi, M. (2013). [Sinkhorn distances: Lightspeed computation of optimal transport](https://arxiv.org/pdf/1306.0895.pdf). In Advances in Neural Information Processing Systems (pp. 2292-2300). -[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). [Iterative Bregman projections for regularized transportation problems](https://arxiv.org/pdf/1412.5154.pdf). SIAM Journal on Scientific Computing, 37(2), A1111-A1138. +\[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). [Iterative Bregman projections for regularized transportation problems](https://arxiv.org/pdf/1412.5154.pdf). SIAM Journal on Scientific Computing, 37(2), A1111-A1138. -[4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, [Supervised planetary unmixing with optimal transport](https://hal.archives-ouvertes.fr/hal-01377236/document), Workshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. +\[4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, [Supervised planetary unmixing with optimal transport](https://hal.archives-ouvertes.fr/hal-01377236/document), Workshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. -[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, [Optimal Transport for Domain Adaptation](https://arxiv.org/pdf/1507.00504.pdf), in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 +\[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, [Optimal Transport for Domain Adaptation](https://arxiv.org/pdf/1507.00504.pdf), in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 -[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). [Regularized discrete optimal transport](https://arxiv.org/pdf/1307.5551.pdf). SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +\[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). [Regularized discrete optimal transport](https://arxiv.org/pdf/1307.5551.pdf). SIAM Journal on Imaging Sciences, 7(3), 1853-1882. -[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). [Generalized conditional gradient: analysis of convergence and applications](https://arxiv.org/pdf/1510.06567.pdf). arXiv preprint arXiv:1510.06567. +\[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). [Generalized conditional gradient: analysis of convergence and applications](https://arxiv.org/pdf/1510.06567.pdf). arXiv preprint arXiv:1510.06567. -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS). +\[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS). -[9] Schmitzer, B. (2016). [Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/pdf/1610.06519.pdf). arXiv preprint arXiv:1610.06519. +\[9] Schmitzer, B. (2016). [Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/pdf/1610.06519.pdf). arXiv preprint arXiv:1610.06519. -[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816. +\[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816. -[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. +\[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. -[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). +\[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). -[13] Mémoli, Facundo (2011). [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 : 417-487. +\[13] Mémoli, Facundo (2011). [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 : 417-487. -[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. +\[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. -[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . +\[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . -[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. +\[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. -[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). +\[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). -[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). +\[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). -[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) +\[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) -[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning +\[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning -[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. +\[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. -[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 +\[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 -[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics, (AISTATS) 21, 2018 +\[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics, (AISTATS) 21, 2018 -[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). +\[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). -[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). +\[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). -[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). +\[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). -[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019. +\[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019. -[28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. +\[28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. -[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020. +\[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020. -[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. +\[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. -[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +\[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 -[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). +\[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). -[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 +\[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 -[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. +\[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. -[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). +\[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). -[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +\[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. (2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. -[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International +\[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 -[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +\[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. -[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. +\[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825\&rep=rep1\&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. -[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. +\[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. -[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) +\[41] Chapel\*, L., Flamary\*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) -[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. +\[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +\[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. -[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. +\[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. -[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. +\[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. -[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. +\[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. -[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. +\[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. -[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. +\[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. -[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. +\[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. -[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). +\[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). -[51] Xu, H., Luo, D., Zha, H., & Carin, L. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. +\[51] Xu, H., Luo, D., Zha, H., & Carin, L. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. -[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv. +\[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv. -[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35. +\[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35. -[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804). +\[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804). -[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR). +\[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR). -[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. +\[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. -[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein +\[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4), 1178-1198. -[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. +\[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. -[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. +\[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. -[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. +\[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. -[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. +\[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. -[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning. +\[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning. -[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations. +\[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations. -[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. +\[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. -[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). +\[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). -[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). +\[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). -[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. +\[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. -[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. +\[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. -[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. +\[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. -[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene +\[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene & B. Thirion (2022). [Aligning individual brains with Fused Unbalanced Gromov-Wasserstein.](https://proceedings.neurips.cc/paper_files/paper/2022/file/8906cac4ca58dcaf17e97a0486ad57ca-Paper-Conference.pdf). Neural Information Processing Systems (NeurIPS). -[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on +\[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on Artificial Intelligence. -[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). +\[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). -[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. +\[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. -[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. +\[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. -[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. +\[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. -[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations. +\[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations. -[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) +\[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) -[78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations. +\[78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations. -[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations. +\[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations. -[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019. +\[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019. -[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). +\[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). -[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385. +\[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385. -[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661. +\[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661. -[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations. +\[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations. diff --git a/ot/backend.py b/ot/backend.py index 549ce43c7..64b5a88cf 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -729,16 +729,6 @@ def stack(self, arrays, axis=0): """ raise NotImplementedError() - def unstack(self, arrays, axis=0): - r""" - Split an array into a sequence of arrays along the given axis. - - This function follows the api from :any:`numpy.unstack` - - See: https://numpy.org/doc/stable/reference/generated/numpy.unstack.html - """ - raise NotImplementedError() - def outer(self, a, b): r""" Computes the outer product between two vectors. @@ -1310,9 +1300,6 @@ def unique(self, a, return_inverse=False): def stack(self, arrays, axis=0): return np.stack(arrays, axis) - def unstack(self, arrays, axis=0): - return np.unstack(arrays, axis=axis) - def reshape(self, a, shape): return np.reshape(a, shape) @@ -1723,9 +1710,6 @@ def unique(self, a, return_inverse=False): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) - def unstack(self, arrays, axis=0): - return jnp.unstack(arrays, axis=axis) - def reshape(self, a, shape): return jnp.reshape(a, shape) @@ -2229,9 +2213,6 @@ def logsumexp(self, a, axis=None, keepdims=False): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) - def unstack(self, arrays, axis=0): - return torch.unbind(arrays, dim=axis) - def reshape(self, a, shape): return torch.reshape(a, shape) diff --git a/ot/sliced.py b/ot/sliced.py index 6a76c7fee..3d5a98f84 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -747,6 +747,7 @@ def sliced_plans( """ X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) if backend is None else backend assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -761,7 +762,6 @@ def sliced_plans( d = X.shape[1] n = X.shape[0] m = Y.shape[0] - nx = get_backend(X, Y) if backend is None else backend is_perm = False if n == m: @@ -772,8 +772,9 @@ def sliced_plans( if do_draw_thetas: # create thetas (n_proj, d) assert n_proj is not None, "n_proj must be specified if thetas is None" thetas = get_random_projections(d, n_proj, backend=nx).T + if warm_theta is not None: - thetas = nx.concatenate([thetas, warm_theta[:, None]], axis=1) + thetas = nx.concatenate([thetas, warm_theta[:, None].T], axis=0) else: n_proj = thetas.shape[0] @@ -808,8 +809,9 @@ def sliced_plans( + "from the following list: " + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) + a = nx.ones(n) / n plan = [ - nx.coo_matrix(np.ones(n) / n, sigma[:, k], tau[:, k], shape=(n, m)) + nx.coo_matrix(a, sigma[:, k], tau[:, k], shape=(n, m), type_as=a) for k in range(n_proj) ] @@ -949,8 +951,8 @@ def min_pivot_sliced( Examples -------- - >>> x=np.array([[3,3], [1,1]]) - >>> y=np.array([[2,2.5], [3,2]]) + >>> x=np.array([[3.,3.], [1.,1.]]) + >>> y=np.array([[2.,2.5], [3.,2.]]) >>> thetas=np.array([[1, 0], [0, 1]]) >>> plan, cost = ot.expected_sliced(x, y, thetas) >>> plan @@ -961,6 +963,7 @@ def min_pivot_sliced( """ X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) if backend is None else backend assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -1088,8 +1091,8 @@ def expected_sliced( Examples -------- - >>> x=np.array([[3,3], [1,1]]) - >>> y=np.array([[2,2.5], [3,2]]) + >>> x=np.array([[3.,3.], [1.,1.]]) + >>> y=np.array([[2.,2.5], [3.,2.]]) >>> thetas=np.array([[1, 0], [0, 1]]) >>> plan, cost = ot.expected_sliced(x, y, thetas) >>> plan @@ -1100,6 +1103,7 @@ def expected_sliced( """ X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) if backend is None else backend assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -1108,8 +1112,6 @@ def expected_sliced( X.shape[1] == Y.shape[1] ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" - nx = get_backend(X, Y) if backend is None else backend - if str(nx) in ["tf", "jax"]: raise NotImplementedError( f"expected_sliced is not implemented for the {str(nx)} backend due" @@ -1141,7 +1143,7 @@ def expected_sliced( weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))]) X_idx = nx.concatenate([G[i].row for i in range(len(G))]) Y_idx = nx.concatenate([G[i].col for i in range(len(G))]) - plan = nx.coo_matrix(weights, X_idx, Y_idx, shape=(n, m)) + plan = nx.coo_matrix(weights, X_idx, Y_idx, shape=(n, m), type_as=weights) if beta == 0.0: # otherwise already computed above cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum() diff --git a/test/test_sliced.py b/test/test_sliced.py index 39a26e9b9..24e5d116e 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -777,8 +777,19 @@ def test_sliced_plans(nx): x_b, y_b = nx.from_numpy(x, y) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + print("et là ???", thetas.shape) thetas_b = nx.from_numpy(thetas) + # test with the minkowski metric + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski") + + # test with an unsupported metric + with pytest.raises(ValueError): + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis") + + # test with a warm theta + ot.sliced.min_pivot_sliced(x, y, n_proj=10, warm_theta=thetas[-1]) + # test with a and b uniform plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True) plan_b, _, _ = ot.sliced.sliced_plans( From 85056c6af3ba2bc51768b1938f7e8292cc6b6881 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Wed, 8 Oct 2025 12:08:29 +0200 Subject: [PATCH 13/19] update tests and doc --- ot/sliced.py | 31 +++++++++++++++++++------------ test/test_sliced.py | 4 +++- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/ot/sliced.py b/ot/sliced.py index 3d5a98f84..90b04c6ee 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -746,7 +746,6 @@ def sliced_plans( Returned only if `log` is True. """ - X, Y = list_to_array(X, Y) nx = get_backend(X, Y) if backend is None else backend assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -903,9 +902,9 @@ def min_pivot_sliced( The first set of vectors. Y : array-like, shape (m, d) The second set of vectors. - a : ndarray of float64, shape (ns,), optional + a : ndarray of float64, shape (n,), optional Source histogram (default is uniform weight) - b : ndarray of float64, shape (nt,), optional + b : ndarray of float64, shape (m,), optional Target histogram (default is uniform weight) thetas : array-like, shape (n_proj, d), optional The projection directions. If None, random directions will be generated @@ -962,7 +961,6 @@ def min_pivot_sliced( 2.125 """ - X, Y = list_to_array(X, Y) nx = get_backend(X, Y) if backend is None else backend assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -987,7 +985,7 @@ def min_pivot_sliced( log=True, backend=nx, ) - pos_min = np.argmin(costs) + pos_min = nx.argmin(costs) cost = costs[pos_min] plan = G[pos_min] @@ -1050,23 +1048,33 @@ def expected_sliced( Parameters ---------- - X : torch.Tensor - A tensor of shape (n, d) representing the first set of vectors. - Y : torch.Tensor - A tensor of shape (m, d) representing the second set of vectors. + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (m, d) + The second set of vectors. + a : ndarray of float64, shape (n,), optional + Source histogram (default is uniform weight) + b : ndarray of float64, shape (m,), optional + Target histogram (default is uniform weight) thetas : torch.Tensor, optional A tensor of shape (n_proj, d) representing the projection directions. If None, random directions will be generated. Default is None. + metric: str, optional (default='sqeuclidean') + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. + p: float, optional (default=2) + The p-norm to apply for if metric='minkowski' n_proj : int, optional The number of projection directions. Required if thetas is None. - order : int, optional - Power to elevate the norm. Default is 2. dense: boolean, optional (default=True) If True, returns :math:`\gamma` as a dense ndarray of shape (n, m). Otherwise returns a sparse representation using scipy's `coo_matrix` format. log : bool, optional If True, returns additional logging information. Default is False. + backend : ot.backend, optional + Backend to use for computations. If None, the backend is inferred from + the input arrays. Default is None. beta : float, optional Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). @@ -1102,7 +1110,6 @@ def expected_sliced( 2.625 """ - X, Y = list_to_array(X, Y) nx = get_backend(X, Y) if backend is None else backend assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" diff --git a/test/test_sliced.py b/test/test_sliced.py index 24e5d116e..07cf0911b 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -776,8 +776,10 @@ def test_sliced_plans(nx): b /= b.sum() x_b, y_b = nx.from_numpy(x, y) + print(x_b) + t_X = torch.tensor(x_b) + t_Y = torch.tensor(y_b) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - print("et là ???", thetas.shape) thetas_b = nx.from_numpy(thetas) # test with the minkowski metric From 75615300d736e3747ceed5d8a58f220a6b3209c0 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Wed, 8 Oct 2025 15:11:37 +0200 Subject: [PATCH 14/19] update tests with backend --- ot/sliced.py | 52 ++++++++----- test/test_sliced.py | 181 +++++++++++++++++++++++++++----------------- 2 files changed, 146 insertions(+), 87 deletions(-) diff --git a/ot/sliced.py b/ot/sliced.py index 90b04c6ee..528cfa315 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -695,7 +695,6 @@ def sliced_plans( n_proj=None, dense=False, log=False, - backend=None, ): r""" Computes all the permutations that sort the projections of two `(n, d)` @@ -731,9 +730,6 @@ def sliced_plans( The number of projection directions. Required if thetas is None. log : bool, optional If True, returns additional logging information. Default is False. - backend : ot.backend, optional - Backend to use for computations. If None, the backend is inferred from - the input arrays. Default is None. Returns ------- @@ -746,7 +742,17 @@ def sliced_plans( Returned only if `log` is True. """ - nx = get_backend(X, Y) if backend is None else backend + X, Y = list_to_array(X, Y) + + if a is not None and b is not None and thetas is None: + nx = get_backend(X, Y, a, b) + elif a is not None and b is not None and thetas is not None: + nx = get_backend(X, Y, a, b, thetas) + elif a is None and b is None and thetas is not None: + nx = get_backend(X, Y, thetas) + else: + nx = get_backend(X, Y) + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -870,7 +876,6 @@ def min_pivot_sliced( dense=True, log=False, warm_theta=None, - backend=None, ): r""" Computes the cost and permutation associated to the min-Pivot Sliced @@ -924,9 +929,6 @@ def min_pivot_sliced( If True, returns additional logging information. Default is False. warm_theta : array-like, shape (d,), optional A theta to add to the list of thetas. Default is None. - backend : ot.backend, optional - Backend to use for computations. If None, the backend is inferred from - the input arrays. Default is None. Returns ------- @@ -961,7 +963,17 @@ def min_pivot_sliced( 2.125 """ - nx = get_backend(X, Y) if backend is None else backend + X, Y = list_to_array(X, Y) + + if a is not None and b is not None and thetas is None: + nx = get_backend(X, Y, a, b) + elif a is not None and b is not None and thetas is not None: + nx = get_backend(X, Y, a, b, thetas) + elif a is None and b is None and thetas is not None: + nx = get_backend(X, Y, thetas) + else: + nx = get_backend(X, Y) + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -969,8 +981,6 @@ def min_pivot_sliced( X.shape[1] == Y.shape[1] ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" - nx = get_backend(X, Y) if backend is None else backend - log_dict = {} G, costs, log_dict_plans = sliced_plans( X, @@ -983,7 +993,6 @@ def min_pivot_sliced( n_proj=n_proj, warm_theta=warm_theta, log=True, - backend=nx, ) pos_min = nx.argmin(costs) cost = costs[pos_min] @@ -1024,7 +1033,6 @@ def expected_sliced( n_proj=None, dense=True, log=False, - backend=None, beta=0.0, ): r""" @@ -1072,9 +1080,6 @@ def expected_sliced( format. log : bool, optional If True, returns additional logging information. Default is False. - backend : ot.backend, optional - Backend to use for computations. If None, the backend is inferred from - the input arrays. Default is None. beta : float, optional Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). @@ -1110,7 +1115,16 @@ def expected_sliced( 2.625 """ - nx = get_backend(X, Y) if backend is None else backend + X, Y = list_to_array(X, Y) + + if a is not None and b is not None and thetas is None: + nx = get_backend(X, Y, a, b) + elif a is not None and b is not None and thetas is not None: + nx = get_backend(X, Y, a, b, thetas) + elif a is None and b is None and thetas is not None: + nx = get_backend(X, Y, thetas) + else: + nx = get_backend(X, Y) assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" @@ -1130,7 +1144,7 @@ def expected_sliced( log_dict = {} G, costs, log_dict_plans = sliced_plans( - X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, backend=nx + X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True ) if log: log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G} diff --git a/test/test_sliced.py b/test/test_sliced.py index 07cf0911b..7b177aee4 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -729,7 +729,7 @@ def test_linear_sliced_sphere_backend_type_devices(nx): np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) -def test_sliced_permutations(nx): +def test_sliced_permutations(): n = 4 n_proj = 10 d = 2 @@ -738,15 +738,7 @@ def test_sliced_permutations(nx): x = rng.randn(n, 2) y = rng.randn(n, 2) - x_b, y_b = nx.from_numpy(x, y) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) - - plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True) - plan_b, _, _ = ot.sliced.sliced_plans( - x_b, y_b, thetas=thetas_b, log=True, dense=True, backend=nx - ) - np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b)) # test without provided thetas _, _ = ot.sliced.sliced_plans(x, y, n_proj=n_proj) @@ -756,7 +748,7 @@ def test_sliced_permutations(nx): ot.sliced.sliced_plans(x[:, 1:], y, thetas=thetas) -def test_sliced_plans(nx): +def test_sliced_plans(): x = [1, 2] with pytest.raises(AssertionError): ot.sliced.min_pivot_sliced(x, x, n_proj=2) @@ -775,39 +767,23 @@ def test_sliced_plans(nx): b = rng.uniform(0, 1, m) b /= b.sum() - x_b, y_b = nx.from_numpy(x, y) - print(x_b) - t_X = torch.tensor(x_b) - t_Y = torch.tensor(y_b) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) + + # test with a and b not uniform + ot.sliced.sliced_plans(x, y, a, b, thetas=thetas, dense=True) # test with the minkowski metric - ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski") + ot.sliced.sliced_plans(x, y, thetas=thetas, metric="minkowski") # test with an unsupported metric with pytest.raises(ValueError): - ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis") + ot.sliced.sliced_plans(x, y, thetas=thetas, metric="mahalanobis") # test with a warm theta - ot.sliced.min_pivot_sliced(x, y, n_proj=10, warm_theta=thetas[-1]) + ot.sliced.sliced_plans(x, y, n_proj=10, warm_theta=thetas[-1]) - # test with a and b uniform - plan, _ = ot.sliced.sliced_plans(x, y, thetas=thetas, dense=True) - plan_b, _, _ = ot.sliced.sliced_plans( - x_b, y_b, thetas=thetas_b, log=True, dense=True, backend=nx - ) - np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b)) - # test with a and b not uniform - plan, _ = ot.sliced.sliced_plans(x, y, a, b, thetas=thetas, dense=True) - plan_b, _, _ = ot.sliced.sliced_plans( - x_b, y_b, a, b, thetas=thetas_b, log=True, dense=True, backend=nx - ) - np.testing.assert_almost_equal(plan, nx.to_numpy(plan_b)) - - -def test_min_pivot_sliced(nx): +def test_min_pivot_sliced(): x = [1, 2] with pytest.raises(AssertionError): ot.sliced.min_pivot_sliced(x, x, n_proj=2) @@ -825,17 +801,13 @@ def test_min_pivot_sliced(nx): b = rng.uniform(0, 1, m) b /= b.sum() - x_b, y_b = nx.from_numpy(x, y) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) - G, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) - G_b, min_cost_b, _ = ot.sliced.min_pivot_sliced( - x_b, y_b, a, b, thetas=thetas_b, log=True, dense=True - ) + # identity of the indiscernibles + _, min_cost = ot.min_pivot_sliced(x, x, a, a, n_proj=10) + np.testing.assert_almost_equal(min_cost, 0.0) - np.testing.assert_almost_equal(G, nx.to_numpy(G_b)) - np.testing.assert_almost_equal(min_cost, nx.to_numpy(min_cost_b)) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) # result should be an upper-bound of W2 and relatively close w2 = ot.emd2(a, b, ot.dist(x, y)) @@ -849,8 +821,30 @@ def test_min_pivot_sliced(nx): with pytest.raises(AssertionError): ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) + # test the logs + _, min_cost, log = ot.sliced.min_pivot_sliced( + x, y, a, b, thetas=thetas, dense=False, log=True + ) + assert len(log) == 5 + costs = log["costs"] + assert len(costs) == thetas.shape[0] + assert len(log["min_theta"]) == d + assert (log["thetas"] == thetas).all() + for c in costs: + assert c > 0 + + # test with the minkowski metric + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski") + + # test with an unsupported metric + with pytest.raises(ValueError): + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis") -def test_expected_sliced(nx): + # test with a warm theta + ot.sliced.min_pivot_sliced(x, y, n_proj=10, warm_theta=thetas[-1]) + + +def test_expected_sliced(): x = [1, 2] with pytest.raises(AssertionError): ot.sliced.min_pivot_sliced(x, x, n_proj=2) @@ -868,9 +862,67 @@ def test_expected_sliced(nx): b = rng.uniform(0, 1, m) b /= b.sum() - x_b, y_b = nx.from_numpy(x, y) thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) + + _, expected_cost = ot.sliced.expected_sliced(x, y, a, b, dense=True, thetas=thetas) + # result should be a coarse upper-bound of W2 + w2 = ot.emd2(a, b, ot.dist(x, y)) + assert expected_cost >= w2 + assert expected_cost <= 3 * w2 + + # test without provided thetas + ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) + + # with a small temperature (i.e. large beta), the cost should be close + # to min_pivot + _, expected_cost = ot.sliced.expected_sliced( + x, y, a, b, thetas=thetas, dense=True, beta=100.0 + ) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) + np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) + + # test the logs + _, min_cost, log = ot.sliced.expected_sliced( + x, y, a, b, thetas=thetas, dense=False, log=True + ) + assert len(log) == 4 + costs = log["costs"] + assert len(costs) == thetas.shape[0] + assert len(log["weights"]) == thetas.shape[0] + assert (log["thetas"] == thetas).all() + for c in costs: + assert c > 0 + + # test with the minkowski metric + ot.sliced.expected_sliced(x, y, thetas=thetas, metric="minkowski") + + # test with an unsupported metric + with pytest.raises(ValueError): + ot.sliced.expected_sliced(x, y, thetas=thetas, metric="mahalanobis") + + +def test_sliced_plans_backends(nx): + n = 10 + m = 24 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(m, 2) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() + + x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b) + + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T + thetas_t = nx.to_numpy(thetas) context = ( nullcontext() @@ -879,32 +931,25 @@ def test_expected_sliced(nx): ) with context: - expected_plan, expected_cost = ot.sliced.expected_sliced( - x, y, a, b, dense=True, thetas=thetas + _, expected_cost_b = ot.sliced.expected_sliced( + x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t ) - expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( - x_b, y_b, a, b, thetas=thetas_b, dense=True, log=True + # result should be the same than numpy version + _, expected_cost = ot.sliced.expected_sliced( + x, y, a, b, dense=True, thetas=thetas ) + np.testing.assert_almost_equal(expected_cost_b, expected_cost) - np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) - np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) - - # result should be a coarse upper-bound of W2 - w2 = ot.emd2(a, b, ot.dist(x, y)) - assert expected_cost >= w2 - assert expected_cost <= 3 * w2 - - # test without provided thetas - ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + # for min_pivot + _, min_cost_b = ot.sliced.min_pivot_sliced( + x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t + ) + # result should be the same than numpy version + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, dense=True, thetas=thetas) + np.testing.assert_almost_equal(min_cost_b, min_cost) - # test with invalid shapes - with pytest.raises(AssertionError): - ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) + # for sliced_plans + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T - # with a small temperature (i.e. large beta), the cost should be close - # to min_pivot - _, expected_cost = ot.sliced.expected_sliced( - x, y, a, b, thetas=thetas, dense=True, beta=100.0 - ) - _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) - np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) + # test with the minkowski metric + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski") From 2f4b6759933b51dc1cb5cd73387dde1023071840 Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Wed, 8 Oct 2025 16:17:19 +0200 Subject: [PATCH 15/19] update tests with backend and improve code coverage --- test/test_sliced.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/test_sliced.py b/test/test_sliced.py index 7b177aee4..cb6e895f0 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -833,8 +833,10 @@ def test_min_pivot_sliced(): for c in costs: assert c > 0 - # test with the minkowski metric + # test with different metrics ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski") + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="euclidean") + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="cityblock") # test with an unsupported metric with pytest.raises(ValueError): @@ -872,6 +874,7 @@ def test_expected_sliced(): # test without provided thetas ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + ot.sliced.expected_sliced(x, y, a, b, n_proj=n_proj, log=True) # test with invalid shapes with pytest.raises(AssertionError): @@ -921,8 +924,8 @@ def test_sliced_plans_backends(nx): x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b) - thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T - thetas_t = nx.to_numpy(thetas) + thetas_b = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T + thetas = nx.to_numpy(thetas_b) context = ( nullcontext() @@ -932,7 +935,7 @@ def test_sliced_plans_backends(nx): with context: _, expected_cost_b = ot.sliced.expected_sliced( - x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t + x_b, y_b, a_b, b_b, dense=True, thetas=thetas_b ) # result should be the same than numpy version _, expected_cost = ot.sliced.expected_sliced( @@ -942,7 +945,7 @@ def test_sliced_plans_backends(nx): # for min_pivot _, min_cost_b = ot.sliced.min_pivot_sliced( - x_b, y_b, a_b, b_b, dense=True, thetas=thetas_t + x_b, y_b, a_b, b_b, dense=True, thetas=thetas_b ) # result should be the same than numpy version _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, dense=True, thetas=thetas) From 15e0d7183c208e6258782ee534ea2ae4ae50c864 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 10 Oct 2025 14:29:06 +0200 Subject: [PATCH 16/19] PR number, authot update, and small backend fix --- .coveragerc | 8 ++++++++ RELEASES.md | 2 +- coverage_help.md | 4 ++++ debug.py | 32 ++++++++++++++++++++++++++++++++ ot/sliced.py | 3 ++- test/test_sliced.py | 1 + 6 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 .coveragerc create mode 100644 coverage_help.md create mode 100644 debug.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..d883888a6 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +omit = + /tmp/* + */_remote_module_non_scriptable.py + */site-packages/* + +[report] +skip_covered = True \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index eee4257af..ad1c895fe 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,7 +4,7 @@ #### New features -- Added Sliced OT plans (PR #757) +- Added Sliced OT plans (PR #767) ## 0.9.6.post1 diff --git a/coverage_help.md b/coverage_help.md new file mode 100644 index 000000000..3ed518bff --- /dev/null +++ b/coverage_help.md @@ -0,0 +1,4 @@ +Example: + + coverage run -m pytest test/test_ot.py + coverage html --rcfile=.coveragerc \ No newline at end of file diff --git a/debug.py b/debug.py new file mode 100644 index 000000000..e4eb86527 --- /dev/null +++ b/debug.py @@ -0,0 +1,32 @@ +# %% +import numpy as np +from ot.backend import get_backend +import torch +import ot +from torch.optim import Adam + + +# %% +rng = np.random.RandomState(0) +n = 10 +d = 2 +X = rng.randn(n, d) +Y = rng.randn(n, d) + np.array([5.0, 0.0])[None, :] +n_proj = 20 +P = ot.sliced.get_random_projections(d, n_proj) +a = rng.uniform(0, 1, n) +a /= a.sum() +b = rng.uniform(0, 1, n) +b /= b.sum() +sw2 = ot.sliced.sliced_wasserstein_distance(X, Y, a=a, b=b, projections=P) + +# %% +nx = get_backend(torch.tensor([0.0])) +X_t = nx.from_numpy(X) +Y_t = nx.from_numpy(Y) +a_t = nx.from_numpy(a) +b_t = nx.from_numpy(b) +P_t = nx.from_numpy(P) +sw2_t = ot.sliced.sliced_wasserstein_distance(X_t, Y_t, a=a_t, b=b_t, projections=P_t) + +# %% diff --git a/ot/sliced.py b/ot/sliced.py index 528cfa315..de7158190 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -6,6 +6,7 @@ # Nicolas Courty # Rémi Flamary # Eloi Tanguy +# Laetitia Chapel # # License: MIT License @@ -776,7 +777,7 @@ def sliced_plans( do_draw_thetas = thetas is None if do_draw_thetas: # create thetas (n_proj, d) assert n_proj is not None, "n_proj must be specified if thetas is None" - thetas = get_random_projections(d, n_proj, backend=nx).T + thetas = get_random_projections(d, n_proj, backend=nx, type_as=X).T if warm_theta is not None: thetas = nx.concatenate([thetas, warm_theta[:, None].T], axis=0) diff --git a/test/test_sliced.py b/test/test_sliced.py index cb6e895f0..0c1597e56 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -3,6 +3,7 @@ # Author: Adrien Corenflos # Nicolas Courty # Eloi Tanguy +# Laetitia Chapel # # License: MIT License From aee78fce04f201978a9d8e2b779092b67e932da2 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 10 Oct 2025 14:29:47 +0200 Subject: [PATCH 17/19] deleted debug files --- .coveragerc | 8 -------- coverage_help.md | 4 ---- debug.py | 32 -------------------------------- 3 files changed, 44 deletions(-) delete mode 100644 .coveragerc delete mode 100644 coverage_help.md delete mode 100644 debug.py diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index d883888a6..000000000 --- a/.coveragerc +++ /dev/null @@ -1,8 +0,0 @@ -[run] -omit = - /tmp/* - */_remote_module_non_scriptable.py - */site-packages/* - -[report] -skip_covered = True \ No newline at end of file diff --git a/coverage_help.md b/coverage_help.md deleted file mode 100644 index 3ed518bff..000000000 --- a/coverage_help.md +++ /dev/null @@ -1,4 +0,0 @@ -Example: - - coverage run -m pytest test/test_ot.py - coverage html --rcfile=.coveragerc \ No newline at end of file diff --git a/debug.py b/debug.py deleted file mode 100644 index e4eb86527..000000000 --- a/debug.py +++ /dev/null @@ -1,32 +0,0 @@ -# %% -import numpy as np -from ot.backend import get_backend -import torch -import ot -from torch.optim import Adam - - -# %% -rng = np.random.RandomState(0) -n = 10 -d = 2 -X = rng.randn(n, d) -Y = rng.randn(n, d) + np.array([5.0, 0.0])[None, :] -n_proj = 20 -P = ot.sliced.get_random_projections(d, n_proj) -a = rng.uniform(0, 1, n) -a /= a.sum() -b = rng.uniform(0, 1, n) -b /= b.sum() -sw2 = ot.sliced.sliced_wasserstein_distance(X, Y, a=a, b=b, projections=P) - -# %% -nx = get_backend(torch.tensor([0.0])) -X_t = nx.from_numpy(X) -Y_t = nx.from_numpy(Y) -a_t = nx.from_numpy(a) -b_t = nx.from_numpy(b) -P_t = nx.from_numpy(P) -sw2_t = ot.sliced.sliced_wasserstein_distance(X_t, Y_t, a=a_t, b=b_t, projections=P_t) - -# %% From dd1b31f454f1de488d81fe9a4c9b1b803cf951e0 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 10 Oct 2025 14:48:11 +0200 Subject: [PATCH 18/19] dense computation of costs for sliced_plans with jax --- ot/sliced.py | 94 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 25 deletions(-) diff --git a/ot/sliced.py b/ot/sliced.py index de7158190..01fb52cd7 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -821,42 +821,86 @@ def sliced_plans( for k in range(n_proj) ] + if not dense and str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + plan = [nx.todense(plan[k]) for k in range(n_proj)] + else: # we compute plans _, plan = wasserstein_1d( X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True ) - if metric in ("minkowski", "euclidean", "cityblock"): - costs = [ - nx.sum( - ( - (nx.sum(nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1)) - ** (1 / p) + if str(nx) == "jax": # dense computation + if not dense: + warnings.warn( + "JAX does not support sparse matrices, converting to dense" + ) + + plan = [nx.todense(plan[k]) for k in range(n_proj)] + + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ( + ( + nx.sum( + nx.abs(X[:, None, :] - Y[None, :, :]) ** p, axis=-1 + ) + ) + ** (1 / p) + ) + * plan[k].data + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum( + (nx.sum((X[:, None, :] - Y[None, :, :]) ** 2, axis=-1)) + * plan[k].data ) - * plan[k].data + for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) - for k in range(n_proj) - ] - elif metric == "sqeuclidean": - costs = [ - nx.sum( - (nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1)) - * plan[k].data + + else: # not jax, sparse computation + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ( + ( + nx.sum( + nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1 + ) + ) + ** (1 / p) + ) + * plan[k].data + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum( + (nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1)) + * plan[k].data + ) + for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) - for k in range(n_proj) - ] - else: - raise ValueError( - "Sliced plans work only with metrics " - + "from the following list: " - + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" - ) if dense: plan = [nx.todense(plan[k]) for k in range(n_proj)] - elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to dense") - plan = [nx.todense(plan[k]) for k in range(n_proj)] if log: log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas} From 652b49d0aad3bd3c25df8299e638739f2c6d7488 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 10 Oct 2025 15:05:36 +0200 Subject: [PATCH 19/19] jax .data fix + backend typing fix in sliced_plans test function --- ot/sliced.py | 4 ++-- test/test_sliced.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ot/sliced.py b/ot/sliced.py index 01fb52cd7..4d31c6119 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -880,7 +880,7 @@ def sliced_plans( ) ** (1 / p) ) - * plan[k].data + * plan[k] ) for k in range(n_proj) ] @@ -888,7 +888,7 @@ def sliced_plans( costs = [ nx.sum( (nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1)) - * plan[k].data + * plan[k] ) for k in range(n_proj) ] diff --git a/test/test_sliced.py b/test/test_sliced.py index 0c1597e56..ddcb3fe1d 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -925,7 +925,9 @@ def test_sliced_plans_backends(nx): x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b) - thetas_b = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T + thetas_b = ot.sliced.get_random_projections( + d, n_proj, seed=0, backend=nx, type_as=x_b + ).T thetas = nx.to_numpy(thetas_b) context = (