Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Use multiple target objectives for distillation. Also see cl/356382304 #290

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions mesh_tensorflow/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,10 @@ def __init__(self,
teacher,
temperature=None,
fraction_soft=None,
distill_start_step=0,
mse_coeff=0.,
kl_coeff=0.,
cosine_coeff=0.,
distill_start_steps=0,
teacher_checkpoint=None,
initialize_student_weights=False):
"""Create a StudentTeacher.
Expand All @@ -1664,7 +1667,10 @@ def __init__(self,
target cross entropy to the training loss. The rest of the loss will be
the cross entropy with the one-hot actual label. Required only when
training.
distill_start_step: an int, training steps after which teacher loss is
mse_coeff: MSE distillation loss co-efficient.
kl_coeff: KL-Divergence distillation loss co-efficient.
cosine_coeff: COsine-embedding distillation loss co-efficient.
distill_start_steps: an int, training steps after which teacher loss is
incorporated in the overall loss.
teacher_checkpoint: a string, the path to the teacher checkpoint that we
wish to use. Required only when training.
Expand All @@ -1676,9 +1682,15 @@ def __init__(self,
self.teacher = teacher
self.temperature = temperature
self.fraction_soft = fraction_soft
self.distill_start_step = distill_start_step
self.distill_start_steps = distill_start_steps
self.teacher_checkpoint = teacher_checkpoint
self.initialize_student_weights = initialize_student_weights
self.kl_coeff = kl_coeff
self.cosine_coeff = cosine_coeff
self.mse_coeff = mse_coeff
if (fraction_soft + kl_coeff + cosine_coeff + mse_coeff) > 1.:
raise ValueError("Distillation co-efficients must not add up to a value "
"greater than 1.")

def call_simple(self,
inputs,
Expand Down Expand Up @@ -1751,15 +1763,40 @@ def call_simple(self,
weights = mtf.cast(mtf.greater(targets, 0), soft_loss.dtype)
soft_loss = (mtf.reduce_sum(soft_loss * weights) /
self.student.loss_denominator(targets, num_microbatches))
if self.kl_coeff > 0.:
student_pred = mtf.softmax(student_logits / self.temperature,
output_vocab_dim)
kl_loss = mtf.layers.kl_divergence(
mtf.stop_gradient(soft_targets), student_pred, output_vocab_dim,
weights=weights)
else:
kl_loss = 0.
if self.cosine_coeff > 0.:
cosine_loss = mtf.layers.cosine_embedding_distill(
mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim,
weights=weights)
else:
cosine_loss = 0.
if self.mse_coeff > 0.:
mse_loss = mtf.layers.kl_divergence(
mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim,
weights=weights)
else:
mse_loss = 0.
global_step = tf.train.get_or_create_global_step()
current_fraction_soft = tf.cast(
distill_loss_fraction = (self.fraction_soft + self.kl_coeff +
self.mse_coeff + self.kl_coeff)
current_distill_fraction = tf.cast(
tf.cond(
tf.math.greater(global_step, self.distill_start_step),
lambda: self.fraction_soft, lambda: tf.constant(0.0)),
tf.math.greater(global_step, self.distill_start_steps),
lambda: distill_loss_fraction, lambda: tf.constant(0.0)),
dtype=tf.bfloat16)

loss = (1.0 - current_fraction_soft) * hard_loss \
+ self.temperature**2 * current_fraction_soft * soft_loss
loss = (1.0 - current_distill_fraction) * hard_loss \
+ current_distill_fraction * (
self.temperature**2 * soft_loss * self.fraction_soft +
self.kl_coeff * kl_loss + self.mse_coeff + mse_loss +
self.cosine_coeff * cosine_loss)

return student_logits, loss

Expand Down