diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index bbae9fc3..f83cb525 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1429,7 +1429,6 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length, """ scoring_datasets = score_dataset_fn( sequence_length=sequence_length, - vocabulary=vocabulary, dataset_split=dataset_split) def input_fn(params): @@ -1598,7 +1597,6 @@ def input_fn(params): dataset = train_dataset_fn( sequence_length=sequence_length, - vocabulary=vocabulary, dataset_split=dataset_split) dataset = dataset.repeat().batch( batch_size * (ensemble_inputs or 1), drop_remainder=True) @@ -1713,7 +1711,6 @@ def eval_model(estimator, vocabulary, sequence_length, batch_size, eval_datasets = eval_dataset_fn( sequence_length=sequence_length, - vocabulary=vocabulary, dataset_split=dataset_split, ) @@ -2356,7 +2353,6 @@ def run(tpu_job_name, name="eval", dataset_fn=functools.partial(train_dataset_fn, sequence_length=sequence_length, - vocabulary=vocabulary, dataset_split=dataset_split), postprocess_fn=None, metric_fns=None)] @@ -2367,7 +2363,6 @@ def run(tpu_job_name, else: eval_datasets = eval_dataset_fn( sequence_length=sequence_length, - vocabulary=vocabulary, dataset_split=dataset_split, ) def _input_fn(params, eval_dataset):