diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 94486046f..b6add7ae3 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -1,11 +1,8 @@ -Contributing to POT -=================== +# Contributing to POT +First off, thank you for considering contributing to POT. -First off, thank you for considering contributing to POT. - -How to contribute ------------------ +## How to contribute The preferred workflow for contributing to POT is to fork the [main repository](https://github.com/rflamary/POT) on @@ -23,7 +20,7 @@ GitHub, clone, and develop on a branch. Steps: $ cd POT ``` -2. Install pre-commit hooks to ensure that your code is properly formatted: +3. Install pre-commit hooks to ensure that your code is properly formatted: ```bash $ pip install pre-commit @@ -32,15 +29,48 @@ GitHub, clone, and develop on a branch. Steps: This will install the pre-commit hooks that will run on every commit. If the hooks fail, the commit will be aborted. -3. Create a ``feature`` branch to hold your development changes: +4. Create a `feature` branch to hold your development changes: ```bash $ git checkout -b my-feature ``` - Always use a ``feature`` branch. It's good practice to never work on the ``master`` branch! + Always use a `feature` branch. It's good practice to never work on the `master` branch! + +5. Install a recent version of Python (e.g. 3.10), using conda for instance. You can create a conda environment and activate it: + + ```bash + $ conda create -n dev-pot-env python=3.10 + $ conda activate dev-pot-env + ``` + +6. Install all the necessary packages in your environment: -4. Develop the feature on your feature branch. Add changed files using ``git add`` and then ``git commit`` files: +```bash +$ pip install -r requirements_all.txt +``` + +6. Install a compiler with OpenMP support for your platform (see details on the [scikit-learn contributing guide](https://scikit-learn.org/stable/developers/advanced_installation.html#platform-specific-instructions)). + For instance, with macOS, Apple clang does not support OpenMP. One can install the LLVM OpenMP library from homebrew: + + ```bash + $ brew install libomp + ``` + + and set environment variables: + + ```bash + $ export CC=/usr/local/opt/llvm/bin/clang + $ export CXX=/usr/local/opt/llvm/bin/clang++ + ``` + +7. Build the projet with pip: + + ```bash + pip install -e . + ``` + +8. Develop the feature on your feature branch. Add changed files using `git add` and then `git commit` files: ```bash $ git add modified_files @@ -53,64 +83,62 @@ GitHub, clone, and develop on a branch. Steps: $ git push -u origin my-feature ``` -5. Follow [these instructions](https://help.github.com/articles/creating-a-pull-request-from-a-fork) -to create a pull request from your fork. This will send an email to the committers. +9. Follow [these instructions](https://help.github.com/articles/creating-a-pull-request-from-a-fork) + to create a pull request from your fork. This will send an email to the committers. (If any of the above seems like magic to you, please look up the [Git documentation](https://git-scm.com/documentation) on the web, or ask a friend or another contributor for help.) -Pull Request Checklist ----------------------- +## Pull Request Checklist We recommended that your contribution complies with the following rules before you submit a pull request: -- Follow the PEP8 Guidelines which should be handles automatically by pre-commit. - -- If your pull request addresses an issue, please use the pull request title - to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is - created. - -- All public methods should have informative docstrings with sample - usage presented as doctests when appropriate. - -- Please prefix the title of your pull request with `[MRG]` (Ready for - Merge), if the contribution is complete and ready for a detailed review. - Two core developers will review your code and change the prefix of the pull - request to `[MRG + 1]` and `[MRG + 2]` on approval, making it eligible - for merging. An incomplete contribution -- where you expect to do more work before - receiving a full review -- should be prefixed `[WIP]` (to indicate a work - in progress) and changed to `[MRG]` when it matures. WIPs may be useful - to: indicate you are working on something to avoid duplicated work, - request broad review of functionality or API, or seek collaborators. - WIPs often benefit from the inclusion of a - [task list](https://github.com/blog/1375-task-lists-in-gfm-issues-pulls-comments) - in the PR description. - - -- When adding additional functionality, provide at least one - example script in the ``examples/`` folder. Have a look at other - examples for reference. Examples should demonstrate why the new - functionality is useful in practice and, if possible, compare it - to other methods available in POT. - -- Documentation and high-coverage tests are necessary for enhancements to be - accepted. Bug-fixes or new features should be provided with - [non-regression tests](https://en.wikipedia.org/wiki/Non-regression_testing). - These tests verify the correct behavior of the fix or feature. In this - manner, further modifications on the code base are granted to be consistent - with the desired behavior. - For the Bug-fixes case, at the time of the PR, this tests should fail for - the code base in master and pass for the PR code. - -- At least one paragraph of narrative documentation with links to - references in the literature (with PDF links when possible) and - the example. +* Follow the PEP8 Guidelines which should be handles automatically by pre-commit. + +* If your pull request addresses an issue, please use the pull request title + to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is + created. + +* All public methods should have informative docstrings with sample + usage presented as doctests when appropriate. + +* Please prefix the title of your pull request with `[MRG]` (Ready for + Merge), if the contribution is complete and ready for a detailed review. + Two core developers will review your code and change the prefix of the pull + request to `[MRG + 1]` and `[MRG + 2]` on approval, making it eligible + for merging. An incomplete contribution -- where you expect to do more work before + receiving a full review -- should be prefixed `[WIP]` (to indicate a work + in progress) and changed to `[MRG]` when it matures. WIPs may be useful + to: indicate you are working on something to avoid duplicated work, + request broad review of functionality or API, or seek collaborators. + WIPs often benefit from the inclusion of a + [task list](https://github.com/blog/1375-task-lists-in-gfm-issues-pulls-comments) + in the PR description. + +* When adding additional functionality, provide at least one + example script in the `examples/` folder. Have a look at other + examples for reference. Examples should demonstrate why the new + functionality is useful in practice and, if possible, compare it + to other methods available in POT. + +* Documentation and high-coverage tests are necessary for enhancements to be + accepted. Bug-fixes or new features should be provided with + [non-regression tests](https://en.wikipedia.org/wiki/Non-regression_testing). + These tests verify the correct behavior of the fix or feature. In this + manner, further modifications on the code base are granted to be consistent + with the desired behavior. + For the Bug-fixes case, at the time of the PR, this tests should fail for + the code base in master and pass for the PR code. + +* At least one paragraph of narrative documentation with links to + references in the literature (with PDF links when possible) and + the example. You can also check for common programming errors with the following tools: -- All lint checks pass. You can run the following command to check: +* All lint checks pass. You can run the following command to check: ```bash $ pre-commit run --all-files @@ -118,52 +146,51 @@ tools: This will run the pre-commit checks on all files in the repository. -- All tests pass. You can run the following command to check: +* All tests pass. You can run the following command to check: ```bash $ pytest --durations=20 -v test/ --doctest-modules - ``` + ``` Bonus points for contributions that include a performance analysis with a benchmark script and profiling output (please report on the mailing list or on the GitHub issue). -Filing bugs ------------ +## Filing bugs + We use Github issues to track all bugs and feature requests; feel free to open an issue if you have found a bug or wish to see a feature implemented. It is recommended to check that your issue complies with the following rules before submitting: -- Verify that your issue is not being currently addressed by other - [issues](https://github.com/rflamary/POT/issues?q=) - or [pull requests](https://github.com/rflamary/POT/pulls?q=). +* Verify that your issue is not being currently addressed by other + [issues](https://github.com/rflamary/POT/issues?q=) + or [pull requests](https://github.com/rflamary/POT/pulls?q=). -- Please ensure all code snippets and error messages are formatted in - appropriate code blocks. - See [Creating and highlighting code blocks](https://help.github.com/articles/creating-and-highlighting-code-blocks). +* Please ensure all code snippets and error messages are formatted in + appropriate code blocks. + See [Creating and highlighting code blocks](https://help.github.com/articles/creating-and-highlighting-code-blocks). -- Please include your operating system type and version number, as well - as your Python, POT, numpy, and scipy versions. This information - can be found by running the following code snippet: +* Please include your operating system type and version number, as well + as your Python, POT, numpy, and scipy versions. This information + can be found by running the following code snippet: - ```python - import platform; print(platform.platform()) - import sys; print("Python", sys.version) - import numpy; print("NumPy", numpy.__version__) - import scipy; print("SciPy", scipy.__version__) - import ot; print("POT", ot.__version__) - ``` +```python +import platform; print(platform.platform()) +import sys; print("Python", sys.version) +import numpy; print("NumPy", numpy.__version__) +import scipy; print("SciPy", scipy.__version__) +import ot; print("POT", ot.__version__) +``` -- Please be specific about what estimators and/or functions are involved - and the shape of the data, as appropriate; please include a - [reproducible](http://stackoverflow.com/help/mcve) code snippet - or link to a [gist](https://gist.github.com). If an exception is raised, - please provide the traceback. +* Please be specific about what estimators and/or functions are involved + and the shape of the data, as appropriate; please include a + [reproducible](http://stackoverflow.com/help/mcve) code snippet + or link to a [gist](https://gist.github.com). If an exception is raised, + please provide the traceback. -New contributor tips --------------------- +## New contributor tips A great way to start contributing to POT is to pick an item from the list of [Easy issues](https://github.com/rflamary/POT/issues?labels=Easy) @@ -173,8 +200,7 @@ assistance in this area will be greatly appreciated by the more experienced developers as it helps free up their time to concentrate on other issues. -Documentation -------------- +## Documentation We are glad to accept any sort of documentation: function docstrings, reStructuredText documents (like this one), tutorials, etc. @@ -182,8 +208,8 @@ reStructuredText documents live in the source code repository under the doc/ directory. You can edit the documentation using any text editor and then generate -the HTML output by typing ``make html`` from the ``docs/`` directory. -Alternatively, ``make`` can be used to quickly generate the +the HTML output by typing `make html` from the `docs/` directory. +Alternatively, `make` can be used to quickly generate the documentation without the example gallery with `make html-noplot`. The resulting HTML files will be placed in `docs/build/html/` and are viewable in a web browser. @@ -199,5 +225,4 @@ start with a small paragraph with a hand-waving explanation of what the method does to the data and a figure (coming from an example) illustrating it. - This Contribution guide is strongly inspired by the one of the [scikit-learn](https://github.com/scikit-learn/scikit-learn) team. diff --git a/README.md b/README.md index f8880a166..f443ec365 100644 --- a/README.md +++ b/README.md @@ -12,77 +12,78 @@ This open source Python library provides several solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning. -Website and documentation: [https://PythonOT.github.io/](https://PythonOT.github.io/) +Website and documentation: Source Code (MIT): -[https://github.com/PythonOT/POT](https://github.com/PythonOT/POT) - + POT has the following main features: + * A large set of differentiable solvers for optimal transport problems, including: - * Exact linear OT, entropic and quadratic regularized OT, - * Gromov-Wasserstein (GW) distances, Fused GW distances and variants of - quadratic OT, - * Unbalanced and partial OT for different divergences, -* OT barycenters (Wasserstein and GW) for fixed and free support, -* Fast OT solvers in 1D, on the circle and between Gaussian Mixture Models (GMMs), -* Many ML related solvers, such as domain adaptation, optimal transport mapping - estimation, subspace learning, Graph Neural Networks (GNNs) layers. -* Several backends for easy use with Pytorch, Jax, Tensorflow, Numpy and Cupy arrays. + * Exact linear OT, entropic and quadratic regularized OT, + * Gromov-Wasserstein (GW) distances, Fused GW distances and variants of + quadratic OT, + * Unbalanced and partial OT for different divergences, +* OT barycenters (Wasserstein and GW) for fixed and free support, +* Fast OT solvers in 1D, on the circle and between Gaussian Mixture Models (GMMs), +* Many ML related solvers, such as domain adaptation, optimal transport mapping + estimation, subspace learning, Graph Neural Networks (GNNs) layers. +* Several backends for easy use with Pytorch, Jax, Tensorflow, Numpy and Cupy arrays. ### Implemented Features POT provides the following generic OT solvers: -* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . -* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. +* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance \[1] . +* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) \[6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT \[7]. * Entropic regularization OT solver with [Sinkhorn Knopp - Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , - stabilized version [9] [10] [34], lazy CPU/GPU solver from geomloss [60] [61], greedy Sinkhorn [22] and Screening - Sinkhorn [26]. -* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. -* Sinkhorn divergence [23] and entropic regularization OT from empirical data. -* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] -* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17]. -* Weak OT solver between empirical distributions [39] -* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12,51]), differentiable using gradients from Graph Dictionary Learning [38] - * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact [24] and regularized [12,51]). + Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) \[2] , + stabilized version \[9] \[10] \[34], lazy CPU/GPU solver from geomloss \[60] \[61], greedy Sinkhorn \[22] and Screening + Sinkhorn \[26]. +* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) \[3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) \[21] and unmixing \[4]. +* Sinkhorn divergence \[23] and entropic regularization OT from empirical data. +* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) \[37] +* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations \[17]. +* Weak OT solver between empirical distributions \[39] +* Non regularized [Wasserstein barycenters \[16\] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale). +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact \[13] and regularized \[12,51]), differentiable using gradients from Graph Dictionary Learning \[38] +* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact \[24] and regularized \[12,51]). * [Stochastic solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for - Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) -* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] -* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] -* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). -* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. + Large-scale Optimal Transport (semi-dual problem \[18] and dual problem \[19]) +* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions \[33] +* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) \[20]. +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) \[10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) \[41] +* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact \[29] and entropic \[3] formulations). +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) \[31, 32] and Max-sliced Wasserstein \[35] that can be used for gradient flows \[36]. * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html) - [44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] -* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. -* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized [48]). -* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. -* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. + \[44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) \[46] +* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) \[38]. +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized \[48]). +* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) \[68]. +* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) \[50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. -* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. -* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/gaussian_gmm/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) [69]. -* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and -[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. -* Fused unbalanced Gromov-Wasserstein [70]. -* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77] -* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77] +* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) \[58], with an extension to bounding potentials using \[59]. +* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/gaussian_gmm/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) \[69]. +* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) \[49] and + [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) \[71]. +* Fused unbalanced Gromov-Wasserstein \[70]. +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) \[77] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) \[69, 77] +* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) \[82, 83, 84] POT provides the following Machine Learning related solvers: * [Optimal transport for domain adaptation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_classes.html) - with [group lasso regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_classes.html), [Laplacian regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_laplacian.html) [5] [30] and [semi + with [group lasso regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_classes.html), [Laplacian regularization](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_laplacian.html) \[5] \[30] and [semi supervised setting](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_semi_supervised.html). -* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8]. -* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt). -* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27]. -* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53] +* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) \[14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) \[8]. +* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) \[11] (requires autograd + pymanopt). +* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) \[27]. +* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) \[52] and TW (OT-GNN) \[53] Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html). @@ -92,9 +93,11 @@ If you use this toolbox in your research and find it useful, please cite POT using the following references from the current version and from our [JMLR paper](https://jmlr.org/papers/v22/20-451.html): - Flamary R., Vincent-Cuaz C., Courty N., Gramfort A., Kachaiev O., Quang Tran H., David L., Bonet C., Cassereau N., Gnassounou T., Tanguy E., Delon J., Collas A., Mazelet S., Chapel L., Kerdoncuff T., Yu X., Feickert M., Krzakala P., Liu T., Fernandes Montesuma E. POT Python Optimal Transport (version 0.9.5). URL: https://github.com/PythonOT/POT +``` +Flamary R., Vincent-Cuaz C., Courty N., Gramfort A., Kachaiev O., Quang Tran H., David L., Bonet C., Cassereau N., Gnassounou T., Tanguy E., Delon J., Collas A., Mazelet S., Chapel L., Kerdoncuff T., Yu X., Feickert M., Krzakala P., Liu T., Fernandes Montesuma E. POT Python Optimal Transport (version 0.9.5). URL: https://github.com/PythonOT/POT - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. URL: https://pythonot.github.io/ +Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. URL: https://pythonot.github.io/ +``` In Bibtex format: @@ -122,13 +125,12 @@ In Bibtex format: The library has been tested on Linux, MacOSX and Windows. It requires a C++ compiler for building/installing the EMD solver and relies on the following Python modules: -- Numpy (>=1.16) -- Scipy (>=1.0) -- Cython (>=0.23) (build only, not necessary when installing from pip or conda) +* Numpy (>=1.16) +* Scipy (>=1.0) +* Cython (>=0.23) (build only, not necessary when installing from pip or conda) #### Pip installation - You can install the toolbox through PyPI with: ```console @@ -142,9 +144,11 @@ pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user ``` Optional dependencies may be installed with + ```console pip install POT[all] ``` + Note that this installs `cvxopt`, which is licensed under GPL 3.0. Alternatively, if you cannot use GPL-licensed software, the specific optional dependencies may be installed individually, or per-submodule. The available optional installations are `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`. #### Anaconda installation with conda-forge @@ -156,6 +160,7 @@ conda install -c conda-forge pot ``` #### Post installation check + After a correct installation, you should be able to import the module without errors: ```python @@ -164,7 +169,6 @@ import ot Note that for easier access the module is named `ot` instead of `pot`. - ### Dependencies Some sub-modules require additional dependencies which are discussed below @@ -175,7 +179,6 @@ Some sub-modules require additional dependencies which are discussed below pip install pymanopt autograd ``` - ## Examples ### Short examples @@ -240,8 +243,7 @@ ba = ot.barycenter(A, M, reg) # reg is regularization parameter ### Examples and Notebooks -The examples folder contain several examples and use case for the library. The full documentation with examples and output is available on [https://PythonOT.github.io/](https://PythonOT.github.io/). - +The examples folder contain several examples and use case for the library. The full documentation with examples and output is available on . ## Acknowledgements @@ -278,176 +280,179 @@ You can also post bug reports and feature requests in Github issues. Make sure t ## References -[1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). [Displacement interpolation using Lagrangian mass transport](https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf). In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. +\[1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). [Displacement interpolation using Lagrangian mass transport](https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf). In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. -[2] Cuturi, M. (2013). [Sinkhorn distances: Lightspeed computation of optimal transport](https://arxiv.org/pdf/1306.0895.pdf). In Advances in Neural Information Processing Systems (pp. 2292-2300). +\[2] Cuturi, M. (2013). [Sinkhorn distances: Lightspeed computation of optimal transport](https://arxiv.org/pdf/1306.0895.pdf). In Advances in Neural Information Processing Systems (pp. 2292-2300). -[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). [Iterative Bregman projections for regularized transportation problems](https://arxiv.org/pdf/1412.5154.pdf). SIAM Journal on Scientific Computing, 37(2), A1111-A1138. +\[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). [Iterative Bregman projections for regularized transportation problems](https://arxiv.org/pdf/1412.5154.pdf). SIAM Journal on Scientific Computing, 37(2), A1111-A1138. -[4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, [Supervised planetary unmixing with optimal transport](https://hal.archives-ouvertes.fr/hal-01377236/document), Workshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. +\[4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, [Supervised planetary unmixing with optimal transport](https://hal.archives-ouvertes.fr/hal-01377236/document), Workshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. -[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, [Optimal Transport for Domain Adaptation](https://arxiv.org/pdf/1507.00504.pdf), in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 +\[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, [Optimal Transport for Domain Adaptation](https://arxiv.org/pdf/1507.00504.pdf), in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 -[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). [Regularized discrete optimal transport](https://arxiv.org/pdf/1307.5551.pdf). SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +\[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). [Regularized discrete optimal transport](https://arxiv.org/pdf/1307.5551.pdf). SIAM Journal on Imaging Sciences, 7(3), 1853-1882. -[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). [Generalized conditional gradient: analysis of convergence and applications](https://arxiv.org/pdf/1510.06567.pdf). arXiv preprint arXiv:1510.06567. +\[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). [Generalized conditional gradient: analysis of convergence and applications](https://arxiv.org/pdf/1510.06567.pdf). arXiv preprint arXiv:1510.06567. -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS). +\[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS). -[9] Schmitzer, B. (2016). [Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/pdf/1610.06519.pdf). arXiv preprint arXiv:1610.06519. +\[9] Schmitzer, B. (2016). [Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/pdf/1610.06519.pdf). arXiv preprint arXiv:1610.06519. -[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816. +\[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816. -[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. +\[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. -[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). +\[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). -[13] Mémoli, Facundo (2011). [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 : 417-487. +\[13] Mémoli, Facundo (2011). [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 : 417-487. -[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. +\[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. -[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . +\[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . -[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. +\[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. -[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). +\[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). -[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). +\[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). -[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) +\[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) -[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning +\[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning -[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. +\[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. -[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 +\[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 -[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics, (AISTATS) 21, 2018 +\[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics, (AISTATS) 21, 2018 -[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). +\[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). -[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). +\[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). -[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). +\[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). -[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019. +\[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019. -[28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. +\[28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. -[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020. +\[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020. -[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. +\[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. -[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +\[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 -[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). +\[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). -[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 +\[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 -[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. +\[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. -[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). +\[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). -[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +\[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. (2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. -[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International +\[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 -[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +\[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. -[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. +\[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825\&rep=rep1\&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. -[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. +\[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. -[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) +\[41] Chapel\*, L., Flamary\*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) -[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. +\[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +\[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. -[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. +\[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. -[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. +\[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. -[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. +\[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. -[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. +\[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. -[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. +\[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. -[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. +\[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. -[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). +\[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). -[51] Xu, H., Luo, D., Zha, H., & Carin, L. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. +\[51] Xu, H., Luo, D., Zha, H., & Carin, L. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. -[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv. +\[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv. -[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35. +\[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35. -[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804). +\[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804). -[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR). +\[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR). -[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. +\[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019. -[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein +\[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4), 1178-1198. -[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. +\[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. -[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. +\[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. -[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. +\[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. -[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. +\[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. -[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning. +\[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning. -[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations. +\[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations. -[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. +\[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. -[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). +\[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). -[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). +\[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). -[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. +\[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. -[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. +\[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. -[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. +\[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. -[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene +\[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene & B. Thirion (2022). [Aligning individual brains with Fused Unbalanced Gromov-Wasserstein.](https://proceedings.neurips.cc/paper_files/paper/2022/file/8906cac4ca58dcaf17e97a0486ad57ca-Paper-Conference.pdf). Neural Information Processing Systems (NeurIPS). -[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on +\[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on Artificial Intelligence. -[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). +\[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). -[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. +\[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. -[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. +\[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. -[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. +\[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. -[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations. +\[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations. -[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) +\[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) -[78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations. +\[78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations. -[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations. +\[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations. -[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019. +\[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019. -[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). +\[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). +\[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385. -``` +\[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661. + +\[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations. diff --git a/RELEASES.md b/RELEASES.md index ccb9b97d2..ad1c895fe 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.9.7dev + +#### New features + +- Added Sliced OT plans (PR #767) + ## 0.9.6.post1 *September 2025* diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py new file mode 100644 index 000000000..00f7beed2 --- /dev/null +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +""" +=============== +Sliced OT Plans +=============== + +Compares different Sliced OT plans between two 2D point clouds. The min-Pivot +Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both +were further studied theoretically in [83]. + +.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + +.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + +.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. +""" + +# Author: Eloi Tanguy +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +############################################################################## +# Setup data and imports +# ---------------------- +import numpy as np + +import ot +import matplotlib.pyplot as plt +from ot.sliced import get_random_projections + + +seed = 0 +np.random.seed(seed) +n = 20 +m = 10 +d = 2 +X = np.random.randn(n, 2) +Y = np.random.randn(m, 2) + np.array([5.0, 0.0])[None, :] +n_proj = 50 +thetas = get_random_projections(d, n_proj).T +alpha = 0.3 + + +proj_X = X @ thetas.T +proj_Y = Y @ thetas.T + + +############################################################################## +# Compute min-Pivot Sliced permutation +# ------------------------------------ +min_plan, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas=thetas, log=True) + +############################################################################## +# Compute Expected Sliced Plan +# ------------------------------------ +expected_plan, expected_cost, log_expected = ot.expected_sliced( + X, Y, thetas=thetas, log=True +) +############################################################################## +# Compute 2-Wasserstein Plan +# ------------------------------------ +a = np.ones(n, device=X.device) / n +b = np.ones(m, device=Y.device) / m +dists = ot.dist(X, Y) +W2 = ot.emd2(a, b, dists) +W2_plan = ot.emd(a, b, dists) + +############################################################################## +# Plot resulting assignments +# ------------------------------------ +fig, axs = plt.subplots(2, 3, figsize=(12, 4)) +fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16) + +# draw min sliced permutation +axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") +for i in range(X.shape[0]): + for j in range(Y.shape[0]): + if min_plan[i, j] > 0: + axs[0, 0].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=alpha, + ) +axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues") + +# draw expected sliced plan +axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}") +for i in range(n): + for j in range(m): + w = alpha * expected_plan[i, j].item() * n + axs[0, 1].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="Expected Sliced plan" if i == 0 and j == 0 else None, + ) +axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues") + +# draw W2 plan +axs[0, 2].set_title(f"W$_2$: cost={W2:.2f}") +for i in range(n): + for j in range(m): + w = alpha * W2_plan[i, j].item() * n + axs[0, 2].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="W2 plan" if i == 0 and j == 0 else None, + ) +axs[1, 2].imshow(W2_plan, interpolation="nearest", cmap="Blues") + +for ax in axs[0, :]: + ax.scatter(X[:, 0], X[:, 1], label="X") + ax.scatter(Y[:, 0], Y[:, 1], label="Y") + +for ax in axs.flatten(): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() + +############################################################################## +# Compare Expected Sliced plans with different inverse-temperatures beta +# ------------------------------------ +# As the temperature decreases, ES becomes sparser and approaches minPS + +betas = [0.0, 5.0, 50.0] +n_plots = len(betas) + 1 +size = 4 +fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size)) + +fig.suptitle( + "Expected Sliced plan varying $\\beta$ (inverse temperature)", y=0.95, fontsize=16 +) +for beta_idx, beta in enumerate(betas): + expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas=thetas, beta=beta) + print(f"beta={beta}: cost={expected_cost:.2f}") + + axs[0, beta_idx].set_title(f"$\\beta$={beta}: cost={expected_cost:.2f}") + for i in range(n): + for j in range(m): + w = alpha * expected_plan[i, j].item() * n + axs[0, beta_idx].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="Expected Sliced plan" if i == 0 and j == 0 else None, + ) + + axs[0, beta_idx].scatter(X[:, 0], X[:, 1], label="X") + axs[0, beta_idx].scatter(Y[:, 0], Y[:, 1], label="Y") + axs[1, beta_idx].imshow(expected_plan, interpolation="nearest", cmap="Blues") + +# draw min sliced permutation (limit when beta -> +inf) +axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") +for i in range(X.shape[0]): + for j in range(Y.shape[0]): + if min_plan[i, j] > 0: + axs[0, -1].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=alpha, + ) + +axs[0, -1].scatter(X[:, 0], X[:, 1], label="X") +axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y") +axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues") + +for ax in axs.flatten(): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() diff --git a/ot/__init__.py b/ot/__init__.py index 235fb91b4..7f1d8152f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,6 +58,8 @@ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif, linear_sliced_wasserstein_sphere, + min_pivot_sliced, + expected_sliced, ) from .gromov import ( gromov_wasserstein, @@ -109,6 +111,8 @@ "sliced_wasserstein_distance", "sliced_wasserstein_sphere", "linear_sliced_wasserstein_sphere", + "min_pivot_sliced", + "expected_sliced", "gromov_wasserstein", "gromov_wasserstein2", "gromov_barycenters", diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 49e0c9c41..d09f8a46a 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -16,7 +16,7 @@ from ..utils import list_to_array -def quantile_function(qs, cws, xs): +def quantile_function(qs, cws, xs, idx_xs=None): r"""Computes the quantile function of an empirical distribution Parameters @@ -27,7 +27,8 @@ def quantile_function(qs, cws, xs): cumulative weights of the 1D empirical distribution, if batched, must be similar to xs xs: array-like, shape (n, ...) locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions - + idx_xs: array-like, shape (n, ...) + associated indices. If None, do not return them Returns ------- q: array-like, shape (..., n) @@ -44,11 +45,22 @@ def quantile_function(qs, cws, xs): cws = cws.T qs = qs.T idx = nx.searchsorted(cws, qs).T - return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + if idx_xs is not None: + return nx.take_along_axis( + xs, nx.clip(idx, 0, n - 1), axis=0 + ), nx.take_along_axis(idx_xs, nx.clip(idx, 0, n - 1), axis=0) + else: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) def wasserstein_1d( - u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + require_sort=True, + return_plan=False, ): r""" Computes the 1 dimensional OT loss [15] between two (batched) empirical @@ -79,7 +91,9 @@ def wasserstein_1d( require_sort: bool, optional sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True - + return_plan: bool, optional + if True, returns also the optimal transport plan between the two + (batched) measures as a coo_matrix, default is False Returns ------- cost: float/array-like, shape (...) @@ -124,15 +138,31 @@ def wasserstein_1d( v_cumweights = nx.cumsum(v_weights, 0) qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) - u_quantiles = quantile_function(qs, u_cumweights, u_values) - v_quantiles = quantile_function(qs, v_cumweights, v_values) + u_quantiles, u_quantiles_idx = quantile_function( + qs, u_cumweights, u_values, u_sorter + ) + v_quantiles, v_quantiles_idx = quantile_function( + qs, v_cumweights, v_values, v_sorter + ) qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) delta = qs[1:, ...] - qs[:-1, ...] diff_quantiles = nx.abs(u_quantiles - v_quantiles) - if p == 1: - return nx.sum(delta * diff_quantiles, axis=0) - return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + if return_plan: + plan = [ + nx.coo_matrix( + delta[:, k], + u_quantiles_idx[:, k], + v_quantiles_idx[:, k], + shape=(n, m), + type_as=u_values, + ) + for k in range(delta.shape[1]) + ] + if return_plan: + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0), plan + else: + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) def emd_1d( @@ -201,7 +231,8 @@ def emd_1d( gamma: ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log: dict - If input log is True, a dictionary containing the cost + If input log is True, a dictionary containing the cost and the indices + of the non-zero elements of the transportation matrix Examples @@ -297,6 +328,8 @@ def emd_1d( warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {"cost": nx.from_numpy(cost, type_as=x_a)} + log["perms_x_a"] = perm_a[indices[:, 0]] + log["perms_x_b"] = perm_b[indices[:, 1]] return G, log return G diff --git a/ot/sliced.py b/ot/sliced.py index 3cf2002e7..4d31c6119 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,21 +1,25 @@ """ Sliced OT Distances - """ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary +# Eloi Tanguy +# Laetitia Chapel # # License: MIT License +import warnings + import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array, get_coordinate_circle +from .utils import list_to_array, get_coordinate_circle, dist from .lp import ( wasserstein_circle, semidiscrete_wasserstein2_unif_circle, linear_circular_ot, + wasserstein_1d, ) @@ -595,7 +599,8 @@ def linear_sliced_wasserstein_sphere( X_s: ndarray, shape (n_samples_a, dim) Samples in the source domain X_t: ndarray, shape (n_samples_b, dim), optional - Samples in the target domain. If None, computes the distance against the uniform distribution on the sphere. + Samples in the target domain. If None, computes the distance against + the uniform distribution on the sphere. a : ndarray, shape (n_samples_a,), optional samples weights in the source domain b : ndarray, shape (n_samples_b,), optional @@ -607,7 +612,8 @@ def linear_sliced_wasserstein_sphere( seed: int or RandomState or None, optional Seed used for random number generator log: bool, optional - if True, linear_sliced_wasserstein_sphere returns the projections used and their associated LCOT. + if True, linear_sliced_wasserstein_sphere returns the projections used + and their associated LCOT. Returns ------- @@ -628,7 +634,10 @@ def linear_sliced_wasserstein_sphere( .. _references-lssot: References ---------- - .. [79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data. International Conference on Learning Representations. + .. [79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, + B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal + Transport: A Fast Metric for Comparing Spherical Data. International + Conference on Learning Representations. """ d = X_s.shape[-1] @@ -639,9 +648,8 @@ def linear_sliced_wasserstein_sphere( if X_s.shape[1] != X_t.shape[1]: raise ValueError( - "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( - X_s.shape[1], X_t.shape[1] - ) + "X_s and X_t must have the same number of dimensions {} and {} \ + respectively given".format(X_s.shape[1], X_t.shape[1]) ) if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)): raise ValueError("X_s is not on the sphere.") @@ -674,3 +682,545 @@ def linear_sliced_wasserstein_sphere( if log: return res, {"projections": projections, "projected_emds": projected_lcot} return res + + +def sliced_plans( + X, + Y, + a=None, + b=None, + metric="sqeuclidean", + p=2, + thetas=None, + warm_theta=None, + n_proj=None, + dense=False, + log=False, +): + r""" + Computes all the permutations that sort the projections of two `(n, d)` + datasets `X` and `Y` on the directions `thetas`. + Each permutation `perm[:, k]` is such that each `X[i, :]` is matched + to `Y[perm[i, k], :]` when projected on `thetas[k, :]`. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (m, d) + The second set of vectors. + a : ndarray of float64, shape (ns,), optional + Source histogram (default is uniform weight) + b : ndarray of float64, shape (nt,), optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + thetas : array-like, shape (n_proj, d), optional + The projection directions. If None, random directions will be + generated. + Default is None. + warm_theta : array-like, shape (d,), optional + A direction to add to the set of directions. Default is None. + dense: bool, optional + If True, returns dense matrices instead of sparse ones. + Default is False. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + log : bool, optional + If True, returns additional logging information. Default is False. + + Returns + ------- + plan : ndarray, shape (ns, nt) or coo_matrix if dense is False + Optimal transportation matrix for the given parameters + costs : list of float + The cost associated to each projection. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + """ + + X, Y = list_to_array(X, Y) + + if a is not None and b is not None and thetas is None: + nx = get_backend(X, Y, a, b) + elif a is not None and b is not None and thetas is not None: + nx = get_backend(X, Y, a, b, thetas) + elif a is None and b is None and thetas is not None: + nx = get_backend(X, Y, thetas) + else: + nx = get_backend(X, Y) + + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" + assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" + + assert ( + X.shape[1] == Y.shape[1] + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" + if metric == "euclidean": + p = 2 + elif metric == "cityblock": + p = 1 + + d = X.shape[1] + n = X.shape[0] + m = Y.shape[0] + + is_perm = False + if n == m: + if a is None or b is None or (a == b).all(): + is_perm = True + + do_draw_thetas = thetas is None + if do_draw_thetas: # create thetas (n_proj, d) + assert n_proj is not None, "n_proj must be specified if thetas is None" + thetas = get_random_projections(d, n_proj, backend=nx, type_as=X).T + + if warm_theta is not None: + thetas = nx.concatenate([thetas, warm_theta[:, None].T], axis=0) + else: + n_proj = thetas.shape[0] + + # project on each theta: (n or m, d) -> (n or m, n_proj) + X_theta = X @ thetas.T # shape (n, n_proj) + Y_theta = Y @ thetas.T # shape (m, n_proj) + + if is_perm: + # we compute maps (permutations) + # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] + sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) + tau = nx.argsort(Y_theta, axis=0) # (m, n_proj) + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ( + (nx.sum(nx.abs(X[sigma[:, k]] - Y[tau[:, k]]) ** p, axis=1)) + ** (1 / p) + ) + / n + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum((nx.sum((X[sigma[:, k]] - Y[tau[:, k]]) ** 2, axis=1)) / n) + for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + ) + a = nx.ones(n) / n + plan = [ + nx.coo_matrix(a, sigma[:, k], tau[:, k], shape=(n, m), type_as=a) + for k in range(n_proj) + ] + + if not dense and str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + plan = [nx.todense(plan[k]) for k in range(n_proj)] + + else: # we compute plans + _, plan = wasserstein_1d( + X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True + ) + + if str(nx) == "jax": # dense computation + if not dense: + warnings.warn( + "JAX does not support sparse matrices, converting to dense" + ) + + plan = [nx.todense(plan[k]) for k in range(n_proj)] + + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ( + ( + nx.sum( + nx.abs(X[:, None, :] - Y[None, :, :]) ** p, axis=-1 + ) + ) + ** (1 / p) + ) + * plan[k].data + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum( + (nx.sum((X[:, None, :] - Y[None, :, :]) ** 2, axis=-1)) + * plan[k].data + ) + for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + ) + + else: # not jax, sparse computation + if metric in ("minkowski", "euclidean", "cityblock"): + costs = [ + nx.sum( + ( + ( + nx.sum( + nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1 + ) + ) + ** (1 / p) + ) + * plan[k] + ) + for k in range(n_proj) + ] + elif metric == "sqeuclidean": + costs = [ + nx.sum( + (nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1)) + * plan[k] + ) + for k in range(n_proj) + ] + else: + raise ValueError( + "Sliced plans work only with metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + ) + + if dense: + plan = [nx.todense(plan[k]) for k in range(n_proj)] + + if log: + log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas} + return plan, costs, log_dict + else: + return plan, costs + + +def min_pivot_sliced( + X, + Y, + a=None, + b=None, + thetas=None, + metric="sqeuclidean", + p=2, + n_proj=None, + dense=True, + log=False, + warm_theta=None, +): + r""" + Computes the cost and permutation associated to the min-Pivot Sliced + Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given + the supports `X` and `Y` of two discrete uniform measures with `n` and `m` + atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through + `n_proj` different projections of the measures on random directions, and + retains the couplings that yields the lowest cost between `X` and `Y` + (compared in :math:`\mathbb{R}^d`). When $n=m$, it gives + + .. math:: + \mathrm{min\text{-}PS}_p^p(X, Y) \approx + \min_{k \in [1, n_{\mathrm{proj}}]} \left( + \frac{1}{n} \sum_{i=1}^n \|X_i - Y_{\sigma_k(i)}\|_2^p \right), + + where :math:`\sigma_k` is a permutation such that ordering the projections + on the axis `thetas[k, :]` matches `X[i, :]` to `Y[\sigma_k(i), :]`. + + .. note:: + The computation ignores potential ambiguities in the projections: if + two points from a same measure have the same projection on a direction, + then multiple sorting permutations are possible. To avoid combinatorial + explosion, only one permutation is retained: this strays from theory in + pathological cases. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (m, d) + The second set of vectors. + a : ndarray of float64, shape (n,), optional + Source histogram (default is uniform weight) + b : ndarray of float64, shape (m,), optional + Target histogram (default is uniform weight) + thetas : array-like, shape (n_proj, d), optional + The projection directions. If None, random directions will be generated + Default is None. + metric: str, optional (default='sqeuclidean') + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + n_proj : int, optional + The number of projection directions. Required if thetas is None. + dense: boolean, optional (default=True) + If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. + log : bool, optional + If True, returns additional logging information. Default is False. + warm_theta : array-like, shape (d,), optional + A theta to add to the list of thetas. Default is None. + + Returns + ------- + plan : ndarray, shape (n, m) or coo_matrix if dense is False + Optimal transportation matrix for the given parameters. + cost : float + The cost associated to the optimal permutation. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + + References + ---------- + .. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). + Fast Optimal Transport through Sliced Generalized Wasserstein + Geodesics. Advances in Neural Information Processing Systems, 36, + 35350–35385. + + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport + Plans. arXiv preprint 2506.03661. + + Examples + -------- + >>> x=np.array([[3.,3.], [1.,1.]]) + >>> y=np.array([[2.,2.5], [3.,2.]]) + >>> thetas=np.array([[1, 0], [0, 1]]) + >>> plan, cost = ot.expected_sliced(x, y, thetas) + >>> plan + [[0 0.5] + [0.5 0]] + >>> cost + 2.125 + """ + + X, Y = list_to_array(X, Y) + + if a is not None and b is not None and thetas is None: + nx = get_backend(X, Y, a, b) + elif a is not None and b is not None and thetas is not None: + nx = get_backend(X, Y, a, b, thetas) + elif a is None and b is None and thetas is not None: + nx = get_backend(X, Y, thetas) + else: + nx = get_backend(X, Y) + + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" + assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" + + assert ( + X.shape[1] == Y.shape[1] + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" + + log_dict = {} + G, costs, log_dict_plans = sliced_plans( + X, + Y, + a, + b, + metric, + p, + thetas, + n_proj=n_proj, + warm_theta=warm_theta, + log=True, + ) + pos_min = nx.argmin(costs) + cost = costs[pos_min] + plan = G[pos_min] + + if log: + log_dict = { + "thetas": log_dict_plans["thetas"], + "costs": costs, + "min_theta": log_dict_plans["thetas"][pos_min], + "X_min_theta": log_dict_plans["X_theta"][:, pos_min], + "Y_min_theta": log_dict_plans["Y_theta"][:, pos_min], + } + + if dense: + plan = nx.todense(plan) + elif str(nx) == "jax": + warnings.warn( + "JAX does not support sparse matrices, converting to\ + dense" + ) + plan = nx.todense(plan) + + if log: + return plan, cost, log_dict + else: + return plan, cost + + +def expected_sliced( + X, + Y, + a=None, + b=None, + thetas=None, + metric="sqeuclidean", + p=2, + n_proj=None, + dense=True, + log=False, + beta=0.0, +): + r""" + Computes the Expected Sliced cost and plan between two datasets `X` and + `Y` of shapes `(n, d)` and `(m, d)`. Given a set of `n_proj` projection + directions, the expected sliced plan is obtained by averaging the `n_proj` + 1d optimal transport plans between the projections of `X` and `Y` on each + direction. Expected Sliced was introduced in [84] and further studied in + [83]. + + .. note:: + The computation ignores potential ambiguities in the projections: if + two points from a same measure have the same projection on a direction, + then multiple sorting permutations are possible. To avoid combinatorial + explosion, only one permutation is retained: this strays from theory in + pathological cases. + + .. warning:: + The function runs on backend but tensorflow and jax are not supported + due to array assignment. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (m, d) + The second set of vectors. + a : ndarray of float64, shape (n,), optional + Source histogram (default is uniform weight) + b : ndarray of float64, shape (m,), optional + Target histogram (default is uniform weight) + thetas : torch.Tensor, optional + A tensor of shape (n_proj, d) representing the projection directions. + If None, random directions will be generated. Default is None. + metric: str, optional (default='sqeuclidean') + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. + p: float, optional (default=2) + The p-norm to apply for if metric='minkowski' + n_proj : int, optional + The number of projection directions. Required if thetas is None. + dense: boolean, optional (default=True) + If True, returns :math:`\gamma` as a dense ndarray of shape (n, m). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. + log : bool, optional + If True, returns additional logging information. Default is False. + beta : float, optional + Inverse-temperature parameter which weights each projection's + contribution to the expected plan. Default is 0 (uniform weighting). + + Returns + ------- + plan : ndarray, shape (n, m) or coo_matrix if dense is False + Optimal transportation matrix for the given parameters. + cost : float + The cost associated to the optimal permutation. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + + References + ---------- + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport + Plans. arXiv preprint 2506.03661. + .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi + A., Kolouri, S. (2024). Expected Sliced Transport Plans. + International Conference on Learning Representations. + + Examples + -------- + >>> x=np.array([[3.,3.], [1.,1.]]) + >>> y=np.array([[2.,2.5], [3.,2.]]) + >>> thetas=np.array([[1, 0], [0, 1]]) + >>> plan, cost = ot.expected_sliced(x, y, thetas) + >>> plan + [[0.25 0.25] + [0.25 0.25]] + >>> cost + 2.625 + """ + + X, Y = list_to_array(X, Y) + + if a is not None and b is not None and thetas is None: + nx = get_backend(X, Y, a, b) + elif a is not None and b is not None and thetas is not None: + nx = get_backend(X, Y, a, b, thetas) + elif a is None and b is None and thetas is not None: + nx = get_backend(X, Y, thetas) + else: + nx = get_backend(X, Y) + + assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead" + assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead" + + assert ( + X.shape[1] == Y.shape[1] + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns" + + if str(nx) in ["tf", "jax"]: + raise NotImplementedError( + f"expected_sliced is not implemented for the {str(nx)} backend due" + "to array assignment." + ) + + n = X.shape[0] + m = Y.shape[0] + + log_dict = {} + G, costs, log_dict_plans = sliced_plans( + X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True + ) + if log: + log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G} + + if beta != 0.0: # computing the temperature weighting + log_factors = -beta * list_to_array(costs) + weights = nx.exp(log_factors - nx.logsumexp(log_factors)) + cost = nx.sum(list_to_array(costs) * weights) + + else: # uniform weights + if n_proj is None: + n_proj = thetas.shape[0] + weights = nx.ones(n_proj) / n_proj + + log_dict["weights"] = weights + + weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))]) + X_idx = nx.concatenate([G[i].row for i in range(len(G))]) + Y_idx = nx.concatenate([G[i].col for i in range(len(G))]) + plan = nx.coo_matrix(weights, X_idx, Y_idx, shape=(n, m), type_as=weights) + + if beta == 0.0: # otherwise already computed above + cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum() + + if dense: + plan = nx.todense(plan) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + plan = nx.todense(plan) + + if log: + return plan, cost, log_dict + else: + return plan, cost diff --git a/test/test_sliced.py b/test/test_sliced.py index 05de13755..ddcb3fe1d 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -2,6 +2,8 @@ # Author: Adrien Corenflos # Nicolas Courty +# Eloi Tanguy +# Laetitia Chapel # # License: MIT License @@ -11,6 +13,7 @@ import ot from ot.sliced import get_random_projections from ot.backend import tf, torch +from contextlib import nullcontext def test_get_random_projections(): @@ -110,6 +113,14 @@ def test_max_sliced_different_dists(): assert res > 0.0 +def test_max_sliced_dim_check(): + n = 3 + x = np.zeros((n, 2)) + y = np.zeros((n + 1, 3)) + with pytest.raises(ValueError): + _ = ot.max_sliced_wasserstein_distance(x, y, n_projections=10) + + def test_sliced_same_proj(): n_projections = 10 seed = 12 @@ -152,6 +163,16 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, 2 * n) + b /= b.sum() + a_b = nx.from_numpy(a) + b_b = nx.from_numpy(b) + val = ot.sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) + val_b = ot.sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) + np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) + def test_sliced_backend_type_devices(nx): n = 100 @@ -227,6 +248,16 @@ def test_max_sliced_backend(nx): assert np.allclose(val0, valb) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, 2 * n) + b /= b.sum() + a_b = nx.from_numpy(a) + b_b = nx.from_numpy(b) + val = ot.max_sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) + val_b = ot.max_sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) + np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) + def test_max_sliced_backend_type_devices(nx): n = 100 @@ -697,3 +728,234 @@ def test_linear_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) + + +def test_sliced_permutations(): + n = 4 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + + # test without provided thetas + _, _ = ot.sliced.sliced_plans(x, y, n_proj=n_proj) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.sliced_plans(x[:, 1:], y, thetas=thetas) + + +def test_sliced_plans(): + x = [1, 2] + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x, x, n_proj=2) + + n = 4 + m = 5 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(m, 2) + + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() + + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + + # test with a and b not uniform + ot.sliced.sliced_plans(x, y, a, b, thetas=thetas, dense=True) + + # test with the minkowski metric + ot.sliced.sliced_plans(x, y, thetas=thetas, metric="minkowski") + + # test with an unsupported metric + with pytest.raises(ValueError): + ot.sliced.sliced_plans(x, y, thetas=thetas, metric="mahalanobis") + + # test with a warm theta + ot.sliced.sliced_plans(x, y, n_proj=10, warm_theta=thetas[-1]) + + +def test_min_pivot_sliced(): + x = [1, 2] + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x, x, n_proj=2) + + n = 10 + m = 4 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(m, 2) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() + + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + + # identity of the indiscernibles + _, min_cost = ot.min_pivot_sliced(x, x, a, a, n_proj=10) + np.testing.assert_almost_equal(min_cost, 0.0) + + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) + + # result should be an upper-bound of W2 and relatively close + w2 = ot.emd2(a, b, ot.dist(x, y)) + assert min_cost >= w2 + assert min_cost <= 1.5 * w2 + + # test without provided thetas + ot.sliced.min_pivot_sliced(x, y, a, b, n_proj=n_proj, log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) + + # test the logs + _, min_cost, log = ot.sliced.min_pivot_sliced( + x, y, a, b, thetas=thetas, dense=False, log=True + ) + assert len(log) == 5 + costs = log["costs"] + assert len(costs) == thetas.shape[0] + assert len(log["min_theta"]) == d + assert (log["thetas"] == thetas).all() + for c in costs: + assert c > 0 + + # test with different metrics + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski") + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="euclidean") + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="cityblock") + + # test with an unsupported metric + with pytest.raises(ValueError): + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="mahalanobis") + + # test with a warm theta + ot.sliced.min_pivot_sliced(x, y, n_proj=10, warm_theta=thetas[-1]) + + +def test_expected_sliced(): + x = [1, 2] + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x, x, n_proj=2) + + n = 10 + m = 24 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(m, 2) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() + + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + + _, expected_cost = ot.sliced.expected_sliced(x, y, a, b, dense=True, thetas=thetas) + # result should be a coarse upper-bound of W2 + w2 = ot.emd2(a, b, ot.dist(x, y)) + assert expected_cost >= w2 + assert expected_cost <= 3 * w2 + + # test without provided thetas + ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + ot.sliced.expected_sliced(x, y, a, b, n_proj=n_proj, log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[:, 1:], y, thetas=thetas) + + # with a small temperature (i.e. large beta), the cost should be close + # to min_pivot + _, expected_cost = ot.sliced.expected_sliced( + x, y, a, b, thetas=thetas, dense=True, beta=100.0 + ) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, thetas=thetas, dense=True) + np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3) + + # test the logs + _, min_cost, log = ot.sliced.expected_sliced( + x, y, a, b, thetas=thetas, dense=False, log=True + ) + assert len(log) == 4 + costs = log["costs"] + assert len(costs) == thetas.shape[0] + assert len(log["weights"]) == thetas.shape[0] + assert (log["thetas"] == thetas).all() + for c in costs: + assert c > 0 + + # test with the minkowski metric + ot.sliced.expected_sliced(x, y, thetas=thetas, metric="minkowski") + + # test with an unsupported metric + with pytest.raises(ValueError): + ot.sliced.expected_sliced(x, y, thetas=thetas, metric="mahalanobis") + + +def test_sliced_plans_backends(nx): + n = 10 + m = 24 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(m, 2) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, m) + b /= b.sum() + + x_b, y_b, a_b, b_b = nx.from_numpy(x, y, a, b) + + thetas_b = ot.sliced.get_random_projections( + d, n_proj, seed=0, backend=nx, type_as=x_b + ).T + thetas = nx.to_numpy(thetas_b) + + context = ( + nullcontext() + if str(nx) not in ["tf", "jax"] + else pytest.raises(NotImplementedError) + ) + + with context: + _, expected_cost_b = ot.sliced.expected_sliced( + x_b, y_b, a_b, b_b, dense=True, thetas=thetas_b + ) + # result should be the same than numpy version + _, expected_cost = ot.sliced.expected_sliced( + x, y, a, b, dense=True, thetas=thetas + ) + np.testing.assert_almost_equal(expected_cost_b, expected_cost) + + # for min_pivot + _, min_cost_b = ot.sliced.min_pivot_sliced( + x_b, y_b, a_b, b_b, dense=True, thetas=thetas_b + ) + # result should be the same than numpy version + _, min_cost = ot.sliced.min_pivot_sliced(x, y, a, b, dense=True, thetas=thetas) + np.testing.assert_almost_equal(min_cost_b, min_cost) + + # for sliced_plans + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0, backend=nx).T + + # test with the minkowski metric + ot.sliced.min_pivot_sliced(x, y, thetas=thetas, metric="minkowski")