Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 3b59d99

Browse files
daphneiMesh TensorFlow Team
authored andcommitted
Enable doing inference where the priming sequences come from a Task rather than a file.
PiperOrigin-RevId: 354647816
1 parent 2c1a4c8 commit 3b59d99

File tree

2 files changed

+195
-81
lines changed

2 files changed

+195
-81
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 161 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,43 @@ def write_lines_to_file(lines, filename):
11551155
output_file.write("{}\n".format(str(line).replace("\n", " ")))
11561156

11571157

1158+
def _get_combined_dataset_input_fn(
1159+
datasets, batch_size, sequence_length, check_for_metrics=False):
1160+
"""Creates input function for estimator for inference, eval, and scoring.
1161+
1162+
Args:
1163+
datasets: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
1164+
These will get combined together into a single tf.data.Dataset.
1165+
batch_size: an integer
1166+
sequence_length: an integer or a dict from feature-key to integer
1167+
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
1168+
check_for_metrics: If True, then only include datasets which have associated
1169+
metric functions.
1170+
1171+
Returns:
1172+
An input function for estimator.
1173+
"""
1174+
def input_fn(params):
1175+
"""Input function for estimator."""
1176+
del params
1177+
1178+
combined_ds = None
1179+
for dataset in datasets:
1180+
if not check_for_metrics or dataset.metric_fns:
1181+
ds = dataset.dataset_fn(sequence_length=sequence_length)
1182+
ds = ds.map(
1183+
_filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1184+
combined_ds = ds if not combined_ds else combined_ds.concatenate(ds)
1185+
1186+
combined_ds = combined_ds.batch(batch_size, drop_remainder=False)
1187+
# Pad the final batch.
1188+
combined_ds = transformer_dataset.trim_and_pad_dataset(
1189+
combined_ds, length=batch_size)
1190+
combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
1191+
return combined_ds
1192+
return input_fn
1193+
1194+
11581195
def get_step_from_checkpoint_path(checkpoint_path):
11591196
"""Returns the global step for the checkpoint at `checkpoint_path`.
11601197
@@ -1227,6 +1264,89 @@ def input_fn(params):
12271264
write_lines_to_file(decodes, output_filename)
12281265

12291266

1267+
@gin.configurable
1268+
def decode_from_dataset(estimator,
1269+
vocabulary,
1270+
model_type,
1271+
batch_size,
1272+
sequence_length,
1273+
checkpoint_path=None,
1274+
infer_dataset_fn=gin.REQUIRED,
1275+
dataset_split="validation",
1276+
decode_output_dir=gin.REQUIRED):
1277+
"""Decode using inputs from the Task examples and writes results to files.
1278+
1279+
Args:
1280+
estimator: a TPUEstimator
1281+
vocabulary: a mtf.transformer.vocabulary.Vocabulary
1282+
model_type: a string
1283+
batch_size: an integer
1284+
sequence_length: an integer or a dict from feature-key to integer
1285+
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
1286+
checkpoint_path: Checkpoint to use for inference.
1287+
infer_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
1288+
See `eval_dataset_fn` argument to `eval_model` for details.
1289+
dataset_split: str, which dataset split to load.
1290+
decode_output_dir: a string, where to write inputs, targets, and decodes.
1291+
"""
1292+
if model_type != "lm":
1293+
raise ValueError("This function currently only supports decoder-only LMs.")
1294+
1295+
infer_datasets = infer_dataset_fn(
1296+
sequence_length=sequence_length,
1297+
vocabulary=vocabulary,
1298+
dataset_split=dataset_split,)
1299+
1300+
input_fn = _get_combined_dataset_input_fn(
1301+
infer_datasets, batch_size, sequence_length)
1302+
1303+
checkpoint_step = get_step_from_checkpoint_path(checkpoint_path)
1304+
# TODO(dei): Deal with case where decode() does not return the right number
1305+
# of outputs. This can happen if the generator in decode() has failures.
1306+
decodes = list(decode(
1307+
estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path))
1308+
1309+
tf.logging.info("Caching inference examples.")
1310+
with tf.Graph().as_default():
1311+
for infer_dataset in infer_datasets:
1312+
ds = infer_dataset.dataset_fn()
1313+
1314+
# Create list of postprocessed text targets
1315+
examples_for_ds = list(tfds.as_numpy(ds))
1316+
examples_for_ds = _maybe_add_pretokenized_features(
1317+
examples_for_ds, vocabulary)
1318+
1319+
# Extract the portion of decodes corresponding to this dataset
1320+
dataset_size = len(examples_for_ds)
1321+
predictions = decodes[:dataset_size]
1322+
1323+
# Remove the used decodes.
1324+
del decodes[:dataset_size]
1325+
1326+
# Write the predictions to file.
1327+
predictions_filename = os.path.join(
1328+
decode_output_dir,
1329+
"{}_{}_predictions".format(infer_dataset.name, checkpoint_step),
1330+
)
1331+
write_lines_to_file(predictions, predictions_filename)
1332+
1333+
# Write the ground-truth targets to file.
1334+
targets = []
1335+
for ex in examples_for_ds:
1336+
targets_pretokenized = ex["targets_pretokenized"]
1337+
targets.append(infer_dataset.postprocess_fn(
1338+
targets_pretokenized, example=ex, is_target=True))
1339+
targets_filename = os.path.join(
1340+
decode_output_dir, "{}_targets".format(infer_dataset.name))
1341+
write_lines_to_file(targets, targets_filename)
1342+
1343+
# Write the inputs to a file.
1344+
inputs = [ex["inputs_pretokenized"] for ex in examples_for_ds]
1345+
inputs_filename = os.path.join(
1346+
decode_output_dir, "{}_inputs".format(infer_dataset.name))
1347+
write_lines_to_file(inputs, inputs_filename)
1348+
1349+
12301350
@gin.configurable
12311351
def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
12321352
"""Replaces everything after EOS with PAD (along last axis).
@@ -1274,9 +1394,10 @@ def save_scores(results, vocabulary,
12741394
write_lines_to_file(["%f" % f for f in scores], scores_filename+".scores")
12751395

12761396
if save_example_text:
1397+
results = _maybe_add_pretokenized_features(results, vocabulary)
1398+
12771399
# Targets will always exist.
12781400
targets = [r.get("targets_pretokenized", r["targets"]) for r in results]
1279-
targets = _maybe_decode_python(targets, targets_vocabulary(vocabulary))
12801401
if scores_filename is not None:
12811402
write_lines_to_file(targets, scores_filename+".targets")
12821403

@@ -1295,7 +1416,6 @@ def get_sequence_length(tokens, pad_id=0):
12951416
# Inputs may only exist for some tasks.
12961417
if "inputs" in results[0]:
12971418
inputs = [r.get("inputs_pretokenized", r["inputs"]) for r in results]
1298-
inputs = _maybe_decode_python(inputs, inputs_vocabulary(vocabulary))
12991419
if scores_filename is not None:
13001420
write_lines_to_file(inputs, scores_filename+".inputs")
13011421
return scores, inputs, targets
@@ -1342,14 +1462,38 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
13421462
return score_postprocess_fn(results, vocabulary)
13431463

13441464

1345-
def _maybe_decode_python(ids_or_strs, vocabulary):
1346-
"""Decode if ids_or_strs is not yet strings in pure python."""
1465+
def _maybe_add_pretokenized_features(examples, vocabulary):
1466+
"""Ensures decoded versions of "inputs" and "targets" exist in each example.
1467+
1468+
Args:
1469+
examples: List of example dictionaries containing mappings from feature
1470+
name to np.array of integers.
1471+
vocabulary: The vocabulary.
1472+
1473+
Returns:
1474+
examples dictionary with decoded plaintext entries for each feature in
1475+
features that was present in the original example.
1476+
"""
1477+
vocabulary = {"inputs": inputs_vocabulary(vocabulary),
1478+
"targets": targets_vocabulary(vocabulary)}
13471479

1348-
if ids_or_strs:
1349-
if isinstance(ids_or_strs[0], np.ndarray) and np.issubdtype(
1350-
ids_or_strs[0].dtype, np.integer):
1351-
ids_or_strs = [vocabulary.decode(t.tolist()) for t in ids_or_strs]
1352-
return ids_or_strs
1480+
# This is just used for logging purposes.
1481+
added_pretokenized = {"inputs": False, "targets": False}
1482+
1483+
for example in examples:
1484+
for feature_name in ["inputs", "targets"]:
1485+
pretokenized_feature_name = feature_name + "_pretokenized"
1486+
if feature_name in example and pretokenized_feature_name not in example:
1487+
s = vocabulary[feature_name].decode(example[feature_name].tolist())
1488+
example[pretokenized_feature_name] = s
1489+
1490+
if not added_pretokenized[feature_name]:
1491+
added_pretokenized[feature_name] = True
1492+
tf.logging.warning(
1493+
"Feature '%s' is being approximated by decoding from the"
1494+
"tokenized feature '%s.'",
1495+
pretokenized_feature_name, feature_name)
1496+
return examples
13531497

13541498

13551499
@gin.configurable
@@ -1474,23 +1618,8 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
14741618
vocabulary=vocabulary,
14751619
dataset_split=dataset_split)
14761620

1477-
def input_fn(params):
1478-
"""Eval input function for estimator."""
1479-
del params
1480-
1481-
dataset = None
1482-
for scoring_dataset in scoring_datasets:
1483-
ds = scoring_dataset.dataset_fn()
1484-
ds = ds.map(
1485-
_filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1486-
dataset = dataset.concatenate(ds) if dataset else ds
1487-
1488-
dataset = dataset.batch(batch_size, drop_remainder=False)
1489-
# Pad the final batch.
1490-
dataset = transformer_dataset.trim_and_pad_dataset(
1491-
dataset, length=batch_size)
1492-
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
1493-
return dataset
1621+
input_fn = _get_combined_dataset_input_fn(
1622+
scoring_datasets, batch_size, sequence_length)
14941623

14951624
return score_with_estimator(
14961625
estimator, input_fn, eval_checkpoint_step, model_dir,
@@ -1681,10 +1810,8 @@ def infer_model(estimator,
16811810
model_type,
16821811
model_dir,
16831812
eval_checkpoint_step,
1684-
input_filename=None,
1685-
output_filename=None,
16861813
checkpoint_paths=None,
1687-
decode_from_file_fn=decode_from_file):
1814+
decode_fn=decode_from_file):
16881815
"""Infer a Mesh-TF model.
16891816
16901817
Args:
@@ -1699,24 +1826,20 @@ def infer_model(estimator,
16991826
model_dir: string, estimator model_dir
17001827
eval_checkpoint_step: int, list of ints, or None, see `eval_model`
17011828
docstring.
1702-
input_filename: a string, input file with examples
1703-
output_filename: a string, output file to save decodes
17041829
checkpoint_paths: optional list of checkpoints to run inference for
1705-
decode_from_file_fn: decoding function, defaults to decode_from_file
1830+
decode_fn: decoding function, defaults to decode_from_file
17061831
"""
17071832
if checkpoint_paths is None:
17081833
checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
17091834

17101835
for checkpoint_path in checkpoint_paths:
1711-
decode_from_file_fn(
1836+
decode_fn(
17121837
estimator,
17131838
vocabulary=vocabulary,
17141839
model_type=model_type,
17151840
batch_size=batch_size,
17161841
sequence_length=sequence_length,
1717-
checkpoint_path=checkpoint_path,
1718-
input_filename=input_filename,
1719-
output_filename=output_filename)
1842+
checkpoint_path=checkpoint_path)
17201843

17211844

17221845
def eval_model(estimator,
@@ -1872,24 +1995,8 @@ def eval_model(estimator,
18721995
if callable(estimator):
18731996
estimator = estimator()
18741997

1875-
def input_fn(params):
1876-
"""Eval input function for estimator."""
1877-
del params
1878-
# Concatenate all dataset inputs to only have to do one decode loop
1879-
combined_ds = None
1880-
for eval_dataset in eval_datasets:
1881-
# Only evaluate tasks with metrics.
1882-
if eval_dataset.metric_fns:
1883-
ds = eval_dataset.dataset_fn(sequence_length=sequence_length)
1884-
ds = ds.map(
1885-
_filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1886-
combined_ds = ds if not combined_ds else combined_ds.concatenate(ds)
1887-
combined_ds = combined_ds.batch(batch_size, drop_remainder=False)
1888-
# Pad the final batch.
1889-
combined_ds = transformer_dataset.trim_and_pad_dataset(
1890-
combined_ds, length=batch_size)
1891-
combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
1892-
return combined_ds
1998+
input_fn = _get_combined_dataset_input_fn(
1999+
eval_datasets, batch_size, sequence_length, check_for_metrics=True)
18932000

18942001
checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
18952002
for checkpoint_path in checkpoint_paths:

mesh_tensorflow/transformer/utils_test.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -122,36 +122,43 @@ def testCleanDecodes(self):
122122
("int32", np.int32),
123123
("int64", np.int64),
124124
)
125-
def test_maybe_decode_python_with_int_inputs(self, dtype):
126-
vocabulary = mock_vocabulary({
127-
"a": 1,
128-
"b": 2,
129-
"c": 3,
130-
"d": 4,
131-
},
125+
def test_maybe_add_pretokenized_features_with_int_inputs(self, dtype):
126+
vocabulary = mock_vocabulary({"a": 1, "b": 2, "c": 3, "d": 4,},
132127
vocab_size=1000)
133-
ids_or_strs = [np.array([1, 2, 3, 4], dtype=np.int32)]
134-
result = utils._maybe_decode_python(ids_or_strs, vocabulary)
135-
expected = [["a", "b", "c", "d"]]
136-
self.assertAllEqual(result, expected)
137128

138-
@parameterized.named_parameters(
139-
("str", [["a", "b", "c", "d"]]),
140-
("bytes", [[b"a", b"b", b"c", b"d"]]),
141-
("ndarray_str", [np.array(["a", "b", "c", "d"])]),
142-
("ndarray_bytes", [np.array([b"a", b"b", b"c", b"d"])]),
143-
)
144-
def test_maybe_decode_python_with_str_inputs(self, ids_or_strs):
145-
vocabulary = mock_vocabulary({
146-
"a": 1,
147-
"b": 2,
148-
"c": 3,
149-
"d": 4,
150-
},
129+
examples = [{"targets": np.array([1, 2, 3, 4], dtype=dtype),
130+
"inputs": np.array([1, 2, 3, 4], dtype=dtype)},
131+
]
132+
result = utils._maybe_add_pretokenized_features(examples, vocabulary)
133+
expected = ["a", "b", "c", "d"]
134+
self.assertAllEqual(result[0]["targets_pretokenized"], expected)
135+
self.assertAllEqual(result[0]["inputs_pretokenized"], expected)
136+
self.assertLen(result, 1)
137+
138+
def test_maybe_add_pretokenized_features_nonstandard_feature(self):
139+
vocabulary = mock_vocabulary({"a": 1, "b": 2, "c": 3, "d": 4,},
151140
vocab_size=1000)
152-
result = utils._maybe_decode_python(ids_or_strs, vocabulary)
153-
expected = [["a", "b", "c", "d"]]
154-
self.assertAllEqual(result, expected)
141+
142+
examples = [{"notafeature": np.array([1, 2, 3, 4], dtype=np.int32),
143+
"inputs": np.array([1, 2, 3, 4], dtype=np.int32)}
144+
]
145+
result = utils._maybe_add_pretokenized_features(examples, vocabulary)
146+
147+
self.assertSameElements(result[0].keys(),
148+
("notafeature", "inputs", "inputs_pretokenized"))
149+
self.assertAllEqual(result[0]["notafeature"], [1, 2, 3, 4])
150+
151+
def test_maybe_add_pretokenized_features_pretokenized_exists(self):
152+
vocabulary = mock_vocabulary({"a": 1, "b": 2, "c": 3, "d": 4,},
153+
vocab_size=1000)
154+
155+
examples = [{"inputs_pretokenized": "Hello world!",
156+
"inputs": np.array([1, 2, 3, 4], dtype=np.int32)}
157+
]
158+
result = utils._maybe_add_pretokenized_features(examples, vocabulary)
159+
self.assertEqual(result[0]["inputs_pretokenized"], "Hello world!")
160+
self.assertSameElements(result[0].keys(), ("inputs", "inputs_pretokenized"))
161+
self.assertLen(result, 1)
155162

156163

157164
if __name__ == "__main__":

0 commit comments

Comments
 (0)