From 44d46149e7f12cf4f712332f70fad0a6f78a0299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Sep 2024 23:46:02 +0200 Subject: [PATCH 1/2] merge --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index cc18cc91b..277af7847 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,8 @@ - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) +- restructure `ot.unbalanced` module (PR #658) +- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From 35aa7f7a21b21abb59fbb494fd79a42cc065a88b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Fri, 25 Apr 2025 13:51:13 +0200 Subject: [PATCH 2/2] fix jax autograd --- RELEASES.md | 1 + ot/backend.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 62240fa77..a24747fb7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,6 +16,7 @@ - `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680) - Backend implementation of `ot.dist` for (PR #701) - Updated documentation Quickstart guide and User guide with new API (PR #726) +- Fix jax version for auto-grad (PR #732) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/backend.py b/ot/backend.py index d5f58bbcc..3d59639fa 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1509,7 +1509,7 @@ def set_gradients(self, val, inputs, grads): aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 aux = aux - jax.lax.stop_gradient(aux) - (val,) = jax.tree_map(lambda z: z + aux, (val,)) + (val,) = jax.tree_util.tree_map(lambda z: z + aux, (val,)) return val def _detach(self, a):