From dc67bd87df532739eac9a2674f205cdf795c9123 Mon Sep 17 00:00:00 2001 From: Colin Raffel Date: Tue, 21 Sep 2021 19:59:53 -0400 Subject: [PATCH 1/4] Output raw model outputs during eval Currently, only the postprocessed model outputs are written out into a file suffixed with "predictions". This outputs an additional file suffixed with "outputs" that stores the raw model outputs, without postprocessing. --- mesh_tensorflow/transformer/utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index e64ca8a2..b1c66d4e 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1550,7 +1550,14 @@ def decode_from_dataset(estimator, # Extract the portion of decodes corresponding to this dataset dataset_size = len(examples_for_ds) predictions = decodes[:dataset_size] - + + # Write the raw outputs to file. + predictions_filename = os.path.join( + decode_output_dir, + "{}_{}_predictions".format(infer_dataset.name, checkpoint_step), + ) + write_lines_to_file(predictions, predictions_filename) + # Remove the used decodes. del decodes[:dataset_size] @@ -2398,18 +2405,24 @@ def eval_model(estimator, eval_dataset.postprocess_fn(d, example=ex) for d, ex in zip(outputs[:dataset_size], examples) ] - # Remove the used decodes. - del outputs[:dataset_size] global_step = int(get_step_from_checkpoint_path(checkpoint_path)) if output_eval_examples: + outputs_filename = os.path.join( + summary_dir, + "{}_{}_outputs".format((eval_dataset.name, global_step), + ) + write_lines_to_file(outputs[:dataset_size], outputs_filename) predictions_filename = os.path.join( eval_summary_dir, "{}_{}_predictions".format(eval_dataset.name, global_step), ) write_lines_to_file(predictions, predictions_filename) + # Remove the used decodes. + del outputs[:dataset_size] + for metric_fn in eval_dataset.metric_fns: summary = tf.Summary() targets = cached_targets[eval_dataset.name] From d7cda2d5011d1dd431b67faa2a3a5d7786101d04 Mon Sep 17 00:00:00 2001 From: Colin Raffel Date: Tue, 21 Sep 2021 20:04:26 -0400 Subject: [PATCH 2/4] Fix --- mesh_tensorflow/transformer/utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index b1c66d4e..301e9754 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1550,14 +1550,6 @@ def decode_from_dataset(estimator, # Extract the portion of decodes corresponding to this dataset dataset_size = len(examples_for_ds) predictions = decodes[:dataset_size] - - # Write the raw outputs to file. - predictions_filename = os.path.join( - decode_output_dir, - "{}_{}_predictions".format(infer_dataset.name, checkpoint_step), - ) - write_lines_to_file(predictions, predictions_filename) - # Remove the used decodes. del decodes[:dataset_size] @@ -2411,7 +2403,7 @@ def eval_model(estimator, if output_eval_examples: outputs_filename = os.path.join( summary_dir, - "{}_{}_outputs".format((eval_dataset.name, global_step), + "{}_{}_outputs".format(eval_dataset.name, global_step), ) write_lines_to_file(outputs[:dataset_size], outputs_filename) predictions_filename = os.path.join( From c3d2b7d6955c3caaa929abada19d25532a8525b0 Mon Sep 17 00:00:00 2001 From: Colin Raffel Date: Tue, 21 Sep 2021 20:05:04 -0400 Subject: [PATCH 3/4] Revert. --- mesh_tensorflow/transformer/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 301e9754..7004b043 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1550,6 +1550,7 @@ def decode_from_dataset(estimator, # Extract the portion of decodes corresponding to this dataset dataset_size = len(examples_for_ds) predictions = decodes[:dataset_size] + # Remove the used decodes. del decodes[:dataset_size] From 8c4ec14b1fdd69b64f97454ef27f384b2f345e51 Mon Sep 17 00:00:00 2001 From: Colin Raffel Date: Tue, 21 Sep 2021 20:16:15 -0400 Subject: [PATCH 4/4] Fix. --- mesh_tensorflow/transformer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 7004b043..7d78bd47 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -2403,7 +2403,7 @@ def eval_model(estimator, if output_eval_examples: outputs_filename = os.path.join( - summary_dir, + eval_summary_dir, "{}_{}_outputs".format(eval_dataset.name, global_step), ) write_lines_to_file(outputs[:dataset_size], outputs_filename)