From 44d46149e7f12cf4f712332f70fad0a6f78a0299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Sep 2024 23:46:02 +0200 Subject: [PATCH 1/5] merge --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index cc18cc91b..277af7847 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,8 @@ - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) +- restructure `ot.unbalanced` module (PR #658) +- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From a93c60cbe7f317c77accd5dbb6850ec8cfad584c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 21 Apr 2025 22:14:33 +0200 Subject: [PATCH 2/5] first commit --- ot/__init__.py | 3 +- ot/solvers.py | 546 ++++++++++++++++++++++++++++++++++++++++++- ot/utils.py | 171 ++++++++++++++ test/test_solvers.py | 136 +++++++++++ 4 files changed, 854 insertions(+), 2 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 5e21d6a76..dbddbd03f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -68,7 +68,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov, solve_sample, bary_sample from .lowrank import lowrank_sinkhorn # utils functions @@ -116,6 +116,7 @@ "solve", "solve_gromov", "solve_sample", + "bary_sample", "smooth", "stochastic", "unbalanced", diff --git a/ot/solvers.py b/ot/solvers.py index a5bbf0e94..257e0d701 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -4,10 +4,11 @@ """ # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License -from .utils import OTResult, dist +from .utils import OTResult, BaryResult, dist from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced @@ -33,6 +34,7 @@ from .optim import cg import warnings +import numpy as np lst_method_lazy = [ @@ -1936,3 +1938,545 @@ def solve_sample( log=log, ) return res + + +def _bary_sample_bcd( + X_s, + X_init, + a_s, + b_init, + w_s, + metric, + inner_solver, + max_iter_bary, + tol_bary, + verbose, + log, + nx, +): + """Compute the barycenter using BCD. + + Parameters + ---------- + X_s : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + X_init : array-like, shape (n_samples_b, dim), + Initialization of the barycenter samples. + a_s : list of array-like, shape (dim_k,) + List of samples weights in each source distribution + b_init : array-like, shape (n_samples_b,) + Initialization of the barycenter weights. + w_s : list of array-like, shape (N,) + Samples barycentric weights + metric : str + Metric to use for the cost matrix, by default "sqeuclidean" + inner_solver : callable + Function to solve the inner OT problem + max_iter_bary : int + Maximum number of iterations for the barycenter + tol_bary : float + Tolerance for the barycenter convergence + verbose : bool + Print information in the solver + log : bool + Log the loss during the iterations + nx: backend + Backend to use for the computation. Must match<< + Returns + ------- + TBD + """ + + X = X_init + b = b_init + inv_b = 1.0 / b + + prev_loss = np.inf + n_samples = len(X_s) + + if log: + log_ = {"loss": []} + else: + log_ = None + # Compute the barycenter using BCD + for it in range(max_iter_bary): + # Solve the inner OT problem for each source distribution + list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)] + + # Update the barycenter samples + if metric in ["sqeuclidean", "euclidean"]: + X_new = ( + sum([w_s[k] * list_res[k].plan.T @ X_s[k] for k in range(n_samples)]) + * inv_b[:, None] + ) + else: + raise NotImplementedError('Not implemented metric="{}"'.format(metric)) + + # compute loss + new_loss = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + + if verbose: + if it % 1 == 0: + print(f"BCD iteration {it}: loss = {new_loss:.4f}") + + if log: + log_["loss"].append(new_loss) + # Check convergence + if abs(new_loss - prev_loss) / abs(prev_loss) < tol_bary: + print(f"BCD converged in {it} iterations") + break + + X = X_new + prev_loss = new_loss + + # compute value_linear + value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)]) + # update BaryResult + bary_res = BaryResult( + X=X_new, + b=b, + value=new_loss, + value_linear=value_linear, + log=log_, + list_res=list_res, + backend=nx, + ) + return bary_res + + +def bary_sample( + X_s, + n, + a_s=None, + w_s=None, + X_init=None, + b_init=None, + learn_X=True, + learn_b=False, + metric="sqeuclidean", + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + lazy=False, + batch_size=None, + method=None, + n_threads=1, + max_iter_bary=1000, + max_iter=None, + rank=100, + scaling=0.95, + tol_bary=1e-5, + tol=None, + random_state=0, + verbose=False, +): + r"""Solve the discrete OT barycenter problem over source distributions using Block-Coordinate Descent. + + The function solves the following general OT barycenter problem + + .. math:: + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}, \mathbf{b} \in \Sigma_n} \min_{\{ \mathbf{T}^{(k)} \}_k \in \R_+^{n_i \times n}} \quad \sum_k w_k \sum_{i,j} T^{(k)}_{i,j}M^{(k)}_{i,j} + \lambda_r R(\mathbf{T}^{(k)}) + + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) + + where the cost matrices :math:`\mathbf{M}^{(k)}` for each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{b}^{(k)})` + is computed from the samples in the source and barycenter domains such that + :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + n : int + number of samples in the barycenter domain + a_s : list of array-like, shape (dim_k,), optional + List of samples weights in each source distribution (default is uniform) + w_s : list of array-like, shape (N,), optional + Samples barycentric weights (default is uniform) + X_init : array-like, shape (n_samples_b, dim), optional + Initialization of the barycenter samples (default is gaussian random sampling). + Shape must match with required n. + b_init : array-like, shape (n_samples_b,), optional + Initialization of the barycenter weights (default is uniform). + Shape must match with required n. + learn_X : bool, optional + Learn the barycenter samples (default is True) + learn_b : bool, optional + Learn the barycenter weights (default is False) + metric : str, optional + Metric to use for the cost matrix, by default "sqeuclidean" + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + c : array-like, shape (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + batch_size : int, optional + Batch size for lazy solver, by default None (default values in each + solvers) + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter_bary : int, optional + Maximum number of iteration for the BCD solver, by default 1000. + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + rank : int, optional + Rank of the OT matrix for lazy solers (method='factored'), by default 100 + scaling : float, optional + Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 + tol_bary : float, optional + Tolerance for solution precision of barycenter problem, by default None (default value 1e-5) + tol : float, optional + Tolerance for solution precision of inner OT solver, by default None (default values in each solvers) + random_state : int, optional + Random seed for the initialization of the barycenter samples, by default 0. + Only used if `X_init` is None. + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b) + + # for uniform weights + res = ot.solve_sample(xa, xb) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_sample(xa, xb, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + + # lazy solver of memory complexity O(n) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) + # lazy OT plan + lazy_plan = res.lazy_plan + + # Use envelope theorem differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `envelope` mode computes the gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + + We also have a very efficient solver with compiled CPU/CUDA code using + geomloss/PyKeOps that can be used with the following code: + + .. code-block:: python + + # automatic solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') + + # force O(n) memory efficient solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') + + # force pre-computed cost matrix + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') + + # use multiscale solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') + + # One can play with speed (small scaling factor) and precision (scaling close to 1) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + \text{with} \ M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + \text{with} \ M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', + unbalanced=1.0, unbalanced_type='L2') + + + - **Factored OT [2]** (when ``method='factored'``): + + This method solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated + to the samples in the source and target domains, and :math:`W_2` is the + Wasserstein distance. This problem is solved using exact OT solvers for + `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides + two transport plans that can be used to recover a low rank OT plan between + the two distributions. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='factored', rank=10) + + # recover the lazy low rank plan + factored_solution_lazy = res.lazy_plan + + # recover the full low rank plan + factored_solution = factored_solution_lazy[:] + + - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): + + This method computes the Gaussian Bures-Wasserstein distance between two + Gaussian distributions estimated from the empirical distributions + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + The covariances and means are estimated from the data. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='gaussian') + + # recover the squared Gaussian Bures-Wasserstein distance + BW_dist = res.value + + - **Wasserstein 1d [1]** (when ``method='1D'``): + + This method computes the Wasserstein distance between two 1d distributions + estimated from the empirical distributions. For multivariate data the + distances are computed independently for each dimension. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='1D') + + # recover the squared Wasserstein distances + W_dists = res.value + + + .. _references-bary-sample: + References + ---------- + + """ + if learn_b: + raise NotImplementedError("Barycenter weights learning not implemented yet") + + if method is not None and method.lower() in lst_method_lazy: + raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") + + n_samples = len(X_s) + + if ( + not lazy + ): # default non lazy solver calls ot.solve_sample within _bary_sample_bcd + # Detect backend + nx = get_backend(*X_s, X_init, b_init, w_s) + + # check sample weights + if a_s is None: + a_s = [ + nx.ones((X_s[k].shape[0],), type_as=X_s[k]) / X_s[k].shape[0] + for k in range(n_samples) + ] + + # check samples barycentric weights + if w_s is None: + w_s = nx.ones(n_samples, type_as=X_s[0]) / n_samples + + # check X_init + if X_init is None: + if (not learn_X) and learn_b: + raise ValueError( + "X_init must be provided if learn_X=False and learn_b=True" + ) + else: + rng = np.random.RandomState(random_state) + mean_ = nx.concatenate( + [nx.mean(X_s[k], axis=0) for k in range(n_samples)], + axis=0, + ) + mean_ = nx.mean(mean_, axis=0) + std_ = nx.concatenate( + [nx.std(X_s[k], axis=0) for k in range(n_samples)], + axis=0, + ) + std_ = nx.mean(std_, axis=0) + X_init = rng.normal( + loc=mean_, + scale=std_, + size=(n, X_s[0].shape[1]), + ) + X_init = nx.from_numpy(X_init, type_as=X_s[0]) + else: + if (X_init.shape[0] != n) or (X_init.shape[1] != X_s[0].shape[1]): + raise ValueError("X_init must have shape (n, dim)") + + # check b_init + if b_init is None: + b_init = nx.ones((n,), type_as=X_s[0]) / n + + def inner_solver(X_a, X, a, b): + return solve_sample( + X_a=X_a, + X_b=X, + a=a, + b=b, + metric=metric, + reg=reg, + c=c, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + method=method, + n_threads=n_threads, + max_iter=max_iter, + tol=tol, + verbose=False, + ) + + res = _bary_sample_bcd( + X_s, + X_init, + a_s, + b_init, + w_s, + metric, + inner_solver, + max_iter_bary, + tol_bary, + verbose, + True, # log set to True by default + nx, + ) + + return res + + else: + raise (NotImplementedError("Barycenter solver with lazy=True not implemented")) diff --git a/ot/utils.py b/ot/utils.py index 1f24fa33f..6f8a5682f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1310,6 +1310,177 @@ def citation(self): """ +class BaryResult: + """Base class for OT barycenter results. + + Parameters + ---------- + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + + Attributes + ---------- + + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + backend : Backend + Backend used to compute the results. + """ + + def __init__( + self, + X=None, + C=None, + b=None, + value=None, + value_linear=None, + value_quad=None, + log=None, + list_res=None, + backend=None, + ): + self._X = X + self._C = C + self._b = b + self._value = value + self._value_linear = value_linear + self._value_quad = value_quad + self._log = log + self._list_res = list_res + self._backend = backend if backend is not None else NumpyBackend() + + def __repr__(self): + s = "BaryResult(" + if self._value is not None: + s += "value={},".format(self._value) + if self._value_linear is not None: + s += "value_linear={},".format(self._value_linear) + if self._X is not None: + s += "X={}(shape={}),".format(self._X.__class__.__name__, self._X.shape) + if self._C is not None: + s += "C={}(shape={}),".format(self._C.__class__.__name__, self._C.shape) + if self._b is not None: + s += "b={}(shape={}),".format(self._b.__class__.__name__, self._b.shape) + if s[-1] != "(": + s = s[:-1] + ")" + else: + s = s + ")" + return s + + # Barycerters -------------------------------- + + @property + def X(self): + """Barycenter features.""" + return self._X + + @property + def C(self): + """Barycenter structure for Gromov Wasserstein solutions.""" + return self._C + + @property + def b(self): + """Barycenter weights.""" + return self._b + + # Loss values -------------------------------- + + @property + def value(self): + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" + return self._value + + @property + def value_linear(self): + """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" + return self._value_linear + + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + return self._value_quad + + # List of OTResult objects ------------------------- + + @property + def list_res(self): + """List of results for the individual OT matching.""" + return self._list_res + + @property + def status(self): + """Optimization status of the solver.""" + return self._status + + @property + def log(self): + """Dictionary containing potential information about the solver.""" + return self._log + + # Miscellaneous -------------------------------- + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ + + class LazyTensor(object): """A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. diff --git a/test/test_solvers.py b/test/test_solvers.py index a0c1d7c43..c691b9cc0 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -1,6 +1,7 @@ """Tests for ot solvers""" # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License @@ -703,3 +704,138 @@ def test_solve_sample_NotImplemented(nx, method_params): with pytest.raises(NotImplementedError): ot.solve_sample(xb, yb, ab, bb, **method_params) + + +def assert_allclose_bary_sol(sol1, sol2): + lst_attr = ["X", "b", "value", "value_linear", "log"] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: + try: + var1 = getattr(sol1, attr) + var2 = getattr(sol2, attr) + if isinstance(var1, dict): # only contains lists + for key in var1.keys(): + np.allclose( + np.array(var1[key]), + np.array(var2[key]), + equal_nan=True, + ) + else: + np.allclose( + nx1.to_numpy(getattr(sol1, attr)), + nx2.to_numpy(getattr(sol2, attr)), + equal_nan=True, + ) + except NotImplementedError: + pass + elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: + return True + else: + return False + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type", + itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type), +) +def test_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState(0) + + K = 3 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_s = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_s = [ot.utils.unif(X.shape[0]) for X in X_s] + b = ot.utils.unif(n) + + w_s = ot.utils.unif(K) + + try: + if reg_type == "tuple": + + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + + # solve default None weights + sol0 = ot.bary_sample( + X_s, + n, + w_s=None, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + max_iter_bary=4, + tol_bary=1e-3, + verbose=True, + ) + + # solve provided uniform weights + sol = ot.bary_sample( + X_s, + n, + a_s=a_s, + b_init=b, + w_s=w_s, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + max_iter_bary=4, + tol_bary=1e-3, + verbose=True, + ) + + assert_allclose_bary_sol(sol0, sol) + + # solve in backend + X_sb = nx.from_numpy(*X_s) + a_sb = nx.from_numpy(*a_s) + w_sb, bb = nx.from_numpy(w_s, b) + + if isinstance(reg_type, tuple): + + def f(G): + return nx.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + + solb = ot.bary_sample( + X_sb, + n, + a_s=a_sb, + b_init=bb, + w_s=w_sb, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + max_iter_bary=4, + tol_bary=1e-3, + verbose=True, + ) + assert_allclose_bary_sol(sol, solb) + + except NotImplementedError: + pytest.skip("Not implemented") From 9e25e8080506512acf14b3c56623d228d751a05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 23 Apr 2025 00:48:18 +0200 Subject: [PATCH 3/5] handle masses in unbalanced cases --- ot/solvers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ot/solvers.py b/ot/solvers.py index 257e0d701..0c944e709 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1948,6 +1948,7 @@ def _bary_sample_bcd( w_s, metric, inner_solver, + update_masses, max_iter_bary, tol_bary, verbose, @@ -1972,6 +1973,8 @@ def _bary_sample_bcd( Metric to use for the cost matrix, by default "sqeuclidean" inner_solver : callable Function to solve the inner OT problem + update_masses : bool + Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. max_iter_bary : int Maximum number of iterations for the barycenter tol_bary : float @@ -2003,6 +2006,10 @@ def _bary_sample_bcd( # Solve the inner OT problem for each source distribution list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)] + # Update the estimated barycenter weights in unbalanced cases + if update_masses: + b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) + inv_b = 1.0 / b # Update the barycenter samples if metric in ["sqeuclidean", "euclidean"]: X_new = ( @@ -2461,6 +2468,8 @@ def inner_solver(X_a, X, a, b): verbose=False, ) + # compute the barycenter using BCD + update_masses = unbalanced is not None res = _bary_sample_bcd( X_s, X_init, @@ -2469,6 +2478,7 @@ def inner_solver(X_a, X, a, b): w_s, metric, inner_solver, + update_masses, max_iter_bary, tol_bary, verbose, From 46c46385eb27c5624e5ca97aaef232a5c429e2ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 24 Apr 2025 18:20:05 +0200 Subject: [PATCH 4/5] update free support --- ot/solvers.py | 102 ++++++++++++++++++++++++++++++++++++++----- ot/utils.py | 6 +++ test/test_solvers.py | 21 ++++++--- test/test_utils.py | 27 +++++++++++- 4 files changed, 136 insertions(+), 20 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 0c944e709..3e1c35ab6 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1949,6 +1949,9 @@ def _bary_sample_bcd( metric, inner_solver, update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, max_iter_bary, tol_bary, verbose, @@ -1975,6 +1978,12 @@ def _bary_sample_bcd( Function to solve the inner OT problem update_masses : bool Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. + warmstart_plan : bool + Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + warmstart_potentials : bool + Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + stopping_criterion : str + Stopping criterion for the BCD algorithm. Can be "loss" or "bary". max_iter_bary : int Maximum number of iterations for the barycenter tol_bary : float @@ -1994,22 +2003,41 @@ def _bary_sample_bcd( b = b_init inv_b = 1.0 / b - prev_loss = np.inf + prev_criterion = np.inf n_samples = len(X_s) if log: - log_ = {"loss": []} + log_ = {"stopping_criterion": []} else: log_ = None + # Compute the barycenter using BCD for it in range(max_iter_bary): # Solve the inner OT problem for each source distribution - list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)] + if it == 0: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + ] + elif warmstart_plan: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, list_res[k].plan, None) + for k in range(n_samples) + ] + elif warmstart_potentials: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, list_res[k].potentials) + for k in range(n_samples) + ] + else: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + ] # Update the estimated barycenter weights in unbalanced cases if update_masses: b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) inv_b = 1.0 / b + # Update the barycenter samples if metric in ["sqeuclidean", "euclidean"]: X_new = ( @@ -2019,30 +2047,40 @@ def _bary_sample_bcd( else: raise NotImplementedError('Not implemented metric="{}"'.format(metric)) - # compute loss - new_loss = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + # compute criterion + if stopping_criterion == "loss": + new_criterion = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + else: # stopping_criterion = "bary" + new_criterion = nx.norm(X_new - X, ord=2) if verbose: if it % 1 == 0: - print(f"BCD iteration {it}: loss = {new_loss:.4f}") + print( + f"BCD iteration {it}: criterion {stopping_criterion} = {new_criterion:.4f}" + ) if log: - log_["loss"].append(new_loss) + log_["stopping_criterion"].append(new_criterion) # Check convergence - if abs(new_loss - prev_loss) / abs(prev_loss) < tol_bary: + if abs(new_criterion - prev_criterion) / abs(prev_criterion) < tol_bary: print(f"BCD converged in {it} iterations") break X = X_new - prev_loss = new_loss + prev_criterion = new_criterion + + # compute loss values - # compute value_linear value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)]) + if stopping_criterion == "loss": + value = new_criterion + else: + value = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) # update BaryResult bary_res = BaryResult( X=X_new, b=b, - value=new_loss, + value=value, value_linear=value_linear, log=log_, list_res=list_res, @@ -2070,6 +2108,8 @@ def bary_sample( batch_size=None, method=None, n_threads=1, + warmstart=False, + stopping_criterion="loss", max_iter_bary=1000, max_iter=None, rank=100, @@ -2154,6 +2194,11 @@ def bary_sample( large scale solver. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 + warmstart : bool, optional + Use the previous OT or potentials as initialization for the next inner solver iteration, by default False. + stopping_criterion : str, optional + Stopping criterion for the outer loop of the BCD solver, by default 'loss'. + Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm. max_iter_bary : int, optional Maximum number of iteration for the BCD solver, by default 1000. max_iter : int, optional @@ -2398,6 +2443,13 @@ def bary_sample( if method is not None and method.lower() in lst_method_lazy: raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") + if stopping_criterion not in ["loss", "bary"]: + raise ValueError( + "stopping_criterion must be either 'loss' or 'bary', got {}".format( + stopping_criterion + ) + ) + n_samples = len(X_s) if ( @@ -2449,7 +2501,28 @@ def bary_sample( if b_init is None: b_init = nx.ones((n,), type_as=X_s[0]) / n - def inner_solver(X_a, X, a, b): + if warmstart: + if reg is None: # exact OT + warmstart_plan = True + warmstart_potentials = False + else: # regularized OT + # unbalanced AND regularized OT + if ( + not isinstance(reg_type, tuple) + and reg_type.lower() in ["kl"] + and unbalanced_type.lower() == "kl" + ): + warmstart_plan = False + warmstart_potentials = True + + else: + warmstart_plan = True + warmstart_potentials = False + else: + warmstart_plan = False + warmstart_potentials = False + + def inner_solver(X_a, X, a, b, plan_init, potentials_init): return solve_sample( X_a=X_a, X_b=X, @@ -2465,6 +2538,8 @@ def inner_solver(X_a, X, a, b): n_threads=n_threads, max_iter=max_iter, tol=tol, + plan_init=plan_init, + potentials_init=potentials_init, verbose=False, ) @@ -2479,6 +2554,9 @@ def inner_solver(X_a, X, a, b): metric, inner_solver, update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, max_iter_bary, tol_bary, verbose, diff --git a/ot/utils.py b/ot/utils.py index 6f8a5682f..8b045984b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1334,6 +1334,8 @@ class BaryResult: Dictionary containing potential information about the solver. list_res: list of OTResult List of results for the individual OT matching. + status : int or str + Status of the solver. Attributes ---------- @@ -1357,6 +1359,8 @@ class BaryResult: Dictionary containing potential information about the solver. list_res: list of OTResult List of results for the individual OT matching. + status : int or str + Status of the solver. backend : Backend Backend used to compute the results. """ @@ -1371,6 +1375,7 @@ def __init__( value_quad=None, log=None, list_res=None, + status=None, backend=None, ): self._X = X @@ -1381,6 +1386,7 @@ def __init__( self._value_quad = value_quad self._log = log self._list_res = list_res + self._status = status self._backend = backend if backend is not None else NumpyBackend() def __repr__(self): diff --git a/test/test_solvers.py b/test/test_solvers.py index c691b9cc0..75c58dd99 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -741,12 +741,16 @@ def assert_allclose_bary_sol(sol1, sol2): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") @pytest.mark.parametrize( - "reg,reg_type,unbalanced,unbalanced_type", - itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type), + "reg,reg_type,unbalanced,unbalanced_type,warmstart", + itertools.product( + lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] + ), ) -def test_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type): +def test_bary_sample_free_support( + nx, reg, reg_type, unbalanced, unbalanced_type, warmstart +): # test bary_sample when is_Lazy = False - rng = np.random.RandomState(0) + rng = np.random.RandomState() K = 3 # number of distributions ns = rng.randint(10, 20, K) # number of samples within each distribution @@ -781,7 +785,8 @@ def df(G): reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, - max_iter_bary=4, + warmstart=warmstart, + max_iter_bary=3, tol_bary=1e-3, verbose=True, ) @@ -798,7 +803,8 @@ def df(G): reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, - max_iter_bary=4, + warmstart=warmstart, + max_iter_bary=3, tol_bary=1e-3, verbose=True, ) @@ -831,7 +837,8 @@ def df(G): reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, - max_iter_bary=4, + warmstart=warmstart, + max_iter_bary=3, tol_bary=1e-3, verbose=True, ) diff --git a/test/test_utils.py b/test/test_utils.py index 938fd6058..1ecd1b51f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -456,7 +456,7 @@ def test_OTResult(): # test print print(res) - # tets get citation + # test get citation print(res.citation) lst_attributes = [ @@ -486,6 +486,31 @@ def test_OTResult(): getattr(res, at) +def test_BaryResult(): + res = ot.utils.BaryResult() + + # test print + print(res) + + # test get citation + print(res.citation) + + lst_attributes = [ + "X", + "C", + "b", + "value", + "value_linear", + "value_quad", + "list_res", + "status", + "log", + ] + for at in lst_attributes: + print(at) + assert getattr(res, at) is None + + def test_get_coordinate_circle(): rng = np.random.RandomState(42) u = rng.rand(1, 100) From 671788d40508042f5833695f0cb96d09fcdbdbbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Fri, 25 Apr 2025 01:12:41 +0200 Subject: [PATCH 5/5] trying to fix tests --- ot/solvers.py | 2 +- ot/unbalanced/_lbfgs.py | 8 ++++---- test/test_solvers.py | 33 +++++++++++++++++++++++++++------ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 3e1c35ab6..daad962a8 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -2032,7 +2032,7 @@ def _bary_sample_bcd( list_res = [ inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) ] - + print("inv_b:", inv_b) # Update the estimated barycenter weights in unbalanced cases if update_masses: b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c4de87474..eb995efb5 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -46,9 +46,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) @@ -218,9 +218,9 @@ def lbfgsb_unbalanced( Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) diff --git a/test/test_solvers.py b/test/test_solvers.py index 75c58dd99..6ede9b3f6 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -743,7 +743,12 @@ def assert_allclose_bary_sol(sol1, sol2): @pytest.mark.parametrize( "reg,reg_type,unbalanced,unbalanced_type,warmstart", itertools.product( - lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] + lst_reg, + ["tuple"], + lst_unbalanced, + lst_unbalanced_type, + [True, False], + # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] ), ) def test_bary_sample_free_support( @@ -774,7 +779,7 @@ def df(G): return 2 * G reg_type = (f, df) - + # print('test reg_type:', reg_type[0](None), reg_type[1](None)) # solve default None weights sol0 = ot.bary_sample( X_s, @@ -790,8 +795,10 @@ def df(G): tol_bary=1e-3, verbose=True, ) + print("------ [done] sol0 - no backend") # solve provided uniform weights + sol = ot.bary_sample( X_s, n, @@ -808,6 +815,7 @@ def df(G): tol_bary=1e-3, verbose=True, ) + print("------ [done] sol - no backend") assert_allclose_bary_sol(sol0, sol) @@ -816,14 +824,25 @@ def df(G): a_sb = nx.from_numpy(*a_s) w_sb, bb = nx.from_numpy(w_s, b) - if isinstance(reg_type, tuple): + if reg_type == "tuple": - def f(G): - return nx.sum(G**2) + def fb(G): + return nx.sum( + G**2 + ) # otherwise we keep previously defined (f, df) as required by inner solver - def df(G): + def dfb(G): return 2 * G + """ + if ( + unbalanced_type.lower() in ["kl", "l2", "tv"]) and ( + unbalanced is not None) and ( + reg is not None + ): + reg_type = (f, df) + else: + """ reg_type = (f, df) solb = ot.bary_sample( @@ -842,6 +861,8 @@ def df(G): tol_bary=1e-3, verbose=True, ) + print("------ [done] sol - with backend") + assert_allclose_bary_sol(sol, solb) except NotImplementedError: