Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Save scores lazily. #359

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 135 additions & 8 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@

import functools
import itertools
import math
import os
import random
import re
import time

import gin
import gin.tf
Expand All @@ -38,6 +40,7 @@
from mesh_tensorflow.transformer import learning_rate_schedules
from mesh_tensorflow.transformer import transformer
import numpy as np
import pandas as pd
import pkg_resources
import six
import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -1654,6 +1657,52 @@ def get_sequence_length(tokens, pad_id=0):
return scores


@gin.configurable
def save_scores_to_tfrecords(
results, vocabulary, scores_filename, shard_idx=0, save_ids_only=False):
"""Processes results from scoring examples and saves them to tfrecords files.

Args:
results: list of dictionaries containing the results for each scored
example.
vocabulary: a function that that returns a tf.data.Dataset with examples
containing the string field 'targets' and optionally the field 'inputs'
scores_filename: a string (path of file to write scores to).
shard_idx: an integer indicating the current index of the file for sharding.
save_ids_only: if true, save the ID that is prepended to the inputs.
"""
results = _maybe_add_pretokenized_features(results, vocabulary)
scores = [r.get("scores", 0.0) for r in results]
targets = [r.get("targets_pretokenized", r["targets"]) for r in results]
inputs = [r.get("targets_neg_pretokenized", "") for r in results]

if save_ids_only:
inputs = [r.split(" ", 1)[0] for r in inputs]

table_path = "{}_{}.tfrecord".format(scores_filename, shard_idx)
tf.logging.info("Saving results to {}".format(table_path))

with tf.io.TFRecordWriter(table_path) as file_writer:
for input_, target, score in zip(inputs, targets, scores):
record_bytes = tf.train.Example(
features=tf.train.Features(
feature={
"input":
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(input_, "utf8")])),
"target":
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(target, "utf8")])),
"score":
tf.train.Feature(
float_list=tf.train.FloatList(value=[score])),
})).SerializeToString()
file_writer.write(record_bytes)


@gin.configurable
def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn=save_scores,
num_examples=None):
Expand Down Expand Up @@ -1691,6 +1740,70 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
return score_postprocess_fn(results, vocabulary)


@gin.configurable
def score_with_estimator_lazy(
estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn=save_scores_to_tfrecords,
num_examples=None, num_examples_per_shard=10000):
"""Score each example returned by input_fn lazily.

Args:
estimator: a TPUEstimator
input_fn: a function that that returns a tf.data.Dataset with examples
containing the string field 'targets' and optionally the field 'inputs'
eval_checkpoint_step: int, list of ints, or None, see `eval_model`
docstring.
model_dir: string, estimator model_dir
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple
score_postprocess_fn: a function that takes in model outputs and
post-processes, saves, and returns them.
num_examples: int, the total # of examples being scored, None if unknown
num_examples_per_shard: int, the number of examples per file shard.

Returns:
a list of floats
"""
if num_examples is not None:
num_shards = math.ceil(num_examples / num_examples_per_shard)
else:
num_shards = None
tf.logging.info(
"Scoring {} examples with {} shards at {} examples per shard".format(
num_examples, num_shards, num_examples_per_shard))

checkpoint_path, = get_checkpoint_iterator(
eval_checkpoint_step, model_dir)
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)

start = time.time()
results = []
shard_idx = 0

for i, result in enumerate(result_iter):
results.append(result)
num_results = len(results)
exceeded_num_examples = num_examples is not None and i >= num_examples

if num_results >= num_examples_per_shard or exceeded_num_examples:
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)

elapsed = time.time() - start
tf.logging.info(
"Scored {} results in {} s, {} examples/s for shard {}".format(
num_results, elapsed, num_results / elapsed, shard_idx))

results = []
shard_idx += 1
start = time.time()

if exceeded_num_examples:
break

if results:
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)


