9
9
#
10
10
# License: MIT License
11
11
12
- import warnings
13
12
import numpy as np
14
13
from scipy .optimize import minimize , Bounds
15
14
16
15
from ..backend import get_backend
17
- from ..utils import list_to_array , get_parameter_pair
16
+ from ..utils import list_to_array , get_parameter_pair , fun_to_numpy
18
17
19
18
20
19
def _get_loss_unbalanced (a , b , c , M , reg , reg_m1 , reg_m2 , reg_div = "kl" , regm_div = "kl" ):
@@ -46,9 +45,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div
46
45
Divergence used for regularization.
47
46
Can take three values: 'entropy' (negative entropy), or
48
47
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
49
- of two calable functions returning the reg term and its derivative.
48
+ of two callable functions returning the reg term and its derivative.
50
49
Note that the callable functions should be able to handle Numpy arrays
51
- and not tesors from the backend
50
+ and not tensors from the backend
52
51
regm_div: string, optional
53
52
Divergence to quantify the difference between the marginals.
54
53
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
@@ -206,26 +205,27 @@ def lbfgsb_unbalanced(
206
205
loss matrix
207
206
reg: float
208
207
regularization term >=0
209
- c : array-like (dim_a, dim_b), optional (default = None)
210
- Reference measure for the regularization.
211
- If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
212
208
reg_m: float or indexable object of length 1 or 2
213
209
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
214
210
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
215
211
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
216
212
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
217
- reg_div: string, optional
213
+ c : array-like (dim_a, dim_b), optional (default = None)
214
+ Reference measure for the regularization.
215
+ If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
216
+ reg_div: string or pair of callable functions, optional (default = 'kl')
218
217
Divergence used for regularization.
219
218
Can take three values: 'entropy' (negative entropy), or
220
219
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
221
- of two calable functions returning the reg term and its derivative.
220
+ of two callable functions returning the reg term and its derivative.
222
221
Note that the callable functions should be able to handle Numpy arrays
223
- and not tesors from the backend
224
- regm_div: string, optional
222
+ and not tensors from the backend, otherwise functions will be converted to Numpy
223
+ leading to a computational overhead.
224
+ regm_div: string, optional (default = 'kl')
225
225
Divergence to quantify the difference between the marginals.
226
226
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
227
- G0: array-like (dim_a, dim_b)
228
- Initialization of the transport matrix
227
+ G0: array-like (dim_a, dim_b), optional (default = None)
228
+ Initialization of the transport matrix. None corresponds to uniform product.
229
229
numItermax : int, optional
230
230
Max number of iterations
231
231
stopThr : float, optional
@@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
267
267
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
268
268
"""
269
269
270
- # wrap the callable function to handle numpy arrays
271
- if isinstance (reg_div , tuple ):
272
- f0 , df0 = reg_div
273
- try :
274
- f0 (G0 )
275
- df0 (G0 )
276
- except BaseException :
277
- warnings .warn (
278
- "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
279
- )
280
-
281
- def f (x ):
282
- return nx .to_numpy (f0 (nx .from_numpy (x , type_as = M0 )))
283
-
284
- def df (x ):
285
- return nx .to_numpy (df0 (nx .from_numpy (x , type_as = M0 )))
286
-
287
- reg_div = (f , df )
270
+ # test settings
271
+ regm_div = regm_div .lower ()
272
+ if regm_div not in ["kl" , "l2" , "tv" ]:
273
+ raise ValueError (
274
+ "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'" .format (regm_div )
275
+ )
288
276
289
- else :
277
+ if isinstance ( reg_div , str ) :
290
278
reg_div = reg_div .lower ()
291
279
if reg_div not in ["entropy" , "kl" , "l2" ]:
292
280
raise ValueError (
@@ -295,16 +283,11 @@ def df(x):
295
283
)
296
284
)
297
285
298
- regm_div = regm_div .lower ()
299
- if regm_div not in ["kl" , "l2" , "tv" ]:
300
- raise ValueError (
301
- "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'" .format (regm_div )
302
- )
303
-
286
+ # convert all inputs to numpy arrays
304
287
reg_m1 , reg_m2 = get_parameter_pair (reg_m )
305
288
306
289
M , a , b = list_to_array (M , a , b )
307
- nx = get_backend (M , a , b )
290
+ nx = get_backend (M , a , b , G0 )
308
291
M0 = M
309
292
310
293
dim_a , dim_b = M .shape
@@ -315,10 +298,22 @@ def df(x):
315
298
b = nx .ones (dim_b , type_as = M ) / dim_b
316
299
317
300
# convert to numpy
318
- a , b , M , reg_m1 , reg_m2 , reg = nx .to_numpy (a , b , M , reg_m1 , reg_m2 , reg )
301
+ if nx .__name__ == "numpy" : # remaining parameters which can be arrays
302
+ reg_m1 , reg_m2 , reg = nx .to_numpy (reg_m1 , reg_m2 , reg )
303
+ else :
304
+ a , b , M , reg_m1 , reg_m2 , reg = nx .to_numpy (a , b , M , reg_m1 , reg_m2 , reg )
305
+
319
306
G0 = a [:, None ] * b [None , :] if G0 is None else nx .to_numpy (G0 )
320
307
c = a [:, None ] * b [None , :] if c is None else nx .to_numpy (c )
321
308
309
+ # potentially convert the callable function to handle numpy arrays
310
+ if isinstance (reg_div , tuple ):
311
+ f0 , df0 = reg_div
312
+ f = fun_to_numpy (f0 , G0 , nx , warn = True )
313
+ df = fun_to_numpy (df0 , G0 , nx , warn = True )
314
+
315
+ reg_div = (f , df )
316
+
322
317
_func = _get_loss_unbalanced (a , b , c , M , reg , reg_m1 , reg_m2 , reg_div , regm_div )
323
318
324
319
res = minimize (
@@ -399,26 +394,27 @@ def lbfgsb_unbalanced2(
399
394
loss matrix
400
395
reg: float
401
396
regularization term >=0
402
- c : array-like (dim_a, dim_b), optional (default = None)
403
- Reference measure for the regularization.
404
- If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
405
397
reg_m: float or indexable object of length 1 or 2
406
398
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
407
399
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
408
400
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
409
401
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
410
- reg_div: string, optional
402
+ c : array-like (dim_a, dim_b), optional (default = None)
403
+ Reference measure for the regularization.
404
+ If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
405
+ reg_div: string or pair of callable functions, optional (default = 'kl')
411
406
Divergence used for regularization.
412
407
Can take three values: 'entropy' (negative entropy), or
413
408
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
414
- of two calable functions returning the reg term and its derivative.
409
+ of two callable functions returning the reg term and its derivative.
415
410
Note that the callable functions should be able to handle Numpy arrays
416
- and not tesors from the backend
417
- regm_div: string, optional
411
+ and not tensors from the backend, otherwise functions will be converted to Numpy
412
+ leading to a computational overhead.
413
+ regm_div: string, optional (default = 'kl')
418
414
Divergence to quantify the difference between the marginals.
419
415
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
420
- G0: array-like (dim_a, dim_b)
421
- Initialization of the transport matrix
416
+ G0: array-like (dim_a, dim_b), optional (default = None)
417
+ Initialization of the transport matrix. None corresponds to uniform product.
422
418
returnCost: string, optional (default = "linear")
423
419
If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
424
420
If `returnCost` = "total", then return the total unbalanced OT loss.
0 commit comments