diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 277e8471..a24e90a5 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -26,9 +26,11 @@ import functools import itertools +import math import os import random import re +import time import gin import gin.tf @@ -38,6 +40,7 @@ from mesh_tensorflow.transformer import learning_rate_schedules from mesh_tensorflow.transformer import transformer import numpy as np +import pandas as pd import pkg_resources import six import tensorflow.compat.v1 as tf @@ -1654,6 +1657,52 @@ def get_sequence_length(tokens, pad_id=0): return scores +@gin.configurable +def save_scores_to_tfrecords( + results, vocabulary, scores_filename, shard_idx=0, save_ids_only=False): + """Processes results from scoring examples and saves them to tfrecords files. + + Args: + results: list of dictionaries containing the results for each scored + example. + vocabulary: a function that that returns a tf.data.Dataset with examples + containing the string field 'targets' and optionally the field 'inputs' + scores_filename: a string (path of file to write scores to). + shard_idx: an integer indicating the current index of the file for sharding. + save_ids_only: if true, save the ID that is prepended to the inputs. + """ + results = _maybe_add_pretokenized_features(results, vocabulary) + scores = [r.get("scores", 0.0) for r in results] + targets = [r.get("targets_pretokenized", r["targets"]) for r in results] + inputs = [r.get("targets_neg_pretokenized", "") for r in results] + + if save_ids_only: + inputs = [r.split(" ", 1)[0] for r in inputs] + + table_path = "{}_{}.tfrecord".format(scores_filename, shard_idx) + tf.logging.info("Saving results to {}".format(table_path)) + + with tf.io.TFRecordWriter(table_path) as file_writer: + for input_, target, score in zip(inputs, targets, scores): + record_bytes = tf.train.Example( + features=tf.train.Features( + feature={ + "input": + tf.train.Feature( + bytes_list=tf.train.BytesList( + value=[bytes(input_, "utf8")])), + "target": + tf.train.Feature( + bytes_list=tf.train.BytesList( + value=[bytes(target, "utf8")])), + "score": + tf.train.Feature( + float_list=tf.train.FloatList(value=[score])), + })).SerializeToString() + file_writer.write(record_bytes) + + +@gin.configurable def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir, vocabulary, score_postprocess_fn=save_scores, num_examples=None): @@ -1691,6 +1740,70 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir, return score_postprocess_fn(results, vocabulary) +@gin.configurable +def score_with_estimator_lazy( + estimator, input_fn, eval_checkpoint_step, model_dir, + vocabulary, score_postprocess_fn=save_scores_to_tfrecords, + num_examples=None, num_examples_per_shard=10000): + """Score each example returned by input_fn lazily. + + Args: + estimator: a TPUEstimator + input_fn: a function that that returns a tf.data.Dataset with examples + containing the string field 'targets' and optionally the field 'inputs' + eval_checkpoint_step: int, list of ints, or None, see `eval_model` + docstring. + model_dir: string, estimator model_dir + vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, + targets_vocabulary) tuple + score_postprocess_fn: a function that takes in model outputs and + post-processes, saves, and returns them. + num_examples: int, the total # of examples being scored, None if unknown + num_examples_per_shard: int, the number of examples per file shard. + + Returns: + a list of floats + """ + if num_examples is not None: + num_shards = math.ceil(num_examples / num_examples_per_shard) + else: + num_shards = None + tf.logging.info( + "Scoring {} examples with {} shards at {} examples per shard".format( + num_examples, num_shards, num_examples_per_shard)) + + checkpoint_path, = get_checkpoint_iterator( + eval_checkpoint_step, model_dir) + result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path) + + start = time.time() + results = [] + shard_idx = 0 + + for i, result in enumerate(result_iter): + results.append(result) + num_results = len(results) + exceeded_num_examples = num_examples is not None and i >= num_examples + + if num_results >= num_examples_per_shard or exceeded_num_examples: + score_postprocess_fn(results, vocabulary, shard_idx=shard_idx) + + elapsed = time.time() - start + tf.logging.info( + "Scored {} results in {} s, {} examples/s for shard {}".format( + num_results, elapsed, num_results / elapsed, shard_idx)) + + results = [] + shard_idx += 1 + start = time.time() + + if exceeded_num_examples: + break + + if results: + score_postprocess_fn(results, vocabulary, shard_idx=shard_idx) + + def _maybe_add_pretokenized_features(examples, vocabulary): """Ensures decoded versions of "inputs" and "targets" exist in each example. @@ -1712,9 +1825,17 @@ def _maybe_add_pretokenized_features(examples, vocabulary): for example in examples: for feature_name in ["inputs", "targets"]: pretokenized_feature_name = feature_name + "_pretokenized" + neg_pretokenized_feature_name = feature_name + "_neg_pretokenized" if feature_name in example and pretokenized_feature_name not in example: - s = vocabulary[feature_name].decode(example[feature_name].tolist()) + ids = example[feature_name].tolist() + + neg_ids = [abs(i) for i in ids if i < 0] + ids = [i for i in ids if i > 0] + + s = vocabulary[feature_name].decode(ids) example[pretokenized_feature_name] = s + neg_s = vocabulary[feature_name].decode(neg_ids) + example[neg_pretokenized_feature_name] = neg_s if not added_pretokenized[feature_name]: added_pretokenized[feature_name] = True @@ -1730,7 +1851,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size, sequence_length, model_dir, eval_checkpoint_step, inputs=gin.REQUIRED, targets=gin.REQUIRED, score_postprocess_fn=gin.REQUIRED, eos_id=1, - score_eos=True): + score_eos=True, + score_with_estimator_fn=score_with_estimator): """Compute log likelihoods per example and write to a text file. inputs & targets must either be the same length (in lines) or have inputs @@ -1761,6 +1883,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size, score_eos: a boolean - whether to score the final eos token of each line If this is set to false, the scores can be interpreted as prefix log-likelihoods + score_with_estimator_fn: a function to run scoring with the estimator. Returns: a list of floats """ @@ -1806,7 +1929,7 @@ def input_fn(params): dataset = dataset.batch(batch_size, drop_remainder=True) return dataset.prefetch(tf.data.experimental.AUTOTUNE) - return score_with_estimator( + return score_with_estimator_fn( estimator, input_fn, eval_checkpoint_step, model_dir, vocabulary, score_postprocess_fn, len(targets)) @@ -1815,7 +1938,8 @@ def input_fn(params): def score_from_dataset(estimator, vocabulary, batch_size, sequence_length, model_dir, eval_checkpoint_step, dataset_split, score_dataset_fn=None, - score_postprocess_fn=gin.REQUIRED): + score_postprocess_fn=gin.REQUIRED, + score_with_estimator_fn=score_with_estimator): """Compute log likelihoods per example and write to a text file. The function returns a list of floats representing the log-likelihood of the @@ -1837,6 +1961,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length, See `eval_dataset_fn` argument to `eval_model` for details. score_postprocess_fn: Function that takes in model outputs and post-processes then returns then. + score_with_estimator_fn: a function to run scoring with the estimator. Returns: scores: a list of floats, the log likelihood scores @@ -1850,9 +1975,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length, input_fn = _get_combined_dataset_input_fn( scoring_datasets, batch_size, sequence_length) - return score_with_estimator( + return score_with_estimator_fn( estimator, input_fn, eval_checkpoint_step, model_dir, - vocabulary, score_postprocess_fn, None) + vocabulary, score_postprocess_fn) def get_estimator(model_type, vocabulary, mesh_shape, @@ -2093,7 +2218,8 @@ def eval_model(estimator, eval_checkpoint_step, eval_with_score=False, output_eval_examples=True, - eval_dir_suffix=None): + eval_dir_suffix=None, + score_with_estimator_fn=score_with_estimator): """Eval a Mesh-TF model. Args: @@ -2137,6 +2263,7 @@ def eval_model(estimator, of the eval examples in plaintext to eval_summary_dir. eval_dir_suffix: string, if not None then will appended to the eval_summary_dir. + score_with_estimator_fn: a function to run scoring with the estimator. """ if eval_dataset_fn is None: raise ValueError("Must provide eval_dataset_fn through gin for eval.") @@ -2248,7 +2375,7 @@ def eval_model(estimator, tf.logging.info("Checkpoint path %s" % checkpoint_path) global_step = int(get_step_from_checkpoint_path(checkpoint_path)) if eval_with_score: - outputs, _ = score_with_estimator( + outputs, _ = score_with_estimator_fn( estimator, input_fn, global_step, model_dir, vocabulary, num_examples=sum(len(cex) for cex in cached_examples.values())) else: