-
Notifications
You must be signed in to change notification settings - Fork 526
[WIP] Partial optimal transport 1d solver #741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #741 +/- ##
=======================================
Coverage 97.09% 97.10%
=======================================
Files 100 101 +1
Lines 20432 20459 +27
=======================================
+ Hits 19839 19866 +27
Misses 593 593 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds an efficient 1D partial optimal transport solver implemented in Cython with a Python wrapper, updates packaging to include the new extension, and provides basic tests.
- Introduces
partial_wasserstein_1d
inot.partial
, backed by a Cython routine. - Registers a new Cython extension
partial_cython
insetup.py
. - Adds tests for dimensionality checks and correctness of 1D solver outputs.
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
test/test_partial.py | New tests for error on non-1D input and basic functional tests |
setup.py | Added partial_cython extension to cythonize call |
ot/partial/partial_solvers.py | Imported and exposed the Python wrapper partial_wasserstein_1d |
ot/partial/partial_cython.pyx | Complete Cython implementation of the 1D partial Wasserstein solver |
ot/partial/init.py | Exported partial_wasserstein_1d in the package API |
Comments suppressed due to low confidence (2)
ot/partial/partial_solvers.py:1308
- The docstring states it "returns the OT matrix", but the function actually returns (indices_x, indices_y, marginal_costs). Please update the brief description to reflect the actual return values.
r"""Solves the partial Wasserstein distance problem between 1d measures and returns
test/test_partial.py:62
- [nitpick] The test comments mention non-1D inputs but don’t define
xs
andxt
with invalid shapes. Consider explicitly creating e.g.xs = np.zeros((5,2))
,xt = np.zeros((5,2))
before calling to ensure the assertion is triggered.
with pytest.raises(AssertionError):
OK I have something I am rather happy with. My last step would be to move from list/dict/set implementation to only numpy arrays and benchmark efficiency. @rflamary maybe you can have a look at the PR before this last step, as that should not impact too much the overall integration of the method in the package (docs, example and API should stay the same). Also, I'm not sure when I will have time for this last step (should either be tomorrow or in 1-2months). |
...from Chapel & Tavenard 2025
Types of changes
It introduces an efficient partial optimal transport solver for 1d problems
Motivation and context / Related issue
How has this been tested (if it applies)
PR checklist