1818from .bregman import sinkhorn , jcpot_barycenter
1919from .lp import emd
2020from .utils import unif , dist , kernel , cost_normalization , label_normalization , laplacian , dots
21- from .utils import list_to_array , check_params , BaseEstimator , deprecated
21+ from .utils import BaseEstimator , check_params , deprecated , labels_to_masks , list_to_array
2222from .unbalanced import sinkhorn_unbalanced
2323from .gaussian import empirical_bures_wasserstein_mapping , empirical_gaussian_gromov_wasserstein_mapping
2424from .optim import cg
@@ -499,18 +499,12 @@ class label
499499 if self .limit_max != np .infty :
500500 self .limit_max = self .limit_max * nx .max (self .cost_ )
501501
502- # assumes labeled source samples occupy the first rows
503- # and labeled target samples occupy the first columns
504- classes = [c for c in nx .unique (ys ) if c != - 1 ]
505- for c in classes :
506- idx_s = nx .where ((ys != c ) & (ys != - 1 ))
507- idx_t = nx .where (yt == c )
508-
509- # all the coefficients corresponding to a source sample
510- # and a target sample :
511- # with different labels get a infinite
512- for j in idx_t [0 ]:
513- self .cost_ [idx_s [0 ], j ] = self .limit_max
502+ # zeros where source label is missing (masked with -1)
503+ missing_labels = ys + nx .ones (ys .shape , type_as = ys )
504+ missing_labels = nx .repeat (missing_labels [:, None ], ys .shape [0 ], 1 )
505+ # zeros where labels match
506+ label_match = ys [:, None ] - yt [None , :]
507+ self .cost_ = nx .maximum (self .cost_ , nx .abs (label_match ) * nx .abs (missing_labels ) * self .limit_max )
514508
515509 # distribution estimation
516510 self .mu_s = self .distribution_estimation (Xs )
@@ -581,12 +575,11 @@ class label
581575 if check_params (Xs = Xs ):
582576
583577 if nx .array_equal (self .xs_ , Xs ):
584-
585578 # perform standard barycentric mapping
586579 transp = self .coupling_ / nx .sum (self .coupling_ , axis = 1 )[:, None ]
587580
588581 # set nans to 0
589- transp [ ~ nx .isfinite (transp )] = 0
582+ transp = nx .nan_to_num (transp , nan = 0 , posinf = 0 , neginf = 0 )
590583
591584 # compute transported samples
592585 transp_Xs = nx .dot (transp , self .xt_ )
@@ -604,9 +597,8 @@ class label
604597 idx = nx .argmin (D0 , axis = 1 )
605598
606599 # transport the source samples
607- transp = self .coupling_ / nx .sum (
608- self .coupling_ , axis = 1 )[:, None ]
609- transp [~ nx .isfinite (transp )] = 0
600+ transp = self .coupling_ / nx .sum (self .coupling_ , axis = 1 )[:, None ]
601+ transp = nx .nan_to_num (transp , nan = 0 , posinf = 0 , neginf = 0 )
610602 transp_Xs_ = nx .dot (transp , self .xt_ )
611603
612604 # define the transported points
@@ -645,23 +637,16 @@ def transform_labels(self, ys=None):
645637
646638 # check the necessary inputs parameters are here
647639 if check_params (ys = ys ):
648-
649- ysTemp = label_normalization (nx .copy (ys ))
650- classes = nx .unique (ysTemp )
651- n = len (classes )
652- D1 = nx .zeros ((n , len (ysTemp )), type_as = self .coupling_ )
653-
654640 # perform label propagation
655641 transp = self .coupling_ / nx .sum (self .coupling_ , axis = 0 )[None , :]
656642
657643 # set nans to 0
658- transp [~ nx .isfinite (transp )] = 0
659-
660- for c in classes :
661- D1 [int (c ), ysTemp == c ] = 1
644+ transp = nx .nan_to_num (transp , nan = 0 , posinf = 0 , neginf = 0 )
662645
663646 # compute propagated labels
664- transp_ys = nx .dot (D1 , transp )
647+ labels = label_normalization (ys )
648+ masks = labels_to_masks (labels , nx = nx , type_as = transp )
649+ transp_ys = nx .dot (masks .T , transp )
665650
666651 return transp_ys .T
667652
@@ -697,12 +682,11 @@ class label
697682 if check_params (Xt = Xt ):
698683
699684 if nx .array_equal (self .xt_ , Xt ):
700-
701685 # perform standard barycentric mapping
702686 transp_ = self .coupling_ .T / nx .sum (self .coupling_ , 0 )[:, None ]
703687
704688 # set nans to 0
705- transp_ [ ~ nx .isfinite (transp_ )] = 0
689+ transp_ = nx .nan_to_num (transp_ , nan = 0 , posinf = 0 , neginf = 0 )
706690
707691 # compute transported samples
708692 transp_Xt = nx .dot (transp_ , self .xs_ )
@@ -719,9 +703,8 @@ class label
719703 idx = nx .argmin (D0 , axis = 1 )
720704
721705 # transport the target samples
722- transp_ = self .coupling_ .T / nx .sum (
723- self .coupling_ , 0 )[:, None ]
724- transp_ [~ nx .isfinite (transp_ )] = 0
706+ transp_ = self .coupling_ .T / nx .sum (self .coupling_ , 0 )[:, None ]
707+ transp_ = nx .nan_to_num (transp_ , nan = 0 , posinf = 0 , neginf = 0 )
725708 transp_Xt_ = nx .dot (transp_ , self .xs_ )
726709
727710 # define the transported points
@@ -750,23 +733,15 @@ def inverse_transform_labels(self, yt=None):
750733
751734 # check the necessary inputs parameters are here
752735 if check_params (yt = yt ):
753-
754- ytTemp = label_normalization (nx .copy (yt ))
755- classes = nx .unique (ytTemp )
756- n = len (classes )
757- D1 = nx .zeros ((n , len (ytTemp )), type_as = self .coupling_ )
758-
759736 # perform label propagation
760737 transp = self .coupling_ / nx .sum (self .coupling_ , 1 )[:, None ]
761-
762738 # set nans to 0
763- transp [ ~ nx .isfinite (transp )] = 0
739+ transp = nx .nan_to_num (transp , nan = 0 , posinf = 0 , neginf = 0 )
764740
765- for c in classes :
766- D1 [int (c ), ytTemp == c ] = 1
767-
768- # compute propagated samples
769- transp_ys = nx .dot (D1 , transp .T )
741+ # compute propagated labels
742+ labels = label_normalization (yt )
743+ masks = labels_to_masks (labels , nx = nx , type_as = transp )
744+ transp_ys = nx .dot (masks .T , transp .T )
770745
771746 return transp_ys .T
772747
@@ -2151,7 +2126,7 @@ def transform_labels(self, ys=None):
21512126 type_as = ys [0 ]
21522127 )
21532128 for i in range (len (ys )):
2154- ysTemp = label_normalization (nx . copy ( ys [i ]) )
2129+ ysTemp = label_normalization (ys [i ])
21552130 classes = nx .unique (ysTemp )
21562131 n = len (classes )
21572132 ns = len (ysTemp )
@@ -2194,7 +2169,7 @@ def inverse_transform_labels(self, yt=None):
21942169 # check the necessary inputs parameters are here
21952170 if check_params (yt = yt ):
21962171 transp_ys = []
2197- ytTemp = label_normalization (nx . copy ( yt ) )
2172+ ytTemp = label_normalization (yt )
21982173 classes = nx .unique (ytTemp )
21992174 n = len (classes )
22002175 D1 = nx .zeros ((n , len (ytTemp )), type_as = self .coupling_ [0 ])
0 commit comments