diff --git a/mesh_tensorflow/transformer/transformer.py b/mesh_tensorflow/transformer/transformer.py index abad5713..b63db348 100644 --- a/mesh_tensorflow/transformer/transformer.py +++ b/mesh_tensorflow/transformer/transformer.py @@ -1650,7 +1650,10 @@ def __init__(self, teacher, temperature=None, fraction_soft=None, - distill_start_step=0, + mse_coeff=0., + kl_coeff=0., + cosine_coeff=0., + distill_start_steps=0, teacher_checkpoint=None, initialize_student_weights=False): """Create a StudentTeacher. @@ -1664,7 +1667,10 @@ def __init__(self, target cross entropy to the training loss. The rest of the loss will be the cross entropy with the one-hot actual label. Required only when training. - distill_start_step: an int, training steps after which teacher loss is + mse_coeff: MSE distillation loss co-efficient. + kl_coeff: KL-Divergence distillation loss co-efficient. + cosine_coeff: COsine-embedding distillation loss co-efficient. + distill_start_steps: an int, training steps after which teacher loss is incorporated in the overall loss. teacher_checkpoint: a string, the path to the teacher checkpoint that we wish to use. Required only when training. @@ -1676,9 +1682,15 @@ def __init__(self, self.teacher = teacher self.temperature = temperature self.fraction_soft = fraction_soft - self.distill_start_step = distill_start_step + self.distill_start_steps = distill_start_steps self.teacher_checkpoint = teacher_checkpoint self.initialize_student_weights = initialize_student_weights + self.kl_coeff = kl_coeff + self.cosine_coeff = cosine_coeff + self.mse_coeff = mse_coeff + if (fraction_soft + kl_coeff + cosine_coeff + mse_coeff) > 1.: + raise ValueError("Distillation co-efficients must not add up to a value " + "greater than 1.") def call_simple(self, inputs, @@ -1751,15 +1763,40 @@ def call_simple(self, weights = mtf.cast(mtf.greater(targets, 0), soft_loss.dtype) soft_loss = (mtf.reduce_sum(soft_loss * weights) / self.student.loss_denominator(targets, num_microbatches)) + if self.kl_coeff > 0.: + student_pred = mtf.softmax(student_logits / self.temperature, + output_vocab_dim) + kl_loss = mtf.layers.kl_divergence( + mtf.stop_gradient(soft_targets), student_pred, output_vocab_dim, + weights=weights) + else: + kl_loss = 0. + if self.cosine_coeff > 0.: + cosine_loss = mtf.layers.cosine_embedding_distill( + mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim, + weights=weights) + else: + cosine_loss = 0. + if self.mse_coeff > 0.: + mse_loss = mtf.layers.kl_divergence( + mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim, + weights=weights) + else: + mse_loss = 0. global_step = tf.train.get_or_create_global_step() - current_fraction_soft = tf.cast( + distill_loss_fraction = (self.fraction_soft + self.kl_coeff + + self.mse_coeff + self.kl_coeff) + current_distill_fraction = tf.cast( tf.cond( - tf.math.greater(global_step, self.distill_start_step), - lambda: self.fraction_soft, lambda: tf.constant(0.0)), + tf.math.greater(global_step, self.distill_start_steps), + lambda: distill_loss_fraction, lambda: tf.constant(0.0)), dtype=tf.bfloat16) - loss = (1.0 - current_fraction_soft) * hard_loss \ - + self.temperature**2 * current_fraction_soft * soft_loss + loss = (1.0 - current_distill_fraction) * hard_loss \ + + current_distill_fraction * ( + self.temperature**2 * soft_loss * self.fraction_soft + + self.kl_coeff * kl_loss + self.mse_coeff + mse_loss + + self.cosine_coeff * cosine_loss) return student_logits, loss