@@ -1155,6 +1155,43 @@ def write_lines_to_file(lines, filename):
11551155 output_file .write ("{}\n " .format (str (line ).replace ("\n " , " " )))
11561156
11571157
1158+ def _get_combined_dataset_input_fn (
1159+ datasets , batch_size , sequence_length , check_for_metrics = False ):
1160+ """Creates input function for estimator for inference, eval, and scoring.
1161+
1162+ Args:
1163+ datasets: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
1164+ These will get combined together into a single tf.data.Dataset.
1165+ batch_size: an integer
1166+ sequence_length: an integer or a dict from feature-key to integer
1167+ the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
1168+ check_for_metrics: If True, then only include datasets which have associated
1169+ metric functions.
1170+
1171+ Returns:
1172+ An input function for estimator.
1173+ """
1174+ def input_fn (params ):
1175+ """Input function for estimator."""
1176+ del params
1177+
1178+ combined_ds = None
1179+ for dataset in datasets :
1180+ if not check_for_metrics or dataset .metric_fns :
1181+ ds = dataset .dataset_fn (sequence_length = sequence_length )
1182+ ds = ds .map (
1183+ _filter_features , num_parallel_calls = tf .data .experimental .AUTOTUNE )
1184+ combined_ds = ds if not combined_ds else combined_ds .concatenate (ds )
1185+
1186+ combined_ds = combined_ds .batch (batch_size , drop_remainder = False )
1187+ # Pad the final batch.
1188+ combined_ds = transformer_dataset .trim_and_pad_dataset (
1189+ combined_ds , length = batch_size )
1190+ combined_ds = combined_ds .prefetch (tf .data .experimental .AUTOTUNE )
1191+ return combined_ds
1192+ return input_fn
1193+
1194+
11581195def get_step_from_checkpoint_path (checkpoint_path ):
11591196 """Returns the global step for the checkpoint at `checkpoint_path`.
11601197
@@ -1227,6 +1264,89 @@ def input_fn(params):
12271264 write_lines_to_file (decodes , output_filename )
12281265
12291266
1267+ @gin .configurable
1268+ def decode_from_dataset (estimator ,
1269+ vocabulary ,
1270+ model_type ,
1271+ batch_size ,
1272+ sequence_length ,
1273+ checkpoint_path = None ,
1274+ infer_dataset_fn = gin .REQUIRED ,
1275+ dataset_split = "validation" ,
1276+ decode_output_dir = gin .REQUIRED ):
1277+ """Decode using inputs from the Task examples and writes results to files.
1278+
1279+ Args:
1280+ estimator: a TPUEstimator
1281+ vocabulary: a mtf.transformer.vocabulary.Vocabulary
1282+ model_type: a string
1283+ batch_size: an integer
1284+ sequence_length: an integer or a dict from feature-key to integer
1285+ the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
1286+ checkpoint_path: Checkpoint to use for inference.
1287+ infer_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
1288+ See `eval_dataset_fn` argument to `eval_model` for details.
1289+ dataset_split: str, which dataset split to load.
1290+ decode_output_dir: a string, where to write inputs, targets, and decodes.
1291+ """
1292+ if model_type != "lm" :
1293+ raise ValueError ("This function currently only supports decoder-only LMs." )
1294+
1295+ infer_datasets = infer_dataset_fn (
1296+ sequence_length = sequence_length ,
1297+ vocabulary = vocabulary ,
1298+ dataset_split = dataset_split ,)
1299+
1300+ input_fn = _get_combined_dataset_input_fn (
1301+ infer_datasets , batch_size , sequence_length )
1302+
1303+ checkpoint_step = get_step_from_checkpoint_path (checkpoint_path )
1304+ # TODO(dei): Deal with case where decode() does not return the right number
1305+ # of outputs. This can happen if the generator in decode() has failures.
1306+ decodes = list (decode (
1307+ estimator , input_fn , vocabulary , checkpoint_path = checkpoint_path ))
1308+
1309+ tf .logging .info ("Caching inference examples." )
1310+ with tf .Graph ().as_default ():
1311+ for infer_dataset in infer_datasets :
1312+ ds = infer_dataset .dataset_fn ()
1313+
1314+ # Create list of postprocessed text targets
1315+ examples_for_ds = list (tfds .as_numpy (ds ))
1316+ examples_for_ds = _maybe_add_pretokenized_features (
1317+ examples_for_ds , vocabulary )
1318+
1319+ # Extract the portion of decodes corresponding to this dataset
1320+ dataset_size = len (examples_for_ds )
1321+ predictions = decodes [:dataset_size ]
1322+
1323+ # Remove the used decodes.
1324+ del decodes [:dataset_size ]
1325+
1326+ # Write the predictions to file.
1327+ predictions_filename = os .path .join (
1328+ decode_output_dir ,
1329+ "{}_{}_predictions" .format (infer_dataset .name , checkpoint_step ),
1330+ )
1331+ write_lines_to_file (predictions , predictions_filename )
1332+
1333+ # Write the ground-truth targets to file.
1334+ targets = []
1335+ for ex in examples_for_ds :
1336+ targets_pretokenized = ex ["targets_pretokenized" ]
1337+ targets .append (infer_dataset .postprocess_fn (
1338+ targets_pretokenized , example = ex , is_target = True ))
1339+ targets_filename = os .path .join (
1340+ decode_output_dir , "{}_targets" .format (infer_dataset .name ))
1341+ write_lines_to_file (targets , targets_filename )
1342+
1343+ # Write the inputs to a file.
1344+ inputs = [ex ["inputs_pretokenized" ] for ex in examples_for_ds ]
1345+ inputs_filename = os .path .join (
1346+ decode_output_dir , "{}_inputs" .format (infer_dataset .name ))
1347+ write_lines_to_file (inputs , inputs_filename )
1348+
1349+
12301350@gin .configurable
12311351def clean_decodes (ids , eos_id = 1 , pad_id = 0 , length_axis = - 1 ):
12321352 """Replaces everything after EOS with PAD (along last axis).
@@ -1274,9 +1394,10 @@ def save_scores(results, vocabulary,
12741394 write_lines_to_file (["%f" % f for f in scores ], scores_filename + ".scores" )
12751395
12761396 if save_example_text :
1397+ results = _maybe_add_pretokenized_features (results , vocabulary )
1398+
12771399 # Targets will always exist.
12781400 targets = [r .get ("targets_pretokenized" , r ["targets" ]) for r in results ]
1279- targets = _maybe_decode_python (targets , targets_vocabulary (vocabulary ))
12801401 if scores_filename is not None :
12811402 write_lines_to_file (targets , scores_filename + ".targets" )
12821403
@@ -1295,7 +1416,6 @@ def get_sequence_length(tokens, pad_id=0):
12951416 # Inputs may only exist for some tasks.
12961417 if "inputs" in results [0 ]:
12971418 inputs = [r .get ("inputs_pretokenized" , r ["inputs" ]) for r in results ]
1298- inputs = _maybe_decode_python (inputs , inputs_vocabulary (vocabulary ))
12991419 if scores_filename is not None :
13001420 write_lines_to_file (inputs , scores_filename + ".inputs" )
13011421 return scores , inputs , targets
@@ -1342,14 +1462,38 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
13421462 return score_postprocess_fn (results , vocabulary )
13431463
13441464
1345- def _maybe_decode_python (ids_or_strs , vocabulary ):
1346- """Decode if ids_or_strs is not yet strings in pure python."""
1465+ def _maybe_add_pretokenized_features (examples , vocabulary ):
1466+ """Ensures decoded versions of "inputs" and "targets" exist in each example.
1467+
1468+ Args:
1469+ examples: List of example dictionaries containing mappings from feature
1470+ name to np.array of integers.
1471+ vocabulary: The vocabulary.
1472+
1473+ Returns:
1474+ examples dictionary with decoded plaintext entries for each feature in
1475+ features that was present in the original example.
1476+ """
1477+ vocabulary = {"inputs" : inputs_vocabulary (vocabulary ),
1478+ "targets" : targets_vocabulary (vocabulary )}
13471479
1348- if ids_or_strs :
1349- if isinstance (ids_or_strs [0 ], np .ndarray ) and np .issubdtype (
1350- ids_or_strs [0 ].dtype , np .integer ):
1351- ids_or_strs = [vocabulary .decode (t .tolist ()) for t in ids_or_strs ]
1352- return ids_or_strs
1480+ # This is just used for logging purposes.
1481+ added_pretokenized = {"inputs" : False , "targets" : False }
1482+
1483+ for example in examples :
1484+ for feature_name in ["inputs" , "targets" ]:
1485+ pretokenized_feature_name = feature_name + "_pretokenized"
1486+ if feature_name in example and pretokenized_feature_name not in example :
1487+ s = vocabulary [feature_name ].decode (example [feature_name ].tolist ())
1488+ example [pretokenized_feature_name ] = s
1489+
1490+ if not added_pretokenized [feature_name ]:
1491+ added_pretokenized [feature_name ] = True
1492+ tf .logging .warning (
1493+ "Feature '%s' is being approximated by decoding from the"
1494+ "tokenized feature '%s.'" ,
1495+ pretokenized_feature_name , feature_name )
1496+ return examples
13531497
13541498
13551499@gin .configurable
@@ -1474,23 +1618,8 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
14741618 vocabulary = vocabulary ,
14751619 dataset_split = dataset_split )
14761620
1477- def input_fn (params ):
1478- """Eval input function for estimator."""
1479- del params
1480-
1481- dataset = None
1482- for scoring_dataset in scoring_datasets :
1483- ds = scoring_dataset .dataset_fn ()
1484- ds = ds .map (
1485- _filter_features , num_parallel_calls = tf .data .experimental .AUTOTUNE )
1486- dataset = dataset .concatenate (ds ) if dataset else ds
1487-
1488- dataset = dataset .batch (batch_size , drop_remainder = False )
1489- # Pad the final batch.
1490- dataset = transformer_dataset .trim_and_pad_dataset (
1491- dataset , length = batch_size )
1492- dataset = dataset .prefetch (tf .data .experimental .AUTOTUNE )
1493- return dataset
1621+ input_fn = _get_combined_dataset_input_fn (
1622+ scoring_datasets , batch_size , sequence_length )
14941623
14951624 return score_with_estimator (
14961625 estimator , input_fn , eval_checkpoint_step , model_dir ,
@@ -1681,10 +1810,8 @@ def infer_model(estimator,
16811810 model_type ,
16821811 model_dir ,
16831812 eval_checkpoint_step ,
1684- input_filename = None ,
1685- output_filename = None ,
16861813 checkpoint_paths = None ,
1687- decode_from_file_fn = decode_from_file ):
1814+ decode_fn = decode_from_file ):
16881815 """Infer a Mesh-TF model.
16891816
16901817 Args:
@@ -1699,24 +1826,20 @@ def infer_model(estimator,
16991826 model_dir: string, estimator model_dir
17001827 eval_checkpoint_step: int, list of ints, or None, see `eval_model`
17011828 docstring.
1702- input_filename: a string, input file with examples
1703- output_filename: a string, output file to save decodes
17041829 checkpoint_paths: optional list of checkpoints to run inference for
1705- decode_from_file_fn : decoding function, defaults to decode_from_file
1830+ decode_fn : decoding function, defaults to decode_from_file
17061831 """
17071832 if checkpoint_paths is None :
17081833 checkpoint_paths = get_checkpoint_iterator (eval_checkpoint_step , model_dir )
17091834
17101835 for checkpoint_path in checkpoint_paths :
1711- decode_from_file_fn (
1836+ decode_fn (
17121837 estimator ,
17131838 vocabulary = vocabulary ,
17141839 model_type = model_type ,
17151840 batch_size = batch_size ,
17161841 sequence_length = sequence_length ,
1717- checkpoint_path = checkpoint_path ,
1718- input_filename = input_filename ,
1719- output_filename = output_filename )
1842+ checkpoint_path = checkpoint_path )
17201843
17211844
17221845def eval_model (estimator ,
@@ -1872,24 +1995,8 @@ def eval_model(estimator,
18721995 if callable (estimator ):
18731996 estimator = estimator ()
18741997
1875- def input_fn (params ):
1876- """Eval input function for estimator."""
1877- del params
1878- # Concatenate all dataset inputs to only have to do one decode loop
1879- combined_ds = None
1880- for eval_dataset in eval_datasets :
1881- # Only evaluate tasks with metrics.
1882- if eval_dataset .metric_fns :
1883- ds = eval_dataset .dataset_fn (sequence_length = sequence_length )
1884- ds = ds .map (
1885- _filter_features , num_parallel_calls = tf .data .experimental .AUTOTUNE )
1886- combined_ds = ds if not combined_ds else combined_ds .concatenate (ds )
1887- combined_ds = combined_ds .batch (batch_size , drop_remainder = False )
1888- # Pad the final batch.
1889- combined_ds = transformer_dataset .trim_and_pad_dataset (
1890- combined_ds , length = batch_size )
1891- combined_ds = combined_ds .prefetch (tf .data .experimental .AUTOTUNE )
1892- return combined_ds
1998+ input_fn = _get_combined_dataset_input_fn (
1999+ eval_datasets , batch_size , sequence_length , check_for_metrics = True )
18932000
18942001 checkpoint_paths = get_checkpoint_iterator (eval_checkpoint_step , model_dir )
18952002 for checkpoint_path in checkpoint_paths :
0 commit comments