diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py
index bbae9fc3..f83cb525 100644
--- a/mesh_tensorflow/transformer/utils.py
+++ b/mesh_tensorflow/transformer/utils.py
@@ -1429,7 +1429,6 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
   """
   scoring_datasets = score_dataset_fn(
       sequence_length=sequence_length,
-      vocabulary=vocabulary,
       dataset_split=dataset_split)
 
   def input_fn(params):
@@ -1598,7 +1597,6 @@ def input_fn(params):
 
     dataset = train_dataset_fn(
         sequence_length=sequence_length,
-        vocabulary=vocabulary,
         dataset_split=dataset_split)
     dataset = dataset.repeat().batch(
         batch_size * (ensemble_inputs or 1), drop_remainder=True)
@@ -1713,7 +1711,6 @@ def eval_model(estimator, vocabulary, sequence_length, batch_size,
 
   eval_datasets = eval_dataset_fn(
       sequence_length=sequence_length,
-      vocabulary=vocabulary,
       dataset_split=dataset_split,
   )
 
@@ -2356,7 +2353,6 @@ def run(tpu_job_name,
             name="eval",
             dataset_fn=functools.partial(train_dataset_fn,
                                          sequence_length=sequence_length,
-                                         vocabulary=vocabulary,
                                          dataset_split=dataset_split),
             postprocess_fn=None,
             metric_fns=None)]
@@ -2367,7 +2363,6 @@ def run(tpu_job_name,
     else:
       eval_datasets = eval_dataset_fn(
           sequence_length=sequence_length,
-          vocabulary=vocabulary,
           dataset_split=dataset_split,
       )
     def _input_fn(params, eval_dataset):