def _maybe_add_pretokenized_features(examples, vocabulary):
"""Ensures decoded versions of "inputs" and "targets" exist in each example.

Expand All @@ -1712,9 +1825,17 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
for example in examples:
for feature_name in ["inputs", "targets"]:
pretokenized_feature_name = feature_name + "_pretokenized"
neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
if feature_name in example and pretokenized_feature_name not in example:
s = vocabulary[feature_name].decode(example[feature_name].tolist())
ids = example[feature_name].tolist()

neg_ids = [abs(i) for i in ids if i < 0]
ids = [i for i in ids if i > 0]

s = vocabulary[feature_name].decode(ids)
example[pretokenized_feature_name] = s
neg_s = vocabulary[feature_name].decode(neg_ids)
example[neg_pretokenized_feature_name] = neg_s

if not added_pretokenized[feature_name]:
added_pretokenized[feature_name] = True
Expand All @@ -1730,7 +1851,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
sequence_length, model_dir, eval_checkpoint_step,
inputs=gin.REQUIRED, targets=gin.REQUIRED,
score_postprocess_fn=gin.REQUIRED, eos_id=1,
score_eos=True):
score_eos=True,
score_with_estimator_fn=score_with_estimator):
"""Compute log likelihoods per example and write to a text file.

inputs & targets must either be the same length (in lines) or have inputs
Expand Down Expand Up @@ -1761,6 +1883,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
score_eos: a boolean - whether to score the final eos token of each line
If this is set to false, the scores can be interpreted as prefix
log-likelihoods
score_with_estimator_fn: a function to run scoring with the estimator.
Returns:
a list of floats
"""
Expand Down Expand Up @@ -1806,7 +1929,7 @@ def input_fn(params):
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)

return score_with_estimator(
return score_with_estimator_fn(
estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn, len(targets))

Expand All @@ -1815,7 +1938,8 @@ def input_fn(params):
def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
model_dir, eval_checkpoint_step, dataset_split,
score_dataset_fn=None,
score_postprocess_fn=gin.REQUIRED):
score_postprocess_fn=gin.REQUIRED,
score_with_estimator_fn=score_with_estimator):
"""Compute log likelihoods per example and write to a text file.

The function returns a list of floats representing the log-likelihood of the
Expand All @@ -1837,6 +1961,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
See `eval_dataset_fn` argument to `eval_model` for details.
score_postprocess_fn: Function that takes in model outputs and
post-processes then returns then.
score_with_estimator_fn: a function to run scoring with the estimator.

Returns:
scores: a list of floats, the log likelihood scores
Expand All @@ -1850,9 +1975,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
input_fn = _get_combined_dataset_input_fn(
scoring_datasets, batch_size, sequence_length)

return score_with_estimator(
return score_with_estimator_fn(
estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn, None)
vocabulary, score_postprocess_fn)


def get_estimator(model_type, vocabulary, mesh_shape,
Expand Down Expand Up @@ -2093,7 +2218,8 @@ def eval_model(estimator,
eval_checkpoint_step,
eval_with_score=False,
output_eval_examples=True,
eval_dir_suffix=None):
eval_dir_suffix=None,
score_with_estimator_fn=score_with_estimator):
"""Eval a Mesh-TF model.

Args:
Expand Down Expand Up @@ -2137,6 +2263,7 @@ def eval_model(estimator,
of the eval examples in plaintext to eval_summary_dir.
eval_dir_suffix: string, if not None then will appended to the
eval_summary_dir.
score_with_estimator_fn: a function to run scoring with the estimator.
"""
if eval_dataset_fn is None:
raise ValueError("Must provide eval_dataset_fn through gin for eval.")
Expand Down Expand Up @@ -2248,7 +2375,7 @@ def eval_model(estimator,
tf.logging.info("Checkpoint path %s" % checkpoint_path)
global_step = int(get_step_from_checkpoint_path(checkpoint_path))
if eval_with_score:
outputs, _ = score_with_estimator(
outputs, _ = score_with_estimator_fn(
estimator, input_fn, global_step, model_dir, vocabulary,
num_examples=sum(len(cex) for cex in cached_examples.values()))
else:
Expand Down