@@ -1101,6 +1101,105 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
11011101  return  loss 
11021102
11031103
1104+ def  kl_divergence (y_true , y_pred , reduced_dim , weights = None , epsilon = 1e-6 ):
1105+   """Kullback-Leibler-Divergence between `y_true` and `y_pred`. 
1106+ 
1107+   Computes: `loss = y_true * log(y_true / y_pred)` 
1108+   From: tf.keras.losses.KLDivergence (Custom implementation with mtf) 
1109+   See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 
1110+ 
1111+   Args: 
1112+     y_true: mtf.Tensor, target predictions (distribution). 
1113+     y_pred: mtf.Tensor, actual predictions (distribution). 
1114+     reduced_dim: mtf.Dimension, reduction dimension for sum. 
1115+     weights: Optional mtf.Tensor, indicator for padded regions. 
1116+     epsilon: float, minimum value for numerical stability. 
1117+   Returns: 
1118+     scalar: K-L Divergence loss. 
1119+   Raises: 
1120+     ValueError: if the shapes do not match or reduced_dim is not valid. 
1121+   """ 
1122+   if  set (y_true .shape .dims ) !=  set (y_pred .shape .dims ):
1123+     raise  ValueError (
1124+         "`y_true` and `y_pred` must be of the same shape. " 
1125+         f"Currently they are { y_true .shape .dims }   and { y_pred .shape .dims }  " )
1126+   if  reduced_dim  not  in   y_true .shape .dims :
1127+     raise  ValueError (
1128+         f"`reduced_dim` must be a valid dimension (from { y_true .shape .dims }  )." )
1129+   if  weights  is  None :
1130+     weights  =  1. 
1131+ 
1132+   def  _clip (x , min_value , max_value ):
1133+     # Clip values for numerical stability. 
1134+     x  =  mtf .maximum (x , min_value )
1135+     x  =  mtf .minimum (x , max_value )
1136+     return  x 
1137+ 
1138+   y_true  =  _clip (y_true , epsilon , 1. )
1139+   y_pred  =  _clip (y_pred , epsilon , 1. )
1140+   return  mtf .reduce_sum (weights  *  y_true  *  mtf .log (y_true  /  y_pred ))
1141+ 
1142+ 
1143+ def  mean_squared_error (y_true , y_pred , weights = None ):
1144+   """L2-Loss between `y_true` and `y_pred`. 
1145+ 
1146+   Args: 
1147+     y_true: mtf.Tensor, target logits. 
1148+     y_pred: mtf.Tensor, actual logits. 
1149+     weights: Optional mtf.Tensor, indicator for padded regions. 
1150+   Returns: 
1151+     scalar: L2 loss. 
1152+   Raises: 
1153+     ValueError: if the shapes do not match or reduced_dim is not valid. 
1154+   """ 
1155+   if  set (y_true .shape .dims ) !=  set (y_pred .shape .dims ):
1156+     raise  ValueError (
1157+         "`y_true` and `y_pred` must be of the same shape. " 
1158+         f"Currently they are { y_true .shape .dims }   and { y_pred .shape .dims }  " )
1159+   if  weights  is  None :
1160+     weights  =  1. 
1161+   return  mtf .reduce_sum (weights  *  mtf .square (y_true  -  y_pred ))
1162+ 
1163+ 
1164+ def  cosine_embedding_distill (y_true , y_pred , reduced_dim , weights = None ,
1165+                              epsilon = 1e-6 ):
1166+   """Cosine embedding loss for distillation from teacher to student logits. 
1167+ 
1168+   See: https://arxiv.org/abs/1910.01108 (DistilBert) and 
1169+   https://github.com/huggingface/transformers/tree/master/examples/ 
1170+     research_projects/distillation. 
1171+ 
1172+   Args: 
1173+     y_true: mtf.Tensor, teacher logits. 
1174+     y_pred: mtf.Tensor, student logits. 
1175+     reduced_dim: mtf.Dimension, reduction dimension for sum. 
1176+     weights: Optional mtf.Tensor, indicator for padded regions. 
1177+     epsilon: float, for numerical stability. 
1178+   Returns: 
1179+     scalar: mean cosine embedding distance. 
1180+   Raises: 
1181+     ValueError: if the shapes do not match or reduced_dim is not valid. 
1182+   """ 
1183+   if  set (y_true .shape .dims ) !=  set (y_pred .shape .dims ):
1184+     raise  ValueError (
1185+         "`y_true` and `y_pred` must be of the same shape. " 
1186+         f"Currently they are { y_true .shape .dims }   and { y_pred .shape .dims }  " )
1187+   if  reduced_dim  not  in   y_true .shape .dims :
1188+     raise  ValueError (
1189+         f"`reduced_dim` must be a valid dimension (from { y_true .shape .dims }  )." )
1190+   if  weights  is  None :
1191+     weights  =  1. 
1192+ 
1193+   prod_sum  =  mtf .reduce_sum (y_true  *  y_pred , reduced_dim = reduced_dim )
1194+   y_true_sq_sum  =  mtf .reduce_sum (y_true  *  y_true , reduced_dim = reduced_dim )
1195+   y_pred_sq_sum  =  mtf .reduce_sum (y_pred  *  y_pred , reduced_dim = reduced_dim )
1196+   inv_denom  =  mtf .rsqrt (y_true_sq_sum  *  y_pred_sq_sum  +  epsilon )
1197+   cos  =  prod_sum  *  inv_denom 
1198+   # TODO(vinaysrao): Turn this into a more general cosine embedding loss with 
1199+   # a `targets` tensor. 
1200+   return  mtf .reduce_sum (weights  *  (1.  -  cos ))
1201+ 
1202+ 
11041203def  sigmoid_cross_entropy_with_logits (logits , targets ):
11051204  """Sigmoid cross-entropy loss. 
11061205
0 commit comments