Skip to content

Commit a97a788

Browse files
[MRG] Fix None init plan in unbalanced lbfgs solvers (#731)
* merge * init commit * up * add fun_to_numpy in utils * complete tests * update releases * improve doc --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent 2825682 commit a97a788

File tree

4 files changed

+104
-49
lines changed

4 files changed

+104
-49
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- Backend implementation of `ot.dist` for (PR #701)
1818
- Updated documentation Quickstart guide and User guide with new API (PR #726)
1919
- Fix jax version for auto-grad (PR #732)
20+
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
2021
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)
2122
- Removed release information from quickstart guide (PR #744)
2223

ot/unbalanced/_lbfgs.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
#
1010
# License: MIT License
1111

12-
import warnings
1312
import numpy as np
1413
from scipy.optimize import minimize, Bounds
1514

1615
from ..backend import get_backend
17-
from ..utils import list_to_array, get_parameter_pair
16+
from ..utils import list_to_array, get_parameter_pair, fun_to_numpy
1817

1918

2019
def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div="kl"):
@@ -46,9 +45,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div
4645
Divergence used for regularization.
4746
Can take three values: 'entropy' (negative entropy), or
4847
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
49-
of two calable functions returning the reg term and its derivative.
48+
of two callable functions returning the reg term and its derivative.
5049
Note that the callable functions should be able to handle Numpy arrays
51-
and not tesors from the backend
50+
and not tensors from the backend
5251
regm_div: string, optional
5352
Divergence to quantify the difference between the marginals.
5453
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
@@ -206,26 +205,27 @@ def lbfgsb_unbalanced(
206205
loss matrix
207206
reg: float
208207
regularization term >=0
209-
c : array-like (dim_a, dim_b), optional (default = None)
210-
Reference measure for the regularization.
211-
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
212208
reg_m: float or indexable object of length 1 or 2
213209
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
214210
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
215211
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
216212
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
217-
reg_div: string, optional
213+
c : array-like (dim_a, dim_b), optional (default = None)
214+
Reference measure for the regularization.
215+
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
216+
reg_div: string or pair of callable functions, optional (default = 'kl')
218217
Divergence used for regularization.
219218
Can take three values: 'entropy' (negative entropy), or
220219
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
221-
of two calable functions returning the reg term and its derivative.
220+
of two callable functions returning the reg term and its derivative.
222221
Note that the callable functions should be able to handle Numpy arrays
223-
and not tesors from the backend
224-
regm_div: string, optional
222+
and not tensors from the backend, otherwise functions will be converted to Numpy
223+
leading to a computational overhead.
224+
regm_div: string, optional (default = 'kl')
225225
Divergence to quantify the difference between the marginals.
226226
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
227-
G0: array-like (dim_a, dim_b)
228-
Initialization of the transport matrix
227+
G0: array-like (dim_a, dim_b), optional (default = None)
228+
Initialization of the transport matrix. None corresponds to uniform product.
229229
numItermax : int, optional
230230
Max number of iterations
231231
stopThr : float, optional
@@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
267267
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
268268
"""
269269

270-
# wrap the callable function to handle numpy arrays
271-
if isinstance(reg_div, tuple):
272-
f0, df0 = reg_div
273-
try:
274-
f0(G0)
275-
df0(G0)
276-
except BaseException:
277-
warnings.warn(
278-
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
279-
)
280-
281-
def f(x):
282-
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))
283-
284-
def df(x):
285-
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))
286-
287-
reg_div = (f, df)
270+
# test settings
271+
regm_div = regm_div.lower()
272+
if regm_div not in ["kl", "l2", "tv"]:
273+
raise ValueError(
274+
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
275+
)
288276

289-
else:
277+
if isinstance(reg_div, str):
290278
reg_div = reg_div.lower()
291279
if reg_div not in ["entropy", "kl", "l2"]:
292280
raise ValueError(
@@ -295,16 +283,11 @@ def df(x):
295283
)
296284
)
297285

298-
regm_div = regm_div.lower()
299-
if regm_div not in ["kl", "l2", "tv"]:
300-
raise ValueError(
301-
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
302-
)
303-
286+
# convert all inputs to numpy arrays
304287
reg_m1, reg_m2 = get_parameter_pair(reg_m)
305288

306289
M, a, b = list_to_array(M, a, b)
307-
nx = get_backend(M, a, b)
290+
nx = get_backend(M, a, b, G0)
308291
M0 = M
309292

310293
dim_a, dim_b = M.shape
@@ -315,10 +298,22 @@ def df(x):
315298
b = nx.ones(dim_b, type_as=M) / dim_b
316299

317300
# convert to numpy
318-
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
301+
if nx.__name__ == "numpy": # remaining parameters which can be arrays
302+
reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg)
303+
else:
304+
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
305+
319306
G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0)
320307
c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c)
321308

309+
# potentially convert the callable function to handle numpy arrays
310+
if isinstance(reg_div, tuple):
311+
f0, df0 = reg_div
312+
f = fun_to_numpy(f0, G0, nx, warn=True)
313+
df = fun_to_numpy(df0, G0, nx, warn=True)
314+
315+
reg_div = (f, df)
316+
322317
_func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div)
323318

324319
res = minimize(
@@ -399,26 +394,27 @@ def lbfgsb_unbalanced2(
399394
loss matrix
400395
reg: float
401396
regularization term >=0
402-
c : array-like (dim_a, dim_b), optional (default = None)
403-
Reference measure for the regularization.
404-
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
405397
reg_m: float or indexable object of length 1 or 2
406398
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
407399
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
408400
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
409401
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
410-
reg_div: string, optional
402+
c : array-like (dim_a, dim_b), optional (default = None)
403+
Reference measure for the regularization.
404+
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
405+
reg_div: string or pair of callable functions, optional (default = 'kl')
411406
Divergence used for regularization.
412407
Can take three values: 'entropy' (negative entropy), or
413408
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
414-
of two calable functions returning the reg term and its derivative.
409+
of two callable functions returning the reg term and its derivative.
415410
Note that the callable functions should be able to handle Numpy arrays
416-
and not tesors from the backend
417-
regm_div: string, optional
411+
and not tensors from the backend, otherwise functions will be converted to Numpy
412+
leading to a computational overhead.
413+
regm_div: string, optional (default = 'kl')
418414
Divergence to quantify the difference between the marginals.
419415
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
420-
G0: array-like (dim_a, dim_b)
421-
Initialization of the transport matrix
416+
G0: array-like (dim_a, dim_b), optional (default = None)
417+
Initialization of the transport matrix. None corresponds to uniform product.
422418
returnCost: string, optional (default = "linear")
423419
If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
424420
If `returnCost` = "total", then return the total unbalanced OT loss.

ot/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,3 +1473,43 @@ def check_number_threads(numThreads):
14731473
'numThreads should either be "max" or a strictly positive integer'
14741474
)
14751475
return numThreads
1476+
1477+
1478+
def fun_to_numpy(fun, arr, nx, warn=True):
1479+
"""Convert a function to a numpy function.
1480+
1481+
Parameters
1482+
----------
1483+
fun : callable
1484+
The function to convert.
1485+
arr : array-like
1486+
The input to test the function. Can be from any backend.
1487+
nx : Backend
1488+
The backend to use for the conversion.
1489+
warn : bool, optional
1490+
Whether to raise a warning if the function is not compatible with numpy.
1491+
Default is True.
1492+
Returns
1493+
-------
1494+
fun_numpy : callable
1495+
The converted function.
1496+
"""
1497+
if arr is None:
1498+
raise ValueError("arr should not be None to test fun")
1499+
1500+
nx_arr = get_backend(arr)
1501+
if nx_arr.__name__ != "numpy":
1502+
arr = nx.to_numpy(arr)
1503+
try:
1504+
fun(arr)
1505+
return fun
1506+
except BaseException:
1507+
if warn:
1508+
warnings.warn(
1509+
"The callable function should be able to handle numpy arrays, a compatible function is created and comes with overhead"
1510+
)
1511+
1512+
def fun_numpy(x):
1513+
return nx.to_numpy(fun(nx.from_numpy(x)))
1514+
1515+
return fun_numpy

test/test_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,3 +731,21 @@ def test_exp_bures(nx):
731731
# exp_\Lambda(log_\Lambda(Sigma)) = Sigma
732732
Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T))
733733
np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5)
734+
735+
736+
def test_fun_to_numpy(nx):
737+
arr = np.arange(5)
738+
arrb = nx.from_numpy(arr)
739+
740+
def fun(x): # backend function
741+
return nx.sum(x)
742+
743+
fun_numpy = ot.utils.fun_to_numpy(fun, arrb, nx, warn=True)
744+
745+
res = nx.to_numpy(fun(arrb))
746+
res_np = fun_numpy(arr)
747+
748+
np.testing.assert_allclose(res, res_np)
749+
750+
with pytest.raises(ValueError):
751+
ot.utils.fun_to_numpy(fun, None, nx, warn=True)

0 commit comments

Comments
 (0)