diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 2fdd24eb1..e05e13319 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -635,8 +635,9 @@ def video_raw_targets_bottom(x, model_hparams, vocab_size): # Loss transformations, applied to target features -def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weight_fn): +def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weights_fn): """Compute the CTC loss.""" + del model_hparams, vocab_size # unused arg logits = top_out with tf.name_scope("ctc_loss", values=[logits, targets]): @@ -658,7 +659,7 @@ def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weight_fn): time_major=False, preprocess_collapse_repeated=False, ctc_merge_repeated=False) - weights = weight_fn(targets) + weights = weights_fn(targets) return tf.reduce_sum(xent), tf.reduce_sum(weights)