diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 355c0e7f..4b92ca7c 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1220,8 +1220,11 @@ def input_fn(params): return dataset checkpoint_step = get_step_from_checkpoint_path(checkpoint_path) - decodes = decode( - estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path) + decodes = [ + d.decode("utf-8") if isinstance(d, bytes) else d + for d in decode(estimator, input_fn, vocabulary, checkpoint_path) + ] + # Remove any padded examples dataset_size = len(inputs) * repeats decodes = decodes[:dataset_size]